📄 svdmatrix.java
字号:
// check all rows same length for (int i = 1; i < m; ++i) { if (values[i].length != n) { String msg = "All rows must be of same length." + " Found row[0].length=" + n + " row[" + i + "]=" + values[i].length; throw new IllegalArgumentException(msg); } } // shared column Ids rows int[] sharedRow = new int[n]; for (int j = 0; j < n; ++j) sharedRow[j] = j; int[][] columnIds = new int[m][]; for (int j = 0; j < m; ++j) columnIds[j] = sharedRow; return partialSvd(columnIds, values, maxOrder, featureInit, initialLearningRate, annealingRate, regularization, minImprovement, minEpochs, maxEpochs, writer); } /** * Return the singular value decomposition of the specified * partial matrix, using the specified search parameters. * * <p>The writer parameter may be set to allow incremental progress * reports to that writer during training. These report on RMSE * per epoch. * * <p>See the class documentation above for a description of the * algorithm. * * <p>There are a number of constraints on the input, any * violation of which will raise an illegal argument exception. * The conditions are: * <ul> * * * <li>The maximum order must be greater than zero.</li> * * <li>The minimum relative improvement in mean square error must be non-negative * and finite.</li> * * <li>The minimum number of epochs must be greater than zero * and less than or equal to the maximum number of epochs.</li> * * <li>The feature initialization value must be non-zero and finite.</li> * * <li>The learning rate must be positive and finite.</li> * * <li>The regularization parameter must be non-negative and finite.</li> * * <li>The column identitifer and value arrays must be the same * length.</li> * * <li>The elements of the column identifier array and the * value array must all be of the same length.</li> * * <li>All column identifiers must be non-negative.</li> * * <li>Each row of the column identifier matrix must contain * columns in strictly ascending order. * * </ul> * * @param columnIds Identifiers of column index for given row and entry. * @param values Values at row and column index for given entry. * @param maxOrder Maximum order of the decomposition. * @param featureInit Initial value for singular vectors. * @param initialLearningRate Incremental multiplier of error determining how * fast learning occurs. * @param annealingRate Rate at which annealing occurs; higher values * provide more gradual annealing. * @param regularization A regularization constant to damp learning. * @param minImprovement Minimum relative improvement in mean square error required * to finish an epoch. * @param minEpochs Minimum number of epochs for training. * @param maxEpochs Maximum number of epochs for training. * @param writer Print writer to which output is directed during * training, or null if no output is desired. * @return Singular value decomposition for the specified partial matrix * at the specified order. * @throws IllegalArgumentException Under conditions listed in the * method documentation above. */ public static SvdMatrix partialSvd(int[][] columnIds, double[][] values, int maxOrder, double featureInit, double initialLearningRate, double annealingRate, double regularization, double minImprovement, int minEpochs, int maxEpochs, PrintWriter writer) { printIfWriter(writer,"Start"); if (maxOrder < 1) throw new IllegalArgumentException("4"); if (minImprovement < 0 || notFinite(minImprovement)) throw new IllegalArgumentException("5"); if (minEpochs <= 0 || maxEpochs < minEpochs) throw new IllegalArgumentException("6"); if (notFinite(featureInit) || featureInit == 0.0) throw new IllegalArgumentException("7"); if (notFinite(initialLearningRate) || initialLearningRate < 0) throw new IllegalArgumentException("8"); if (notFinite(regularization) || regularization < 0) throw new IllegalArgumentException("9"); for (int row = 0; row < columnIds.length; ++row) { if (columnIds == null) throw new IllegalArgumentException("colIds"); if (values == null) throw new IllegalArgumentException("values"); if (columnIds[row] == null) throw new IllegalArgumentException("columnIds " + row); if (values[row] == null) throw new IllegalArgumentException("vals " + row); if (columnIds[row].length != values[row].length) throw new IllegalArgumentException("10"); for (int i = 0; i < columnIds[row].length; ++i) { if (columnIds[row][i] < 0) throw new IllegalArgumentException("12"); if (i > 0 && columnIds[row][i-1] >= columnIds[row][i]) throw new IllegalArgumentException("13"); } } if (annealingRate < 0 || notFinite(annealingRate)) throw new IllegalArgumentException("14"); int numRows = columnIds.length; int numEntries = 0; for (double[] xs : values) numEntries += xs.length; int maxColumnIndex = 0; for (int[] xs : columnIds) for (int i = 0; i < xs.length; ++i) if (xs[i] > maxColumnIndex) maxColumnIndex = xs[i]; int numColumns = maxColumnIndex + 1; maxOrder = Math.min(maxOrder,Math.min(numRows,numColumns)); double[][] cache = new double[values.length][]; for (int row = 0; row < numRows; ++row) { cache[row] = new double[values[row].length]; Arrays.fill(cache[row],0.0F); } List<double[]> rowVectorList = new ArrayList<double[]>(maxOrder); List<double[]> columnVectorList = new ArrayList<double[]>(maxOrder); for (int order = 0; order < maxOrder; ++order) { printIfWriter(writer," Factor=" + order); double[] rowVector = initArray(numRows,featureInit); double[] columnVector = initArray(numColumns,featureInit); double rmseLast = Double.POSITIVE_INFINITY; for (int epoch = 0; epoch < maxEpochs; ++epoch) { double learningRateForEpoch = initialLearningRate / (1.0 + epoch / annealingRate); double sumOfSquareErrors = 0.0; for (int row = 0; row < numRows; ++row) { int[] columnIdsForRow = columnIds[row]; double[] valuesForRow = values[row]; double[] cacheForRow = cache[row]; for (int i = 0; i < columnIdsForRow.length; ++i) { int column = columnIdsForRow[i]; double prediction = predict(row,column, rowVector,columnVector, cacheForRow[i]); double error = valuesForRow[i] - prediction; sumOfSquareErrors += error * error; double rowCurrent = rowVector[row]; double columnCurrent = columnVector[column]; rowVector[row] += learningRateForEpoch * (error * columnCurrent - regularization * rowCurrent); columnVector[column] += learningRateForEpoch * (error * rowCurrent - regularization * columnCurrent); } } double rmse = Math.sqrt(sumOfSquareErrors/numEntries); printIfWriter(writer, " epoch=" + epoch + " rmse=" + rmse); if ((epoch >= minEpochs) && (relativeDifference(rmse,rmseLast) < minImprovement)) { printIfWriter(writer, " exiting in epoch=" + epoch + " rmse=" + rmse + " relDiff=" + relativeDifference(rmse,rmseLast)); break; } rmseLast = rmse; } printIfWriter(writer,"Order=" + order + " RMSE=" + rmseLast); rowVectorList.add(rowVector); columnVectorList.add(columnVector); for (int row = 0; row < cache.length; ++row) { double[] cacheRow = cache[row]; for (int i = 0; i < cacheRow.length; ++i) { cacheRow[i] = predict(row,columnIds[row][i], rowVector,columnVector, cacheRow[i]); } } } double[][] rowVectors = toArray(rowVectorList); double[][] columnVectors = toArray(columnVectorList); return new SvdMatrix(transpose(rowVectors), transpose(columnVectors), maxOrder); } static void printIfWriter(PrintWriter writer, String msg) { if (writer == null) return; writer.print("partialSvd| "); writer.println(msg); writer.flush(); } static double predictRaw(int row, int column, int order, List<double[]> rowVectorList, List<double[]> columnVectorList, double lowerBound, double upperBound, double init) { double[][] rows = toArray(rowVectorList); double[][] cols = toArray(columnVectorList); double val = 0.0; for (int i = 0; i <= order; ++i) val += rows[i][row] * cols[i][column]; return val; } static double relativeDifference(double x, double y) { return Math.abs(x - y) / (Math.abs(x) + Math.abs(y)); } static double[][] transpose(double[][] xs) { double[][] ys = new double[xs[0].length][xs.length]; for (int i = 0; i < xs.length; ++i) for (int j = 0; j < xs[i].length; ++j) ys[j][i] = xs[i][j]; return ys; } static double[][] toArray(List<double[]> list) { double[][] result = new double[list.size()][]; list.toArray(result); return result; } static double predict(int row, int column, double[] rowVector, double[] columnVector, double cache) { return cache + rowVector[row] * columnVector[column]; } static double[] initArray(int size, double val) { double[] xs = new double[size]; // Arrays.fill(xs,val); // random init java.util.Random random = new java.util.Random(); for (int i = 0; i < xs.length; ++i) xs[i] = random.nextGaussian() * val; return xs; } static boolean notFinite(double x) { return Double.isNaN(x) || Double.isInfinite(x); } static double columnLength(double[][] xs, int col) { double sumOfSquares = 0.0; for (int i = 0; i < xs.length; ++i) sumOfSquares += xs[i][col] * xs[i][col]; // subopt array mem order return Math.sqrt(sumOfSquares); } static void scale(double[][] vecs, double[][] singularVecs, double[] singularVals) { for (int i = 0; i < vecs.length; ++i) for (int k = 0; k < vecs[i].length; ++k) vecs[i][k] = singularVecs[i][k] * singularVals[k]; } static void verifyDimensions(String prefix, int order, double[][] vectors) { for (int i = 0; i < vectors.length; ++i) { if (vectors[i].length != order) { String msg = "All vectors must have length equal to order." + " order=" + order + " " + prefix + "Vectors[" + i + "].length=" + vectors[i].length; throw new IllegalArgumentException(msg); } } } // normalize columns to unit length; static double[][] normalizeColumns(double[][] xs) { int numDims = xs.length; int order = xs[0].length; double[][] result = new double[numDims][order]; for (int j = 0; j < order; ++j) { double sumOfSquares = 0.0; for (int i = 0; i < numDims; ++i) { double valIJ = xs[i][j]; result[i][j] = valIJ; sumOfSquares += valIJ * valIJ; } double length = Math.sqrt(sumOfSquares); for (int i = 0; i < numDims; ++i) result[i][j] /= length; } return result; } /* public static void permute(int[] xs) { Random random = new Random(); for (int i = xs.length; --i > 0; ) { int pos = random.nextInt(i); int temp = xs[pos]; xs[pos] = xs[i]; xs[i] = temp; } } */}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -