📄 algorithmrvm.java,v
字号:
} } } A_d.setValue(i, i, tmp_alpha); // zero the weights to be pruned // if (debug_level_d) { // System.out.println("tmp_alpha: " + tmp_alpha // + " alpha_thresh_d: " + alpha_thresh_d); } if (tmp_alpha > alpha_thresh_d) { curr_weights_d.set(i, new Double(0.0)); updated = true; if (debug_level_d) { // System.out.println("weight set to 0.0"); Matrix.printDoubleVector(curr_weights_d); } } // check for convergence in A // if (!MathUtil.almostEqual(tmp_alpha, old_alpha)) { updated = true; } } // exit gracefully // return updated; } /** * * Completes one pass of IRLS training to update the weights * given the currently assigne hyperparamters. In this step, we find * those weights that maximize equation (24) of [1]. The iteratively * reweighted least squares algorithm is used to find w_mp. It proceeds * as follows: * * a. initialize * b. loop until convergence * 1. compute the B matrix * 2. compute the Hessian and the gradient * 3. update weights with formula * new weights = weights - inv(Hessian) * gradient * 4. check convergence * * * @@return boolean value indicating status * */ public boolean irlsTrain() { // Debug // // System.out.println("AlgorithmRVM : irlsTrain()"); // 1.a. initialize: curr_weights: holds the weights currently // being considered for update // int start_index = 0; if (weights_d.size() < dimA_d) { curr_weights_d.set(0, new Double(bias_d)); start_index = 1; } MathUtil.copyVector(curr_weights_d, weights_d, weights_d.size(), start_index, 0); // make sure there are some weights that need updating // if (dimA_d == 0) { /* curr_weights_d.setLength(0); old_irls_weights_d.setLength(0); gradient_d.setLength(0); hessian_d.setDimensions(0,0); covar_cholesky_d.setDimensions(0,0); sigma_d.assign(0.5); B_d.assign(0.25); */ // System.out.println("dimA_d == 0"); Exception e = new Exception(); e.printStackTrace(); return true; } // 1.b. loop until convergence: // convergence is achieved once the weights do not change over an // iteration. This corresponds to the gradient of the likelihood // function with respect to the weights going to zero. // boolean irls_converged = false; boolean last_iteration = false; long num_irls_its = 0; double lambda = 1.0; boolean reduced = true; // initially true to skip the first iteration double old_likelihood = 0.0; double new_likelihood = old_likelihood; long size_B = B_d.getNumRows(); Matrix sigma_M = new Matrix(); Matrix y_M = new Matrix(); Matrix temp_M = new Matrix(); Matrix temp2_M = new Matrix(); Matrix curr_weights_M = new Matrix(); Matrix gradient_M = new Matrix(); while (!irls_converged) { if (debug_level_d) { // System.out.println("curr_weights_d: "); Matrix.printDoubleVector(curr_weights_d); } // copy out the current best weights to the weight vector // so they can be used for evaluations // if (num_irls_its > 0) { if (start_index == 1) { bias_d = MathUtil.doubleValue(curr_weights_d, 0); } else { bias_d = 0.0; } MathUtil.copyVector(weights_d, curr_weights_d, weights_d.size(), 0, start_index); } // 1.b.1. compute the B matrix and the sigma vector // if (debug_level_d) { // System.out.println("IRLS iteration: " + num_irls_its); // System.out.println( // "1.b.1. compute the B matrix and the sigma vector"); } computeSigma(); for (int i = 0; i < size_B; i++) { double sigma = MathUtil.doubleValue(sigma_d, i); B_d.setValue(i, i, sigma * (1 - sigma)); } // if the likelihood decreases then the newton step was // too large and needs to be reduced // new_likelihood = computeLikelihood(); if (debug_level_d) { // System.out.println("sigma_d: "); // Matrix.printDoubleVector(sigma_d); // System.out.println("phi_d: "); // phi_d.printMatrix(); // System.out.println("B_d: ") // B_d.printMatrix(); // System.out.println("old likelihood 1 : " + old_likelihood); // System.out.println("new likelihood 1 : " + new_likelihood); } if (!reduced) { if (new_likelihood < old_likelihood || (num_irls_its < 2)) { // store the weights computed on the last IRLS iteration // old_likelihood = new_likelihood; if (lambda < 0.9) { // System.out.println("increasing step parameter"); lambda *= 2.0; } else { lambda = 1.0; } } else { // revert to the previous weights // // curr_weights_d = (Vector)old_irls_weights_d.clone(); // curr_weights_d = old_irls_weights_d; lambda *= 0.5; if (lambda < 0.0001) { last_iteration = true; } reduced = true; continue; } } reduced = false; // old_irls_weights_d = (Vector)curr_weights_d.clone(); // old_irls_weights_d = curr_weights_d; if (debug_level_d) { // System.out.println("curr_weights_d: "); // Matrix.printDoubleVector(curr_weights_d); // System.out.println( // "1.b.2. compute the gradient and Hessian "); } /* debug if (debug_level_d > Integral::BRIEF) { Vector error_signal; error_signal.sub(targets_d, sigma_d); error_signal.abs(); Double out = error_signal.sum(); out.debug(L"sum-squared error"); curr_weights_d.debug(L"current weights"); } */ // 1.b.2. compute the gradient and Hessian // // gradient = phi * [y_d - sigma] - A * weights // sigma is reused as temporary data space in the second segment // for computing A * weights. // /* */ // init matrix // boolean status = true; sigma_M.initMatrix(sigma_d); y_M.initMatrix(y_d); curr_weights_M.initMatrix(curr_weights_d); gradient_M.initMatrix(gradient_d); sigma_M.scalarMultMatrix(-1.0); // -sigma sigma_M.addMatrix(y_M); // sigma_M = y_d -sigma_M sigma_M.transposeMatrix(temp_M); // gradient = phi * [y_d - sigma] // phi_d.multMatrix(temp_M, gradient_M); // temp_M = -A * weights curr_weights_M.transposeMatrix(temp_M); A_d.multMatrix(temp_M, temp2_M); temp2_M.scalarMultMatrix(-1); // gradient = phi * [y_d - sigma] - A * weights // gradient_M.addMatrix(temp2_M); //gradient_M.printMatrix(); gradient_M.toDoubleVector(gradient_d); /* if (!status) { gradient_d.debug(L"computed gradient"); return Error::handle(name(), L"train:gradient computation", ERR, __FILE__, __LINE__); } // debugging information // if (debug_level_d > Integral::DETAILED) { gradient_d.debug(L"gradient"); } */ // Hessian = -(phi * B * transpose(phi) + A) // this is a quadratic form about B // status = true; // temp_M = phi * B, temp2_M = phi' // hessian_d = phi * B * phi' // phi_d.multMatrix(B_d, temp_M); phi_d.transposeMatrix(temp2_M); temp_M.multMatrix(temp2_M, hessian_d); // phi * B * phi' + A // hessian_d.addMatrix(A_d); /* if (!status) { hessian_d.debug(L"computed hessian"); return Error::handle(name(), L"train:Hessian computation", ERR, __FILE__, __LINE__); } // debugging information // if (debug_level_d > Integral::DETAILED) { hessian_d.debug(L"Hessian"); } */ // with the cholesky decomposition, we can solve systems // of equations as well as determine the diagonal elements // of the covariance. This is all we need for the RVM // training // if (!hessian_d.decompositionCholesky(covar_cholesky_d)) { // System.out.println("train:Cholesky decomposition"); return false; } /* // debugging information // if (debug_level_d > Integral::DETAILED) { covar_cholesky_d.debug(L"cholesky decomposition"); } */ if (debug_level_d) { // System.out.println("hessian_d: "); // hessian_d.printMatrix(); // System.out.println("gradient_d: "); // Matrix.printDoubleVector(gradient_d); // System.out.println("covar_cholesky_d: "); // covar_cholesky_d.printMatrix(); } // 1.b.3. update the weights: // because the space is convex we take a full Newton step // weights* = weights - inverse(Hessian) * gradient // // System.out.println("1.b.3. update the weights:"); status = true; gradient_M.initMatrix(gradient_d); status &= covar_cholesky_d.choleskySolve(temp_M, covar_cholesky_d, gradient_M); if (debug_level_d) { // System.out.println("temp_M: "); // temp_M.printMatrix(); } if (!last_iteration) { temp_M.scalarMultMatrix(lambda); temp2_M.initMatrix(old_irls_weights_d); temp_M.addMatrix(temp2_M); //temp_M.printMatrix(); temp_M.toDoubleVector(curr_weights_d); double val = MathUtil.vectorProduct(gradient_d, gradient_d); val = Math.sqrt(val); val /= curr_weights_d.size(); // 1.b.4. check convergence // /* if ( MathUtil.almostEqual(curr_weights_d, old_irls_weights_d) && MathUtil.almostEqual(gradient_d, 0.0 )){ */ if (val < 1e-6) { last_iteration = true; } } else { irls_converged = true; } /* if (!status) { covar_cholesky_d.debug(L"computed cholesky"); return Error::handle(name(), L"train:weight update", ERR, __FILE__, __LINE__); } */ // next iteration // num_irls_its++; pro_box_d.setProgressCurr((int)num_irls_its % 20); } if (debug_level_d) { // System.out.println("new likelihood 2: " + new_likelihood); } /* if (debug_level_d > Integral::NONE) { new_likelihood.debug(L"new likelihood"); } */ // generate the true hessian (negative of the one we store) // hessian_d.scalarMultMatrix(-1.0); computeSigma(); /* if (debug_level_d > Integral::DETAILED) { curr_weights_d.debug(L"current weights after IRLS"); Vector tmp_1; tmp_1.sub(targets_d, sigma_d); Vector tmp_weights; Matrix tmp_matrix(A_d); tmp_matrix.inverse(); tmp_matrix.changeType(Integral::FULL); tmp_matrix.mult(phi_d); tmp_matrix.multv(tmp_weights, tmp_1); tmp_weights.debug(L"computed weights after IRLS"); } */ // exit gracefully // return true; } /** * * Finalizes the trained model making it ready to write to file or use * in prediction * * * @@return true * * */ boolean finalizeTraining() { // Debug // System.out.println("AlgorithmRVM : finalizeTraining()"); // perform a final check on the range of the weights and // manually prune any that have fallen below the minimum // allowed weight value // if (Math.abs(bias_d) < min_allowed_weight_d) { bias_d = 0.0; } for (int i = 0; i < weights_d.size(); i++) { if (Math.abs(MathUtil.doubleValue(weights_d, i)) < min_allowed_weight_d) { weights_d.set(i, new Double(0.0)); } } // prune any remaining zero-weights with large hyperparameters // pruneAndUpdate(); // run a final iteration of IRLS training to get the final hessian // irlsTrain(); // compute the final inverse hessian and assign // inv_hessian_d.invertMatrix(hessian_d); inv_hessian_d.scalarMultMatrix(-1); // exit gracefully // return true; } /** * * prunes off vectors whose hyperparameters have gone to infinity and * updates working data sets * * * @@return true * */ public boolean pruneAndUpdate() { // Debug // System.out.println("AlgorithmRVM : pruneAndUpdate()"); /* // debugging information // if (debug_level_d > Integral::BRIEF) { A_d.debug(L"A"); } */ if (debug_level_d) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -