📄 algorithmrvm.java
字号:
bias_d.debug(L"bias_d"); weights_d.debug(L"weights_d"); targets_d.debug(L"targets_d"); vectors_d.debug(L"vectors_d"); } */ if (debug_level_d) { // System.out.println("tmp_A after pruning:"); // Matrix.printDoubleVector(tmp_A); } dimA_d = dimA; A_d.initDiagonalMatrix(tmp_A); Matrix tmp_phi = new Matrix(); tmp_phi.copyMatrixRows(phi_d, flags); phi_d = tmp_phi; // debugging information // if (debug_level_d) { // System.out.println("Pruning " + weights_pruned // + " vect//ors..." + weights_d.size() + " remain"); } // exit gracefully // return true; } /** * * Computes the diagonal elements of the inverse from the * cholesky decomposition * * * @return true * */ public boolean computeVarianceCholesky() { // Debug // System.out.println("AlgorithmRVM : computeVarianceCholesky()"); // find the size of the cholesky decomposition matrix // long nrows = covar_cholesky_d.getNumRows(); double sum = 0.0; // compute the inverse of the lower triangular cholesky decomposition // for (int i = 0; i < nrows; i++) { covar_cholesky_d.setValue(i, i, (1.0 / covar_cholesky_d.getValue (i, i))); for (int j = i + 1; j < nrows; j++) { sum = 0.0; for (int k = i; k < j; k++) { sum -= covar_cholesky_d.getValue(j, k) * covar_cholesky_d.getValue(k, i); } covar_cholesky_d.setValue(j, i, (sum / covar_cholesky_d.getValue(j, j))); } } // loop over columns and get the sumSquare values to put on // the diagonals. This simulates a matrix multiply but we // only retain the diagonal values. // /* Vector tmp; for (long i = 0; i < nrows; i++) { covar_cholesky_d.getColumn(tmp, i); double val = tmp.sumSquare(); covar_cholesky_d.setValue(i,i,val); } */ for (int i = 0; i < nrows; i++) { double val = covar_cholesky_d.getColSumSquare(i); covar_cholesky_d.setValue(i, i, val); } // exit gracefully // return true; } /** * * Computes the sigma valued defined on pp. 218 of [1]. * * @return true * */ public boolean computeSigma() { // Debug // System.out.println("AlgorithmRVM : computeSigma()"); // declare local variables // Matrix weights = new Matrix(); Matrix sigma = new Matrix(); // debug_level_d = false; weights.initMatrix(curr_weights_d); if (debug_level_d) { // System.out.println("phi_d: "); // phi_d.printMatrix(); // System.out.println("curr_weights_d: "); // weights.printMatrix(); // System.out.println("sigma: "); // sigma.printMatrix(); } weights.multMatrix(phi_d, sigma); sigma.scalarMultMatrix(-1); sigma.expMatrix(); sigma.addMatrixElements(1.0); sigma.inverseMatrixElements(); if (debug_level_d) { // System.out.println("phi_d: "); // phi_d.printMatrix(); // System.out.println("curr_weights_d: "); // weights.printMatrix(); // System.out.println("sigma: "); // sigma.printMatrix(); } sigma.toDoubleVector(sigma_d); // debug_level_d = true; // exit gracefully // return true; } /** * * Computes the log likelihood of the weights given the * data. We would expect this value to increase as training proceeds * * @return Likelihood of weights given the data * * */ public double computeLikelihood() { // Debug // // System.out.println("AlgorithmRVM : computeLikelihood()"); // declare local variables // double result = 0; double divisor = phi_d.getNumColumns(); // init the matrix Matrix weights_M = new Matrix(); Matrix temp_M = new Matrix(); Matrix result_M = new Matrix(); weights_M.initMatrix(curr_weights_d); // compute the -1/2 * w * A * w' component of the likelihood // weights_M.multMatrix(A_d, temp_M); result_M.copyMatrix(weights_M); result_M.transposeMatrix(weights_M); temp_M.multMatrix(weights_M, result_M); result = result_M.getValue(0, 0); result *= 0.5; result /= divisor; // compute the summation in the log likelihood // L = sum [log(sigma) * t + (1-t) * log(1-sigma)] // long len; if (sigma_d.size() > y_d.size()) { len = y_d.size(); } else { len = sigma_d.size(); } double t; double sig; for (int i = 0; i < len; i++) { t = MathUtil.doubleValue(y_d, i); sig = MathUtil.doubleValue(sigma_d, i); if (t < 0.5) { result -= Math.log(1 - sig) / divisor; } else { result -= Math.log(sig) / divisor; } } // exit gracefully // return result; } /** * * Computes the output of a specific example * * * @return result of the evaluation function * */ double evaluateOutput(Vector point_a) { // Debug // // System.out.println("AlgorithmRVM : evaluateOutput(Vector point_a)"); // declare local variables // double result = bias_d; // debug message // if (debug_level_d) { // System.out.println("bias_d: " + bias_d); // System.out.println("point_a: "); // Matrix.printDoubleVector(point_a); } // evaluate the output at the example // for (int i = 0; i < weights_d.size(); i++) { double w_i = ((Double)weights_d.get(i)).doubleValue(); Vector x_i = (Vector)evalx_d.get(i); result += w_i * K(point_a, x_i); if (debug_level_d) { // System.out.println("x_i: " + x_i + "i: " + i // + " result: " + result); } } // apply the logistic sigmoid link function // result = 1.0 / (1.0 + Math.exp(-result)); // debug message // if (debug_level_d) { // System.out.println("result: " + result); } // return the result // return result; } /** * * Evaluates the linear Kernel on the input vectors * K(x,y) = (x . y) * * @param point1 first point * @param point2 second point * * @return Kernel evaluation * */ double K(Vector point1, Vector point2) { // Debug // System.out.println( // "AlgorithmRVM : K(Vector point1, Vector point2)"); double result = 0; if (kernel_type_d == KERNEL_TYPE_LINEAR) { result = MathUtil.linearKernel(point1, point2); } if (kernel_type_d == KERNEL_TYPE_RBF) { result = MathUtil.rbfKernel(point1, point2, 20); } if (kernel_type_d == KERNEL_TYPE_POLYNOMIAL) { result = MathUtil.polynomialKernel(point1, point2); } // return the result // return result; } /** * * Computes the line of discrimination for the classification * algorithms when the corresponding flags have been initialized * */ public void computeDecisionRegions() { // Debug // // System.out.println("AlgorithmRVM : computeDecisionRegions()"); DisplayScale scale = output_panel_d.disp_area_d.getDisplayScale(); double currentX = scale.xmin; double currentY = scale.ymin; // set precision // int outputWidth = output_panel_d.disp_area_d.getXPrecision(); int outputHeight = output_panel_d.disp_area_d.getYPrecision(); double incrementY = (scale.ymax-scale.ymin)/outputHeight; double incrementX = (scale.xmax-scale.xmin)/outputWidth; // declare a 2D array to store the class associations // output_canvas_d = new int[outputWidth][outputHeight]; // loop through each and every point on the pixmap and // determine which class each pixel is associated with // int associated = 0; pro_box_d.setProgressMin(0); pro_box_d.setProgressMax(outputWidth); pro_box_d.setProgressCurr(20); for (int i = 0; i < outputWidth; i++) { currentX += incrementX; currentY = scale.ymin; // set current status // pro_box_d.setProgressCurr(i); // pro_box_d.appendMessage("."); for (int j = 0; j < outputHeight; j++) { // declare the current pixel point // currentY += incrementY; MyPoint pixel = new MyPoint(currentX, currentY); Vector<Double> curr_point = new Vector<Double>(); curr_point.add(new Double(pixel.x)); curr_point.add(new Double(pixel.y)); double output = evaluateOutput(curr_point); // System.out.println(message.toString()); if (output >= 0.5) { associated = 0; // decision_regions_d.add(pixel); } else { associated = 1; } // put and entry in the output canvas array to // indicate which class the current pixel is // closest to // output_canvas_d[i][j] = associated; // add a point to the vector of decision // region points if the class that the current // point is associated with is different for // the class what the previous point was // associated with i.e., a transition point // if (j > 0 && i > 0) { if (associated != output_canvas_d[i][j - 1] || associated != output_canvas_d[i - 1][j]) { decision_regions_d.add(pixel); } } } } // end of the loop } /** * * Computes the classification error for the data points * */ public void computeErrors() { // declare local variables // String text; double error; int samples = 0; int samples1 = 0; int samples2 = 0; int samples3 = 0; int samples4 = 0; int incorrect = 0; int incorrect1 = 0; int incorrect2 = 0; int incorrect3 = 0; int incorrect4 = 0; DisplayScale scale = output_panel_d.disp_area_d.getDisplayScale(); // set scales int outputWidth = output_panel_d.disp_area_d.getXPrecision(); int outputHeight = output_panel_d.disp_area_d.getYPrecision(); double incrementY = (scale.ymax-scale.ymin)/outputHeight; double incrementX = (scale.xmax-scale.xmin)/outputWidth; // compute the classification error for the first set // for (int i = 0; i < set1_d.size(); i++) { MyPoint point = (MyPoint)set1_d.elementAt(i); samples1++; if ((point.x > scale.xmin && point.x < scale.xmax) && (point.y > scale.ymin && point.y < scale.ymax)) { if (output_canvas_d[(int)((point.x - scale.xmin) / incrementX)] [(int)((point.y - scale.ymin) / incrementY)] != 0) { incorrect1++; } } } if (set1_d.size() > 0) { error = ((double)incorrect1 / (double)samples1) * 100.0; text = new String( " Results for class 0:\n" + " Total number of samples: " + samples1 + "\n" + " Misclassified samples: " + incorrect1 + "\n" + " Classification error: " + MathUtil.setDecimal(error, 2) + "%"); pro_box_d.appendMessage(text); } // compute the classification error for the second set // for (int i = 0; i < set2_d.size(); i++) { MyPoint point = (MyPoint)set2_d.elementAt(i); samples2++; if ((point.x > scale.xmin && point.x < scale.xmax) && (point.y > scale.ymin && point.y < scale.ymax)) { if (output_canvas_d[(int)((point.x - scale.xmin) / incrementX)] [(int)((point.y - scale.ymin) / incrementY)] != 1) { incorrect2++; } } } if (set2_d.size() > 0) { error = ((double)incorrect2 / (double)samples2) * 100.0; text = new String( " Results for class 1:\n" + " Total number of samples: " + samples2 + "\n" + " Misclassified samples: " + incorrect2 + "\n" + " Classification error: " + MathUtil.setDecimal(error, 2) + "%"); pro_box_d.appendMessage(text); } // compute the overall classification error // samples = samples1 + samples2 + samples3 + samples4; incorrect = incorrect1 + incorrect2 + incorrect3 + incorrect4; error = ((double)incorrect / (double)samples) * 100.0; text = new String( " Overall results:\n" + " Total number of samples: " + samples + "\n" + " Misclassified samples: " + incorrect + "\n" + " Classification error: " + MathUtil.setDecimal(error, 2) + "%"); pro_box_d.appendMessage(text); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -