📄 algorithmrvm.java,v
字号:
head 1.6;access;symbols;locks; strict;comment @# @;1.6date 2005.06.10.18.28.33; author rirwin; state Exp;branches;next 1.5;1.5date 2005.05.23.20.01.10; author rirwin; state Exp;branches;next 1.4;1.4date 2005.03.17.18.18.05; author patil; state Exp;branches;next 1.3;1.3date 2005.03.15.19.05.27; author patil; state Exp;branches;next 1.2;1.2date 2005.01.19.22.25.36; author patil; state Exp;branches;next 1.1;1.1date 2004.12.28.00.04.32; author patil; state Exp;branches;next ;desc@No changes made.@1.6log@Establishing RCS version.@text@/* * @@(#) AlgorithmRVM.java v6.0 03/15/2005 * created 02/09/03 * Last edited: Ryan Irwin * */// import java packages//import java.awt.*;import java.util.*;/** * Algorithm Relevance Vector Machines */public class AlgorithmRVM extends Algorithm{ //----------------------------------------------------------------- // // static data members // //----------------------------------------------------------------- static final double DEF_ALPHA_THRESH = 1e12; static final double DEF_MIN_ALLOWED_WEIGHT = 1e-12; static final long DEF_MAX_RVM_ITS = (long) (1 << 30); static final long DEF_MAX_UPDATE_ITS = (long) (1 << 30); static final double DEF_MIN_THETA = 1e-8; static final double DEF_MOMENTUM = 0.85; static final long DEF_MAX_ADDITIONS = 1; static final boolean DEF_SAVE_RESTART = false; static final boolean DEF_LOAD_RESTART = false; static final int KERNEL_TYPE_LINEAR = 0; static final int KERNEL_TYPE_RBF = 1; static final int KERNEL_TYPE_POLYNOMIAL = 2; static final int KERNEL_TYPE_DEFAULT = KERNEL_TYPE_LINEAR; static final Double CLASS_LABEL_1 = new Double(0.0); static final Double CLASS_LABEL_2 = new Double(1.0); //----------------------------------------------------------------- // // instance data members // //----------------------------------------------------------------- boolean debug_level_d = false; // RVM data member // // relevance vector weights, vectors, and labels // Vector<Vector<Double>> x_d = new Vector<Vector<Double>>(); // vectors_d Vector<Double> y_d = new Vector<Double>(); // targets_d Vector<Vector<Double>> evalx_d = new Vector<Vector<Double>>(); // vectors_d Vector<Double> evaly_d = new Vector<Double>(); // targets_d Matrix inv_hessian_d = new Matrix(); // A in [2] Vector<Double> weights_d = new Vector<Double>(); // w in [1] double bias_d = 0.0; // RVM parameters // //int kernel_type_d = KERNEL_TYPE_LINEAR; int kernel_type_d = KERNEL_TYPE_RBF; // vector of support region points // Vector<MyPoint> support_vectors_d = new Vector<MyPoint>(); Vector<MyPoint> decision_regions_d = new Vector<MyPoint>(); int output_canvas_d[][]; /** tuning parameters: These are the only parameters that a user need * worry about prior to training. The default quantities are usually * sufficient. However, run-time performance and accuracy can be * influenced by appropriately tuning these parameters. * * maximum hyperparameter value allowed before pruning. decreasing this * value can speed up convergence of the model but may yield overpruning * and poor generalization. the value should always be greater than zero */ double alpha_thresh_d; /** * minimum value of a weight allowed in the model. typically as the weight *decreases toward zero, it should be pruned. */ double min_allowed_weight_d; /** * maximum number of training iterations to carry out before stopping * adjusting this parameter can result in sub-optimal results */ long max_rvm_its_d; /** * maximum number of iterations that are allowed to pass betweeen * model updates (adding or pruning of a hyperparameter) before training * is terminated (for the full mode of training) or a vector is manually * added (for the incremental mode of training) */ long max_update_its_d; /** * minimum value of the theta calculation (the divisor of equation * 17 in [3]) that will trigger a model addition (in the * incremental training mode). */ double min_theta_d; /** * hyperparameter update momentum term. a larger value for this term * can lead to faster convergence, while too large a value can cause * oscillation. the value is typically on the range [0,1] */ double momentum_d; /** * number of hyperparameters to add at a time. adding a small number of * hyperparameters at a time will yield a smoother movement through the * model space, but may increase the total convergence time. */ long max_additions_d; /** * whether or not to create backup copies of training data. if true then * data will be occasionally saved to disk in the file provided. that * file can later be used to restart training in the middle of the * convergence process. *** the restart facility currently is available * only for incremental training *** */ boolean save_restart_d; //Filename restart_save_file_d; /** * whether or not to bootstrap training from a restart file. if true then * the given restart file is read and training is continued from that point * forward. *** the restart facility currently is available * only for incremental training *** */ boolean load_restart_d; // Filename restart_load_file_d; // model data // int num_samples_d; // number of remaining RVs Matrix A_d = new Matrix(); // hyperparameter matrix int dimA_d; // number of non-pruned params Matrix phi_d = new Matrix(); // working design matrix Vector<Double> curr_weights_d = new Vector<Double>(); // updated weights Vector<Double> last_rvm_weights_d = new Vector<Double>(); // stored weights for rvm pass // IRLS training quantities // Vector<Double> sigma_d = new Vector<Double>(); // error vector Matrix B_d = new Matrix(); // data-dependent "noise" Vector<Double> gradient_d = new Vector<Double>();// gradient w.r.t. weights Matrix hessian_d = new Matrix(); // hessian w.r.t. weights Matrix covar_cholesky_d = new Matrix(); // cholesky decomposition of covar Vector<Double> old_irls_weights_d = new Vector<Double>(); // stored weights for irls pass long last_changed_d; // counter for last time model changed // incremental training quantities // Vector S_d = new Vector(); // updates for incremental train // Vector hyperparams_d = new Vector(); // current hyperparameters // Vector weights_d; // current hyperparameters Vector<Double> last_hyperparams_d = new Vector<Double>(); // previous iterations hyperparams Vector<Double> twoback_hyperparams_d = new Vector<Double>(); // hyperparameters from two //------------------------------------------------------------------- // // classification functions // //-------------------------------------------------------------------- /** * Overrides the initialize() method in the base class. Initializes * member data and prepares for execution of first step. This method * "resets" the algorithm. * * @@return true */ public boolean initialize() { // Debug // // System.out.println("AlgorithmRVM : initialize()"); // check the data points // if (output_panel_d == null) { return false; } alpha_thresh_d = 1e4; min_allowed_weight_d = 1e-8; max_rvm_its_d = DEF_MAX_RVM_ITS; max_update_its_d = DEF_MAX_UPDATE_ITS; max_update_its_d = 100; min_theta_d = DEF_MIN_THETA; momentum_d = DEF_MOMENTUM; max_additions_d = DEF_MAX_ADDITIONS; save_restart_d = DEF_SAVE_RESTART; load_restart_d = DEF_LOAD_RESTART; // add the process description for the RVM algorithm // if (description_d.size() == 0) { String str = new String(" 0. Initialize the original data."); description_d.addElement(str); str = new String(" 1. Displaying the original data."); description_d.addElement(str); str = new String(" 2. Computing the Relevance Vectors."); description_d.addElement(str); str = new String(" 3. Computing the decision regions."); description_d.addElement(str); } // append message to process box // pro_box_d.appendMessage("Relevance Vector Machine :" + "\n"); // set the data points for this algorithm // // set1_d = (Vector)data_points_d.dset1.clone(); // set2_d = (Vector)data_points_d.dset2.clone(); // set1_d = data_points_d.dset1; set2_d = data_points_d.dset2; // reset values // support_vectors_d = new Vector<MyPoint>(); decision_regions_d = new Vector<MyPoint>(); step_count = 3; x_d = new Vector<Vector<Double>>(); y_d = new Vector<Double>(); // set the step index // step_index_d = 0; // append message to process box // pro_box_d.appendMessage((String)description_d.get(step_index_d)); // exit gracefully // return true; } /** * Implementation of the run function from the Runnable interface. * Determines what the current step is and calls the appropriate method. */ public void run() { // Debug // // System.out.println(algo_id + ": run()"); if (step_index_d == 1) { disableControl(); step1(); enableControl(); } else if (step_index_d == 2) { disableControl(); step2(); enableControl(); } else if (step_index_d == 3) { disableControl(); step3(); pro_box_d.appendMessage(" Algorithm Complete"); enableControl(); } // exit gracefully // return; } /** * * step one of the algorithm. Scales the display to fit the plot. * * @@return true */ boolean step1() { // debug // // System.out.println(algo_id + ": step1()"); pro_box_d.setProgressMin(0); pro_box_d.setProgressMax(1); pro_box_d.setProgressCurr(0); scaleToFitData(); // Display original data // output_panel_d.addOutput(set1_d, Classify.PTYPE_INPUT, data_points_d.color_dset1); output_panel_d.addOutput(set2_d, Classify.PTYPE_INPUT, data_points_d.color_dset2); output_panel_d.addOutput(set3_d, Classify.PTYPE_INPUT, data_points_d.color_dset3); output_panel_d.addOutput(set4_d, Classify.PTYPE_INPUT, data_points_d.color_dset4); // step 1 completed // pro_box_d.setProgressCurr(1); output_panel_d.repaint(); // exit gracefully // return true; } /** * * step two of the algorithm. Finds the PCA for the given data * * @@return true */ boolean step2() { // Debug // // System.out.println("AlgorithmRVM : step2()"); pro_box_d.setProgressMin(0); pro_box_d.setProgressMax(20); pro_box_d.setProgressCurr(0); trainFull(); output_panel_d.addOutput(support_vectors_d, Classify.PTYPE_SUPPORT_VECTOR, Color.black); pro_box_d.setProgressCurr(20); output_panel_d.repaint(); // exit gracefully // return true; } /** * * step three of the algorithm * * @@return true */ boolean step3() { // Debug // // System.out.println("AlgorithmRVM : step3()"); computeDecisionRegions(); // display support vectors // output_panel_d.addOutput(decision_regions_d, Classify.PTYPE_INPUT, new Color(255, 200, 0)); output_panel_d.repaint(); computeErrors(); // exit gracefully // return true; } /** * * this method trains an RVM probabilistic classifier on the input data and * targets provided. the inputs and targets should have a one-to-one * correspondence and all targets should be either 0 (out-of-class) or * 1 (in-class). The training scheme follows that of [1] section 3. * It is assumed that the class data and targets are already set when * this method is called. * * the training algorithm in [1] for RVMs proceeds in three iterative steps * * 1. prune away any weights whose hyperparameters have gone to infinity * * 2. estimate most probable weights: in this step we find those * weights that maximize equation (24) of [1]. The iteratively * reweighted least squares algorithm is used to find w_mp * * 3. estimate the covariance of a Gaussian approximation to the * posterior distribution (the posterior is what we want to * model in the end) over the weights centered at the weights, * w_mp. * * 4. estimate the hyperparameters that govern the weights. this * is done by evaluating the derivative over the hyperparameters * and finding the maximizing hyperparameters. * * 1, 2, 3, and 4 are carried out iteratively until a suitable convergence * criteria is satisfied. * * @@return boolean value indicating status * */ public boolean trainFull() { // Debug // // System.out.println("AlgorithmRVM : trainFull()"); // debugging information // //debug_level_d = true; // 0. initialize data structures for training // if (!initFullTrain()) { // System.out.println("Error at initFullTrain "); } if (debug_level_d) { // System.out.println("RVM training"); } if (debug_level_d) { // Matrix.printDoubleVector(x_d); Matrix.printDoubleVector(y_d); } // debugging information // if (debug_level_d) { // System.out.println("RVM convergence loop"); } // loop until convergence or until a maximum number of iterations has // been reached // long num_rvm_its = 0; boolean rvm_converged = false; while (!rvm_converged) { // store the weights achieved on the last iteration so we // can test for convergence later //
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -