📄 algorithmrvm.java,v
字号:
//last_rvm_weights_d = (Vector)curr_weights_d.clone(); // last_rvm_weights_d = curr_weights_d; // debugging information // if (debug_level_d) { // System.out.println("RVM Iteration: " + num_rvm_its); } if (debug_level_d) { // System.out.println("phi_d: "); // phi_d.printMatrix(); // System.out.println("A_d: "); // A_d.printMatrix(); // System.out.println("weights_d: "); // Matrix.printDoubleVector(weights_d); // System.out.println("curr_weights_d: "); // Matrix.printDoubleVector(curr_weights_d); } // 1. prune only after the first iteration (we use a // weight of exactly 0.0 to trigger pruning so pruning on // the first iteration would prune all weights // if (num_rvm_its > 0) { pruneAndUpdate(); } if (debug_level_d) { // System.out.println("After pruning: "); // System.out.println("phi_d: "); // phi_d.printMatrix(); // System.out.println("A_d: "); // A_d.printMatrix(); // System.out.println("weights_d: "); // Matrix.printDoubleVector(weights_d); // System.out.println("curr_weights_d: "); // Matrix.printDoubleVector(curr_weights_d); // System.out.println("curr_weights_d: "); // Matrix.printMatrix(curr); } // if all weights have been pruned, then there is nothing // left to do and the process has converged (albeit to a // pretty useless solution) // if (dimA_d == 0) { rvm_converged = true; // debugging information // if (debug_level_d) { // System.out.println("rvm convergence achieved"); } // conclude training // break; } // 2. run a pass of IRLS training to estimate w_MP // irlsTrain(); // 3. estimate the covariance of the Gaussian // approximation compute the variance vector. The // covariance is the inverse of the Hessian matrix. Only // the diagonal elements are needed to update the // hyperparameters. From the Cholesky decomposition, we // can efficiently find these values. After this function // call, the diagonal elements of covar_cholesky will // contain the negation of the diagonal elements of the // estimated covariance. The other elements of // covar_cholesky are not meaningful. // computeVarianceCholesky(); // 4. reestimate the hyperparameters // boolean A_changed = false; A_changed = updateHyperparametersFull(); // store the final weights from this iteration // int start_index = 0; if (num_samples_d < dimA_d) { start_index = 1; bias_d = MathUtil.doubleValue(curr_weights_d, 0); } MathUtil.copyVector(weights_d, curr_weights_d, num_samples_d, 0, start_index); // check convergence: converge if the weight updates have // stagnated and we have not pruned off a weight // if (MathUtil.almostEqual(curr_weights_d, last_rvm_weights_d)) { last_changed_d++; if (!A_changed || (num_rvm_its > max_rvm_its_d) || (last_changed_d > max_update_its_d || weights_d.size() < 2)) { rvm_converged = true; } } else { last_changed_d = 0; } // next iteration of RVM optimization // num_rvm_its++; } // run any final steps in the training algorithm and store the // final model // finalizeTraining(); // debugging information // if (debug_level_d) { // System.out.println(" convergence achieved"); } // exit gracefully // return true; } /** * * initializes the data structures in preparation for a full training * pass * * * * @@return boolean value indicating status * */ public boolean initFullTrain() { // Debug // // System.out.println("AlgorithmRVM : initFullTrain()"); // initialize x_d and y_d // // total number of points // x_d.clear(); y_d.clear(); support_vectors_d.clear(); /* Point p; Vector vec_point = new Vector(); vec_point.add(new Double(-0.5545023696682464)); vec_point.add(new Double(0.4976303317535544)); x_d.add(vec_point); y_d.add(CLASS_LABEL_1); p = new Point(47,53); support_vectors_d.add(p); vec_point = new Vector(); vec_point.add(new Double(-0.5450236966824644)); vec_point.add(new Double(0.2890995260663507)); x_d.add(vec_point); y_d.add(CLASS_LABEL_1); p = new Point(48,75); support_vectors_d.add(p); vec_point = new Vector(); vec_point.add(new Double(-0.22274881516587675)); vec_point.add(new Double(0.27014218009478663)); x_d.add(vec_point); y_d.add(CLASS_LABEL_1); p = new Point(82,77); support_vectors_d.add(p); vec_point = new Vector(); vec_point.add(new Double(-0.21327014218009477)); vec_point.add(new Double(0.4691943127962085)); x_d.add(vec_point); y_d.add(CLASS_LABEL_1); p = new Point(83,56); support_vectors_d.add(p); vec_point = new Vector(); vec_point.add(new Double(0.3744075829383888)); vec_point.add(new Double(-0.31753554502369674)); x_d.add(vec_point); y_d.add(CLASS_LABEL_2); p = new Point(145,139); support_vectors_d.add(p); vec_point = new Vector(); vec_point.add(new Double(0.6303317535545025)); vec_point.add(new Double(-0.31753554502369674)); x_d.add(vec_point); y_d.add(CLASS_LABEL_2); p = new Point(172,139); support_vectors_d.add(p); vec_point = new Vector(); vec_point.add(new Double(0.2606635071090049)); vec_point.add(new Double(-0.4218009478672986)); x_d.add(vec_point); y_d.add(CLASS_LABEL_2); p = new Point(133,150); support_vectors_d.add(p); vec_point = new Vector(); vec_point.add(new Double(0.49763033175355464)); vec_point.add(new Double(-0.4407582938388628)); x_d.add(vec_point); y_d.add(CLASS_LABEL_2); p = new Point(158, 152); support_vectors_d.add(p); if ( debug_level_d ) { Matrix.printDoubleVector(y_d); } */ // add the data points // for (int i = 0; i < set1_d.size(); i++) { Vector<Double> vec_point = new Vector<Double>(); MyPoint curr_point = (MyPoint)set1_d.get(i); support_vectors_d.add(curr_point); vec_point.add(new Double(curr_point.x)); vec_point.add(new Double(curr_point.y)); x_d.add(vec_point); y_d.add(CLASS_LABEL_1); } // System.out.println("class: 2"); for (int i = 0; i < set2_d.size(); i++) { Vector<Double> vec_point = new Vector<Double>(); MyPoint curr_point = (MyPoint)set2_d.get(i); support_vectors_d.add(curr_point); vec_point.add(new Double(curr_point.x)); vec_point.add(new Double(curr_point.y)); x_d.add(vec_point); y_d.add(CLASS_LABEL_2); } // evalx_d = (Vector)x_d.clone(); // evaly_d = (Vector)y_d.clone(); // evalx_d = x_d; evaly_d = y_d; // number of samples stored // num_samples_d = x_d.size(); // System.out.println(num_samples_d); // the weights are a column vector of length N. The bias is stored // separately // weights_d.setSize(num_samples_d); MathUtil.initDoubleVector(weights_d, 0.0); // the hyperparameters, alphas, are stored as a diagonal array of // dimension N+1 x N+1. // A_d.initMatrixValue(num_samples_d + 1, num_samples_d + 1, 0, Matrix.DIAGONAL); // the phi matrix (pp. 214 of [1]) is the design matrix of kernel // evaluations. Note that here we use the transpose of the PHI matrix // in [1] since our matrices are assumed to be column matrices. Thus // phi is an N+1 x N matrix column matrix rather than a N x N+1 row // matrix as they do in [1]. // phi_d.initMatrixValue(num_samples_d + 1, num_samples_d, 0, Matrix.FULL); // the sigma vector is the probability of each sample given // the current weights. the B matrix is a diagonal matrix of // size N x N that is used in the least squares optimization // of the weights (pp. 219 of [1]). B[i,i] = sigma(i) * (1 - // sigma(i)) // sigma_d.setSize(num_samples_d); MathUtil.initDoubleVector(sigma_d, 0.0); B_d.initMatrixValue(num_samples_d, num_samples_d, 0, Matrix.DIAGONAL); // temporary storage for the gradient vector, hessian matrix and the // covariance matrix (or at least the cholesky decomposition of the // hessian which can be used to find the covariance) // gradient_d.setSize(num_samples_d + 1); MathUtil.initDoubleVector(gradient_d, 0.0); hessian_d.initMatrixValue(num_samples_d + 1, num_samples_d + 1, 0, Matrix.SYMMETRIC); covar_cholesky_d.initMatrixValue(num_samples_d + 1, num_samples_d + 1, 0, Matrix.LOWER_TRIANGULAR); // curr_weights are used during the optimization to hold the // current set of weights in a single vector. last_rvm_weights // hold the last set of weights that were converged to in the // RVM training iteration. If these weights match the weights // found in the next iteration then we have reached global // convergence (subject to a few other rules). // old_irls_weights contains the last set of weights tried for // the IRLS convergence tests // curr_weights_d.setSize(num_samples_d + 1); MathUtil.initDoubleVector(curr_weights_d, 0.0); last_rvm_weights_d.setSize(num_samples_d + 1); MathUtil.initDoubleVector(last_rvm_weights_d, 0.0); // if the weights stagnate then we can discontinue training even if the // hyperparameters are oscillating. // last_changed_d = 0; // compute the phi matrix: // The first element in each column is set to 1.0. Element i,j in the // matrix corresponds to the kernel function evaluated with // x_d(j) and x_d(i-1) as arguments // double kernel_value = 0.0; for (int j = 0; j < num_samples_d; j++) { phi_d.setValue(0, j, 1.0); } for (int i = 1; i <= num_samples_d; i++) { for (int j = 0; j < num_samples_d; j++) { Vector vec1 = (Vector)x_d.get(j); Vector vec2 = (Vector)x_d.get(i - 1); kernel_value = K(vec1, vec2); phi_d.setValue(i, j, kernel_value); // System.out.println("k_val: " + kernel_value); } } // initialize the A matrix: // A is initialized to the N+1 x N+1 indentity matrix // double scale = 1.0 / num_samples_d / num_samples_d; A_d.identityMatrix(); A_d.scalarMultMatrix(scale); // keep track of the number of hyperparameters: // if dim_A is equal to the size of the weight_d vector then the bias // has been pruned. // dimA_d = A_d.getNumRows(); // initialize weights to zero - this gives every input vector // probability 1/2 of being an in-class example. // i.e. sigma(0.0) = 1 / (1 + exp(0)) = 1/2 // MathUtil.initDoubleVector(weights_d, 0.0); bias_d = 0.0; // check that the data and targets are appropriate for training // if ((num_samples_d == 0) || (y_d.size() != num_samples_d)) { //System.out.println("Error at num_samples_d"); } // exit gracefully // return true; } /** * * Updates the hyperparameter values * * @@return boolean value indicating status * */ public boolean updateHyperparametersFull() { // Debug // System.out.println("AlgorithmRVM : updateHyperparametersFull()"); // update the momentum components // only use momentum if no hyperparams have been pruned // recently. // // twoback_hyperparams_d = (Vector)last_hyperparams_d.clone(); // twoback_hyperparams_d = last_hyperparams_d; A_d.getDiagonalVector(last_hyperparams_d); boolean use_momentum = (twoback_hyperparams_d.size() == last_hyperparams_d.size()); // loop over each hyperparameter and update - remember that A is // diagonal where A(i,i) = alpha_i. The update equation used is // equation (16) of [1]. We also test for convergence in this loop. // If no alphas have substantially changed then we assume convergence. // We must take some care here that NaN and Infinity values do not // creap into our calculations // boolean updated = false; double gamma = 0.0; double tmp_alpha = 0.0; double old_alpha = 0.0; double sum_gamma = 0.0; for (int i = 0; i < dimA_d; i++) { // equation (16): alpha = (1 - alpha(i) * covar(i,i)) / (weights^2) // old_alpha = A_d.getValue(i, i); // error checking // if (old_alpha < 0.0) { // System.out.println("train:invalid old_alpha: " + old_alpha); Exception e = new Exception(); e.printStackTrace(); return false; } if (covar_cholesky_d.getValue(i, i) < 0.0) { // System.out.println("train:invalid covar"); Exception e = new Exception(); e.printStackTrace(); return false; } // compute gamma = 1 - alpha(i)*covar(i,i) // gamma = 1.0 - (old_alpha * covar_cholesky_d.getValue(i, i)); // gamma can be negative if covar(i,i) ~= 1/old_alpha // due to numerical imprecision. only error if the result // is not close to zero. // if (gamma < 0.0) { if (!MathUtil.almostEqual(gamma, 0.0)) { // System.out.println("train:invalid gamma: " + gamma); Exception e = new Exception(); e.printStackTrace(); return false; } else { gamma = -gamma; } } sum_gamma += gamma; // error checking // if the weight is already zero then alpha(i) will go to infinity // set it to a very large value instead // double weight_i = MathUtil.doubleValue(curr_weights_d, i); if (weight_i == 0.0) { tmp_alpha = alpha_thresh_d * 10; } else { tmp_alpha = gamma / weight_i / weight_i; } // compute momentum term // if (use_momentum) { double momentum_term = 0.0; double last_hyper = MathUtil.doubleValue(last_hyperparams_d, i); double last_hyper2 = MathUtil.doubleValue( twoback_hyperparams_d, i); if ((last_hyper < alpha_thresh_d) && (last_hyper2 < alpha_thresh_d)) { momentum_term = momentum_d * (last_hyper - last_hyper2); if (debug_level_d) { // System.out.println( // "momentum_term = " + momentum_term); } // the hyperparameter can not be negative so we do not want // to decrease toward zero too quickly and we should never // allow it to go below zero so we only take moderate // negative steps // if (momentum_term > (-0.5 * tmp_alpha)) { tmp_alpha = tmp_alpha + momentum_term;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -