📄 algorithmrvm.java
字号:
computeVarianceCholesky(); // 4. reestimate the hyperparameters // boolean A_changed = false; A_changed = updateHyperparametersFull(); // store the final weights from this iteration // int start_index = 0; if (num_samples_d < dimA_d) { start_index = 1; bias_d = MathUtil.doubleValue(curr_weights_d, 0); } MathUtil.copyVector(weights_d, curr_weights_d, num_samples_d, 0, start_index); // check convergence: converge if the weight updates have // stagnated and we have not pruned off a weight // if (MathUtil.almostEqual(curr_weights_d, last_rvm_weights_d)) { last_changed_d++; if (!A_changed || (num_rvm_its > max_rvm_its_d) || (last_changed_d > max_update_its_d || weights_d.size() < 2)) { rvm_converged = true; } } else { last_changed_d = 0; } // next iteration of RVM optimization // num_rvm_its++; } // run any final steps in the training algorithm and store the // final model // finalizeTraining(); // debugging information // if (debug_level_d) { // System.out.println(" convergence achieved"); } // exit gracefully // return true; } /** * * initializes the data structures in preparation for a full training * pass * * * * @return boolean value indicating status * */ public boolean initFullTrain() { // Debug // // System.out.println("AlgorithmRVM : initFullTrain()"); // initialize x_d and y_d // // total number of points // x_d.clear(); y_d.clear(); support_vectors_d.clear(); /* Point p; Vector vec_point = new Vector(); vec_point.add(new Double(-0.5545023696682464)); vec_point.add(new Double(0.4976303317535544)); x_d.add(vec_point); y_d.add(CLASS_LABEL_1); p = new Point(47,53); support_vectors_d.add(p); vec_point = new Vector(); vec_point.add(new Double(-0.5450236966824644)); vec_point.add(new Double(0.2890995260663507)); x_d.add(vec_point); y_d.add(CLASS_LABEL_1); p = new Point(48,75); support_vectors_d.add(p); vec_point = new Vector(); vec_point.add(new Double(-0.22274881516587675)); vec_point.add(new Double(0.27014218009478663)); x_d.add(vec_point); y_d.add(CLASS_LABEL_1); p = new Point(82,77); support_vectors_d.add(p); vec_point = new Vector(); vec_point.add(new Double(-0.21327014218009477)); vec_point.add(new Double(0.4691943127962085)); x_d.add(vec_point); y_d.add(CLASS_LABEL_1); p = new Point(83,56); support_vectors_d.add(p); vec_point = new Vector(); vec_point.add(new Double(0.3744075829383888)); vec_point.add(new Double(-0.31753554502369674)); x_d.add(vec_point); y_d.add(CLASS_LABEL_2); p = new Point(145,139); support_vectors_d.add(p); vec_point = new Vector(); vec_point.add(new Double(0.6303317535545025)); vec_point.add(new Double(-0.31753554502369674)); x_d.add(vec_point); y_d.add(CLASS_LABEL_2); p = new Point(172,139); support_vectors_d.add(p); vec_point = new Vector(); vec_point.add(new Double(0.2606635071090049)); vec_point.add(new Double(-0.4218009478672986)); x_d.add(vec_point); y_d.add(CLASS_LABEL_2); p = new Point(133,150); support_vectors_d.add(p); vec_point = new Vector(); vec_point.add(new Double(0.49763033175355464)); vec_point.add(new Double(-0.4407582938388628)); x_d.add(vec_point); y_d.add(CLASS_LABEL_2); p = new Point(158, 152); support_vectors_d.add(p); if ( debug_level_d ) { Matrix.printDoubleVector(y_d); } */ // add the data points // for (int i = 0; i < set1_d.size(); i++) { Vector<Double> vec_point = new Vector<Double>(); MyPoint curr_point = (MyPoint)set1_d.get(i); support_vectors_d.add(curr_point); vec_point.add(new Double(curr_point.x)); vec_point.add(new Double(curr_point.y)); x_d.add(vec_point); y_d.add(CLASS_LABEL_1); } // System.out.println("class: 2"); for (int i = 0; i < set2_d.size(); i++) { Vector<Double> vec_point = new Vector<Double>(); MyPoint curr_point = (MyPoint)set2_d.get(i); support_vectors_d.add(curr_point); vec_point.add(new Double(curr_point.x)); vec_point.add(new Double(curr_point.y)); x_d.add(vec_point); y_d.add(CLASS_LABEL_2); } // evalx_d = (Vector)x_d.clone(); // evaly_d = (Vector)y_d.clone(); // evalx_d = x_d; evaly_d = y_d; // number of samples stored // num_samples_d = x_d.size(); // System.out.println(num_samples_d); // the weights are a column vector of length N. The bias is stored // separately // weights_d.setSize(num_samples_d); MathUtil.initDoubleVector(weights_d, 0.0); // the hyperparameters, alphas, are stored as a diagonal array of // dimension N+1 x N+1. // A_d.initMatrixValue(num_samples_d + 1, num_samples_d + 1, 0, Matrix.DIAGONAL); // the phi matrix (pp. 214 of [1]) is the design matrix of kernel // evaluations. Note that here we use the transpose of the PHI matrix // in [1] since our matrices are assumed to be column matrices. Thus // phi is an N+1 x N matrix column matrix rather than a N x N+1 row // matrix as they do in [1]. // phi_d.initMatrixValue(num_samples_d + 1, num_samples_d, 0, Matrix.FULL); // the sigma vector is the probability of each sample given // the current weights. the B matrix is a diagonal matrix of // size N x N that is used in the least squares optimization // of the weights (pp. 219 of [1]). B[i,i] = sigma(i) * (1 - // sigma(i)) // sigma_d.setSize(num_samples_d); MathUtil.initDoubleVector(sigma_d, 0.0); B_d.initMatrixValue(num_samples_d, num_samples_d, 0, Matrix.DIAGONAL); // temporary storage for the gradient vector, hessian matrix and the // covariance matrix (or at least the cholesky decomposition of the // hessian which can be used to find the covariance) // gradient_d.setSize(num_samples_d + 1); MathUtil.initDoubleVector(gradient_d, 0.0); hessian_d.initMatrixValue(num_samples_d + 1, num_samples_d + 1, 0, Matrix.SYMMETRIC); covar_cholesky_d.initMatrixValue(num_samples_d + 1, num_samples_d + 1, 0, Matrix.LOWER_TRIANGULAR); // curr_weights are used during the optimization to hold the // current set of weights in a single vector. last_rvm_weights // hold the last set of weights that were converged to in the // RVM training iteration. If these weights match the weights // found in the next iteration then we have reached global // convergence (subject to a few other rules). // old_irls_weights contains the last set of weights tried for // the IRLS convergence tests // curr_weights_d.setSize(num_samples_d + 1); MathUtil.initDoubleVector(curr_weights_d, 0.0); last_rvm_weights_d.setSize(num_samples_d + 1); MathUtil.initDoubleVector(last_rvm_weights_d, 0.0); // if the weights stagnate then we can discontinue training even if the // hyperparameters are oscillating. // last_changed_d = 0; // compute the phi matrix: // The first element in each column is set to 1.0. Element i,j in the // matrix corresponds to the kernel function evaluated with // x_d(j) and x_d(i-1) as arguments // double kernel_value = 0.0; for (int j = 0; j < num_samples_d; j++) { phi_d.setValue(0, j, 1.0); } for (int i = 1; i <= num_samples_d; i++) { for (int j = 0; j < num_samples_d; j++) { Vector vec1 = (Vector)x_d.get(j); Vector vec2 = (Vector)x_d.get(i - 1); kernel_value = K(vec1, vec2); phi_d.setValue(i, j, kernel_value); // System.out.println("k_val: " + kernel_value); } } // initialize the A matrix: // A is initialized to the N+1 x N+1 indentity matrix // double scale = 1.0 / num_samples_d / num_samples_d; A_d.identityMatrix(); A_d.scalarMultMatrix(scale); // keep track of the number of hyperparameters: // if dim_A is equal to the size of the weight_d vector then the bias // has been pruned. // dimA_d = A_d.getNumRows(); // initialize weights to zero - this gives every input vector // probability 1/2 of being an in-class example. // i.e. sigma(0.0) = 1 / (1 + exp(0)) = 1/2 // MathUtil.initDoubleVector(weights_d, 0.0); bias_d = 0.0; // check that the data and targets are appropriate for training // if ((num_samples_d == 0) || (y_d.size() != num_samples_d)) { //System.out.println("Error at num_samples_d"); } // exit gracefully // return true; } /** * * Updates the hyperparameter values * * @return boolean value indicating status * */ public boolean updateHyperparametersFull() { // Debug // System.out.println("AlgorithmRVM : updateHyperparametersFull()"); // update the momentum components // only use momentum if no hyperparams have been pruned // recently. // // twoback_hyperparams_d = (Vector)last_hyperparams_d.clone(); // twoback_hyperparams_d = last_hyperparams_d; A_d.getDiagonalVector(last_hyperparams_d); boolean use_momentum = (twoback_hyperparams_d.size() == last_hyperparams_d.size()); // loop over each hyperparameter and update - remember that A is // diagonal where A(i,i) = alpha_i. The update equation used is // equation (16) of [1]. We also test for convergence in this loop. // If no alphas have substantially changed then we assume convergence. // We must take some care here that NaN and Infinity values do not // creap into our calculations // boolean updated = false; double gamma = 0.0; double tmp_alpha = 0.0; double old_alpha = 0.0; double sum_gamma = 0.0; for (int i = 0; i < dimA_d; i++) { // equation (16): alpha = (1 - alpha(i) * covar(i,i)) / (weights^2) // old_alpha = A_d.getValue(i, i); // error checking // if (old_alpha < 0.0) { // System.out.println("train:invalid old_alpha: " + old_alpha); Exception e = new Exception(); e.printStackTrace(); return false; } if (covar_cholesky_d.getValue(i, i) < 0.0) { // System.out.println("train:invalid covar"); Exception e = new Exception(); e.printStackTrace(); return false; } // compute gamma = 1 - alpha(i)*covar(i,i) // gamma = 1.0 - (old_alpha * covar_cholesky_d.getValue(i, i)); // gamma can be negative if covar(i,i) ~= 1/old_alpha // due to numerical imprecision. only error if the result // is not close to zero. // if (gamma < 0.0) { if (!MathUtil.almostEqual(gamma, 0.0)) { // System.out.println("train:invalid gamma: " + gamma); Exception e = new Exception(); e.printStackTrace(); return false; } else { gamma = -gamma; } } sum_gamma += gamma; // error checking // if the weight is already zero then alpha(i) will go to infinity // set it to a very large value instead // double weight_i = MathUtil.doubleValue(curr_weights_d, i); if (weight_i == 0.0) { tmp_alpha = alpha_thresh_d * 10; } else { tmp_alpha = gamma / weight_i / weight_i; } // compute momentum term // if (use_momentum) { double momentum_term = 0.0; double last_hyper = MathUtil.doubleValue(last_hyperparams_d, i); double last_hyper2 = MathUtil.doubleValue( twoback_hyperparams_d, i); if ((last_hyper < alpha_thresh_d) && (last_hyper2 < alpha_thresh_d)) { momentum_term = momentum_d * (last_hyper - last_hyper2); if (debug_level_d) { // System.out.println( // "momentum_term = " + momentum_term); } // the hyperparameter can not be negative so we do not want // to decrease toward zero too quickly and we should never // allow it to go below zero so we only take moderate // negative steps // if (momentum_term > (-0.5 * tmp_alpha)) { tmp_alpha = tmp_alpha + momentum_term; } } } A_d.setValue(i, i, tmp_alpha); // zero the weights to be pruned // if (debug_level_d) { // System.out.println("tmp_alpha: " + tmp_alpha // + " alpha_thresh_d: " + alpha_thresh_d); } if (tmp_alpha > alpha_thresh_d) { curr_weights_d.set(i, new Double(0.0)); updated = true; if (debug_level_d) { // System.out.println("weight set to 0.0"); Matrix.printDoubleVector(curr_weights_d); } } // check for convergence in A // if (!MathUtil.almostEqual(tmp_alpha, old_alpha)) { updated = true; } } // exit gracefully // return updated; } /** * * Completes one pass of IRLS training to update the weights * given the currently assigne hyperparamters. In this step, we find * those weights that maximize equation (24) of [1]. The iteratively * reweighted least squares algorithm is used to find w_mp. It proceeds * as follows: * * a. initialize * b. loop until convergence * 1. compute the B matrix * 2. compute the Hessian and the gradient * 3. update weights with formula * new weights = weights - inv(Hessian) * gradient * 4. check convergence * * * @return boolean value indicating status * */ public boolean irlsTrain() { // Debug // // System.out.println("AlgorithmRVM : irlsTrain()"); // 1.a. initialize: curr_weights: holds the weights currently // being considered for update // int start_index = 0; if (weights_d.size() < dimA_d) { curr_weights_d.set(0, new Double(bias_d)); start_index = 1; } MathUtil.copyVector(curr_weights_d, weights_d, weights_d.size(), start_index, 0); // make sure there are some weights that need updating // if (dimA_d == 0) { /* curr_weights_d.setLength(0); old_irls_weights_d.setLength(0); gradient_d.setLength(0); hessian_d.setDimensions(0,0); covar_cholesky_d.setDimensions(0,0); sigma_d.assign(0.5); B_d.assign(0.25); */ // System.out.println("dimA_d == 0"); Exception e = new Exception(); e.printStackTrace(); return true; } // 1.b. loop until convergence: // convergence is achieved once the weights do not change over an // iteration. This corresponds to the gradient of the likelihood // function with respect to the weights going to zero. // boolean irls_converged = false; boolean last_iteration = false; long num_irls_its = 0; double lambda = 1.0; boolean reduced = true; // initially true to skip the first iteration double old_likelihood = 0.0; double new_likelihood = old_likelihood; long size_B = B_d.getNumRows(); Matrix sigma_M = new Matrix(); Matrix y_M = new Matrix(); Matrix temp_M = new Matrix(); Matrix temp2_M = new Matrix(); Matrix curr_weights_M = new Matrix(); Matrix gradient_M = new Matrix(); while (!irls_converged)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -