sparsevector.java
来自「mallet是自然语言处理、机器学习领域的一个开源项目。」· Java 代码 · 共 947 行 · 第 1/2 页
JAVA
947 行
for (int loc = 0; loc < maxloc; loc++) { int idx = v.indexAtLocation (loc); if (idx >= values.length) break; values [idx] += v.valueAtLocation (loc) * factor; } } private void denseTimesEqualsSparse (SparseVector v, double factor) { int maxloc = v.numLocations(); for (int loc = 0; loc < maxloc; loc++) { int idx = v.indexAtLocation (loc); if (idx >= values.length) break; values [idx] *= v.valueAtLocation (loc) * factor; } } /** * Increments this[index] by value. * @throws IllegalArgumentException If index is not present. */ public void incrementValue (int index, double value) throws IllegalArgumentException { int loc = location (index); if (loc >= 0) values[loc] += value; else throw new IllegalArgumentException ("Trying to set value that isn't present in SparseVector"); } /** Sets every present index in the vector to v. */ public void setAll (double v) { for (int i = 0; i < values.length; i++) values[i] = v; } /** * Sets the value at the given index. * @throws IllegalArgumentException If index is not present. */ public void setValue (int index, double value) throws IllegalArgumentException { if (indices == null) values[index] = value; else { int loc = location(index); if (loc < 0) throw new IllegalArgumentException ("Can't insert values into a sparse Vector."); else values[loc] = value; } } /** Sets the value at the given location. */ public void setValueAtLocation (int location, double value) { values[location] = value; } /** Copy values from an array into this vector. The array should have the * same size as the vector */ // yanked from DenseVector public final void arrayCopyFrom( double[] a ) { arrayCopyFrom(a,0); } /** Copy values from an array starting at a particular location into this * vector. The array must have at least as many values beyond the * starting location as there are in the vector. * * @return Next uncopied location in the array. */ public final int arrayCopyFrom( double [] a , int startingArrayLocation ) { System.arraycopy( a, startingArrayLocation, values, 0, values.length ); return startingArrayLocation + values.length; } /** * Applies the method argument to each value in a non-binary vector. * The method should both accept a Double as an argument and return a Double. * * @throws IllegalArgumentException If the method argument has an * inappropriate signature. * @throws UnsupportedOperationException If vector is binary * @throws IllegalAccessException If the method is inaccessible * @throws Throwable If the method throws an exception it is relayed */ public final void map (Method f) throws IllegalAccessException, Throwable { if (values == null) throw new UnsupportedOperationException ("Binary values may not be altered via map"); if (f.getParameterTypes().length!=1 || f.getParameterTypes()[0] != Double.class || f.getReturnType() != Double.class ) throw new IllegalArgumentException ("Method signature must be \"Double f (Double x)\""); try { for (int i=0 ; i<values.length ; i++) values[i] = ((Double)f.invoke (null, new Object[] {new Double(values[i])})).doubleValue (); } catch (InvocationTargetException e) { throw e.getTargetException(); } } /** Copy the contents of this vector into an array starting at a * particular location. * * @return Next available location in the array */ public final int arrayCopyInto (double[] array, int startingArrayLocation) { System.arraycopy (values, 0, array, startingArrayLocation, values.length); return startingArrayLocation + values.length; } /*********************************************************************** * VECTOR OPERATIONS ***********************************************************************/ public double dotProduct (double[] v) { double ret = 0; if (values == null) for (int i = 0; i < indices.length; i++) ret += v[indices[i]]; else for (int i = 0; i < indices.length; i++) ret += values[i] * v[indices[i]]; return ret; } public double dotProduct (ConstantMatrix m) { if (m instanceof SparseVector) return dotProduct ((SparseVector)m); else if (m instanceof DenseVector) return dotProduct ((DenseVector)m); else throw new IllegalArgumentException ("Unrecognized Matrix type "+m.getClass()); } public double dotProduct (DenseVector v) { if (v.hasInfinite || this.hasInfinite) return extendedDotProduct(v); double ret = 0; if (values == null) for (int i = 0; i < indices.length; i++) ret += v.value(indices[i]); else for (int i = 0; i < indices.length; i++) ret += values[i] * v.value(indices[i]); if (Double.isNaN(ret)) return extendedDotProduct(v); return ret; } // sets -Inf * 0 = 0; Inf * 0 = 0 public double extendedDotProduct (DenseVector v) { double ret = 0; if (values == null) for (int i = 0; i < indices.length; i++) ret += v.value(indices[i]); else for (int i = 0; i < indices.length; i++) { if (Double.isInfinite(values[i]) && v.value(indices[i])==0.0) { this.hasInfinite = true; continue; } else if (Double.isInfinite(v.value(indices[i])) && values[i]==0.0) { v.hasInfinite = true; continue; } ret += values[i] * v.value(indices[i]); } return ret; } public double dotProduct (SparseVector v) { if (v.hasInfinite || hasInfinite) return extendedDotProduct(v); double ret; // Decide in which direction to do the dot product. // This is a heuristic choice based on efficiency, and it could certainly // be more complicated. if (v instanceof IndexedSparseVector) { ret = v.dotProduct (this); } else if(numLocations() > v.numLocations ()) { ret = dotProductInternal (v, this); } else { ret = dotProductInternal (this, v); } if (Double.isNaN (ret)) return extendedDotProduct (v); return ret; } private double dotProductInternal (SparseVector vShort, SparseVector vLong) { double ret = 0; int numShortLocs = vShort.numLocations(); if (vShort.isBinary ()) { for(int i = 0; i < numShortLocs; i++) { ret += vLong.value (vShort.indexAtLocation(i)); } } else { for(int i = 0; i < numShortLocs; i++) { double v1 = vShort.valueAtLocation(i); double v2 = vLong.value (vShort.indexAtLocation(i)); ret += v1*v2; } } return ret; } // sets -Inf * 0 = 0, Inf * 0 = 0 public double extendedDotProduct (SparseVector v) { double ret = 0.0; SparseVector vShort = null; SparseVector vLong = null; // this ensures minimal computational effort if(numLocations() > v.numLocations ()) { vShort = v; vLong = this; } else { vShort = this; vLong = v; } for(int i = 0; i < vShort.numLocations(); i++) { double v1 = vShort.valueAtLocation(i); double v2 = vLong.value (vShort.indexAtLocation(i)); if (Double.isInfinite(v1) && v2==0.0) { vShort.hasInfinite = true; continue; } else if (Double.isInfinite(v2) && v1==0.0) { vLong.hasInfinite = true; continue; } ret += v1*v2; } return ret; } public SparseVector vectorAdd(SparseVector v, double scale) { if(indices != null) { // sparse SparseVector int [] ind = v.getIndices(); double [] val = v.getValues(); int [] newIndices = new int[ind.length+indices.length]; double [] newVals = new double[ind.length+indices.length]; for(int i = 0; i < indices.length; i++) { newIndices[i] = indices[i]; newVals[i] = values[i]; } for(int i = 0; i < ind.length; i++) { newIndices[i+indices.length] = ind[i]; newVals[i+indices.length] = scale*val[i]; } return new SparseVector(newIndices, newVals, true, true, false); } int [] newIndices = new int[values.length]; double [] newVals = new double[values.length]; // dense SparseVector int curPos = 0; for(int i = 0; i < values.length; i++) { double val = values[i]+scale*v.value(i); if(val != 0.0) { newIndices[curPos] = i; newVals[curPos++] = val; } } return new SparseVector(newIndices, newVals, true, true, false); } public double oneNorm () { double ret = 0; if (values == null) return indices.length; for (int i = 0; i < values.length; i++) ret += values[i]; return ret; } public double absNorm () { double ret = 0; if (values == null) return indices.length; for (int i = 0; i < values.length; i++) ret += Math.abs(values[i]); return ret; } public double twoNorm () { double ret = 0; if (values == null) return Math.sqrt (indices.length); for (int i = 0; i < values.length; i++) ret += values[i] * values[i]; return Math.sqrt (ret); } public double infinityNorm () { if (values == null) return 1.0; double max = Double.NEGATIVE_INFINITY; for (int i = 0; i < values.length; i++) if (Math.abs(values[i]) > max) max = Math.abs(values[i]); return max; } public void print() { if (values == null) { // binary sparsevector for (int i = 0; i < indices.length; i++) System.out.println ("SparseVector["+indices[i]+"] = 1.0"); } else { for (int i = 0; i < values.length; i++) { int idx = (indices == null) ? i : indices [i]; System.out.println ("SparseVector["+idx+"] = "+values[i]); } } } public boolean isNaN() { if (values == null) return false; for (int i = 0; i < values.length; i++) if (Double.isNaN(values[i])) return true; return false; } protected void sortIndices () { if (indices == null) // It's dense, and thus by definition sorted. return; if (values == null) java.util.Arrays.sort (indices); else { // Just BubbleSort; this is efficient when already mostly sorted. // Note that we BubbleSort from the the end forward; this is most efficient // when we have added a few additional items to the end of a previously sorted list. // We could be much smarter if we remembered the highest index that was already sorted for (int i = indices.length-1; i >= 0; i--) { boolean swapped = false; for (int j = 0; j < i; j++) if (indices[j] > indices[j+1]) { // Swap both indices and values int f; f = indices[j]; indices[j] = indices[j+1]; indices[j+1] = f; if (values != null) { double v; v = values[j]; values[j] = values[j+1]; values[j+1] = v; } swapped = true; } if (!swapped) break; } } //if (values == null) int numDuplicates = 0; for (int i = 1; i < indices.length; i++) if (indices[i-1] == indices[i]) numDuplicates++; if (numDuplicates > 0) removeDuplicates (numDuplicates); } // Argument zero is special value meaning that this function should count them. protected void removeDuplicates (int numDuplicates) { if (numDuplicates == 0) for (int i = 1; i < indices.length; i++) if (indices[i-1] == indices[i]) numDuplicates++; if (numDuplicates == 0) return; int[] newIndices = new int[indices.length - numDuplicates]; double[] newValues = values == null ? null : new double[indices.length - numDuplicates]; newIndices[0] = indices[0]; if (values != null) newValues[0] = values[0]; for (int i = 1, j = 1; i < indices.length; i++) { if (indices[i] == indices[i-1]) { if (newValues != null) newValues[j-1] += values[i]; } else { newIndices[j] = indices[i]; if (values != null) newValues[j] = values[i]; j++; } } this.indices = newIndices; this.values = newValues; } /// Serialization private static final long serialVersionUID = 2; private static final int CURRENT_SERIAL_VERSION = 1; private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); out.writeInt (indices == null ? -1 : indices.length); out.writeInt (values == null ? -1 : values.length); if (indices != null) for (int i = 0; i < indices.length; i++) out.writeInt (indices[i]); if (values != null) for (int i = 0; i < values.length; i++) out.writeDouble (values[i]); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt (); int indicesSize = in.readInt(); int valuesSize = in.readInt(); this.hasInfinite = false; if (indicesSize >= 0) { indices = new int[indicesSize]; for (int i = 0; i < indicesSize; i++) { indices[i] = in.readInt(); } } if (valuesSize >= 0) { values = new double[valuesSize]; for (int i = 0; i < valuesSize; i++) { values[i] = in.readDouble(); if (Double.isInfinite (values[i])) this.hasInfinite = true; } } }}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?