📄 bayesianlogisticregression.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * BayesianLogisticRegression.java * Copyright (C) 2008 Illinois Institute of Technology * */package weka.classifiers.bayes;import weka.classifiers.Classifier;import weka.classifiers.bayes.blr.GaussianPriorImpl;import weka.classifiers.bayes.blr.LaplacePriorImpl;import weka.classifiers.bayes.blr.Prior;import weka.core.Attribute;import weka.core.Capabilities;import weka.core.Instance;import weka.core.Instances;import weka.core.Option;import weka.core.OptionHandler;import weka.core.RevisionUtils;import weka.core.SelectedTag;import weka.core.SerializedObject;import weka.core.Tag;import weka.core.TechnicalInformation;import weka.core.TechnicalInformationHandler;import weka.core.Utils;import weka.core.Capabilities.Capability;import weka.core.TechnicalInformation.Field;import weka.core.TechnicalInformation.Type;import weka.filters.Filter;import weka.filters.unsupervised.attribute.Normalize;import java.util.Enumeration;import java.util.Random;import java.util.StringTokenizer;import java.util.Vector;/** <!-- globalinfo-start --> * Implements Bayesian Logistic Regression for both Gaussian and Laplace Priors.<br/> * <br/> * For more information, see<br/> * <br/> * Alexander Genkin, David D. Lewis, David Madigan (2004). Large-scale bayesian logistic regression for text categorization. URL http://www.stat.rutgers.edu/~madigan/PAPERS/shortFat-v3a.pdf. * <p/> <!-- globalinfo-end --> * <!-- technical-bibtex-start --> * BibTeX: * <pre> * @techreport{Genkin2004, * author = {Alexander Genkin and David D. Lewis and David Madigan}, * institution = {DIMACS}, * title = {Large-scale bayesian logistic regression for text categorization}, * year = {2004}, * URL = {http://www.stat.rutgers.edu/\~madigan/PAPERS/shortFat-v3a.pdf} * } * </pre> * <p/> <!-- technical-bibtex-end --> * * * @author Navendu Garg (gargnav at iit dot edu) * @version $Revision: 1.3 $ */public class BayesianLogisticRegression extends Classifier implements OptionHandler, TechnicalInformationHandler { static final long serialVersionUID = -8013478897911757631L; /** Log-likelihood values to be used to choose the best hyperparameter. */ public static double[] LogLikelihood; /** Set of values to be used as hyperparameter values during Cross-Validation. */ public static double[] InputHyperparameterValues; /** DEBUG Mode*/ boolean debug = false; /** Choose whether to normalize data or not */ public boolean NormalizeData = false; /** Tolerance criteria for the stopping criterion. */ public double Tolerance = 0.0005; /** Threshold for binary classification of probabilisitic estimate*/ public double Threshold = 0.5; /** Distributions available */ public static final int GAUSSIAN = 1; public static final int LAPLACIAN = 2; public static final Tag[] TAGS_PRIOR = { new Tag(GAUSSIAN, "Gaussian"), new Tag(LAPLACIAN, "Laplacian") }; /** Distribution Prior class */ public int PriorClass = GAUSSIAN; /** NumFolds for CV based Hyperparameters selection*/ public int NumFolds = 2; /** Methods for selecting the hyperparameter value */ public static final int NORM_BASED = 1; public static final int CV_BASED = 2; public static final int SPECIFIC_VALUE = 3; public static final Tag[] TAGS_HYPER_METHOD = { new Tag(NORM_BASED, "Norm-based"), new Tag(CV_BASED, "CV-based"), new Tag(SPECIFIC_VALUE, "Specific value") }; /** Hyperparameter selection method */ public int HyperparameterSelection = NORM_BASED; /** The class index from the training data */ public int ClassIndex = -1; /** Best hyperparameter for test phase */ public double HyperparameterValue = 0.27; /** CV Hyperparameter Range */ public String HyperparameterRange = "R:0.01-316,3.16"; /** Maximum number of iterations */ public int maxIterations = 100; /**Iteration counter */ public int iterationCounter = 0; /** Array for storing coefficients of Bayesian regression model. */ public double[] BetaVector; /** Array to store Regression Coefficient updates. */ public double[] DeltaBeta; /** Trust Region Radius Update*/ public double[] DeltaUpdate; /** Trust Region Radius*/ public double[] Delta; /** Array to store Hyperparameter values for each feature. */ public double[] Hyperparameters; /** R(i)= BetaVector X x(i) X y(i). * This an intermediate value with respect to vector BETA, input values and corresponding class labels*/ public double[] R; /** This vector is used to store the increments on the R(i). It is also used to determining the stopping criterion.*/ public double[] DeltaR; /** * This variable is used to keep track of change in * the value of delta summation of r(i). */ public double Change; /** * Bayesian Logistic Regression returns the probability of a given instance will belong to a certain * class (p(y=+1|Beta,X). To obtain a binary value the Threshold value is used. * <pre> * p(y=+1|Beta,X)>Threshold ? 1 : -1 * </pre> */ /** Filter interface used to point to weka.filters.unsupervised.attribute.Normalize object * */ public Filter m_Filter; /** Dataset provided to do Training/Test set.*/ protected Instances m_Instances; /** Prior class object interface*/ protected Prior m_PriorUpdate; public String globalInfo() { return "Implements Bayesian Logistic Regression " + "for both Gaussian and Laplace Priors.\n\n" + "For more information, see\n\n" + getTechnicalInformation(); } /** * <pre> * (1)Initialize m_Beta[j] to 0. * (2)Initialize m_DeltaUpdate[j]. * </pre> * * */ public void initialize() throws Exception { int numOfAttributes; int numOfInstances; int i; int j; Change = 0.0; //Manipulate Data if (NormalizeData) { m_Filter = new Normalize(); m_Filter.setInputFormat(m_Instances); m_Instances = Filter.useFilter(m_Instances, m_Filter); } //Set the intecept coefficient. Attribute att = new Attribute("(intercept)"); Instance instance; m_Instances.insertAttributeAt(att, 0); for (i = 0; i < m_Instances.numInstances(); i++) { instance = m_Instances.instance(i); instance.setValue(0, 1.0); } //Get the number of attributes numOfAttributes = m_Instances.numAttributes(); numOfInstances = m_Instances.numInstances(); ClassIndex = m_Instances.classIndex(); iterationCounter = 0; //Initialize Arrays. switch (HyperparameterSelection) { case 1: HyperparameterValue = normBasedHyperParameter(); if (debug) { System.out.println("Norm-based Hyperparameter: " + HyperparameterValue); } break; case 2: HyperparameterValue = CVBasedHyperparameter(); if (debug) { System.out.println("CV-based Hyperparameter: " + HyperparameterValue); } break; } BetaVector = new double[numOfAttributes]; Delta = new double[numOfAttributes]; DeltaBeta = new double[numOfAttributes]; Hyperparameters = new double[numOfAttributes]; DeltaUpdate = new double[numOfAttributes]; for (j = 0; j < numOfAttributes; j++) { BetaVector[j] = 0.0; Delta[j] = 1.0; DeltaBeta[j] = 0.0; DeltaUpdate[j] = 0.0; //TODO: Change the way it takes values. Hyperparameters[j] = HyperparameterValue; } DeltaR = new double[numOfInstances]; R = new double[numOfInstances]; for (i = 0; i < numOfInstances; i++) { DeltaR[i] = 0.0; R[i] = 0.0; } //Set the Prior interface to the appropriate prior implementation. if (PriorClass == GAUSSIAN) { m_PriorUpdate = new GaussianPriorImpl(); } else { m_PriorUpdate = new LaplacePriorImpl(); } } /** * This method tests what kind of data this classifier can handle. * return Capabilities */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // attributes result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.BINARY_ATTRIBUTES); // class result.enable(Capability.BINARY_CLASS); // instances result.setMinimumNumberInstances(0); return result; } /** * <ul> * <li>(1) Set the data to the class attribute m_Instances.</li> * <li>(2)Call the method initialize() to initialize the values.</li> * </ul> * @param data training data * @exception Exception if classifier can't be built successfully. */ public void buildClassifier(Instances data) throws Exception { Instance instance; int i; int j; // can classifier handle the data? getCapabilities().testWithFail(data); //(1) Set the data to the class attribute m_Instances. m_Instances = new Instances(data); //(2)Call the method initialize() to initialize the values. initialize(); do { //Compute the prior Trust Region Radius Update; for (j = 0; j < m_Instances.numAttributes(); j++) { if (j != ClassIndex) { DeltaUpdate[j] = m_PriorUpdate.update(j, m_Instances, BetaVector[j], Hyperparameters[j], R, Delta[j]); //limit step to trust region. DeltaBeta[j] = Math.min(Math.max(DeltaUpdate[j], 0 - Delta[j]), Delta[j]); //Update the for (i = 0; i < m_Instances.numInstances(); i++) { instance = m_Instances.instance(i); if (instance.value(j) != 0) { DeltaR[i] = DeltaBeta[j] * instance.value(j) * classSgn(instance.classValue()); R[i] += DeltaR[i]; } } //Updated Beta values. BetaVector[j] += DeltaBeta[j]; //Update size of trust region. Delta[j] = Math.max(2 * Math.abs(DeltaBeta[j]), Delta[j] / 2.0); } } } while (!stoppingCriterion()); m_PriorUpdate.computelogLikelihood(BetaVector, m_Instances); m_PriorUpdate.computePenalty(BetaVector, Hyperparameters); } /** * This class is used to mask the internal class labels. * * @param value internal class label * @return * <pre> * <ul><li> * -1 for internal class label 0 * </li> * <li> * +1 for internal class label 1 * </li> * </ul> * </pre> */ public static double classSgn(double value) { if (value == 0.0) { return -1.0; } else { return 1.0; } } /** * Returns an instance of a TechnicalInformation object, containing * detailed information about the technical background of this class, * e.g., paper reference or book this class is based on. * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result = null; result = new TechnicalInformation(Type.TECHREPORT); result.setValue(Field.AUTHOR, "Alexander Genkin and David D. Lewis and David Madigan"); result.setValue(Field.YEAR, "2004"); result.setValue(Field.TITLE, "Large-scale bayesian logistic regression for text categorization"); result.setValue(Field.INSTITUTION, "DIMACS"); result.setValue(Field.URL, "http://www.stat.rutgers.edu/~madigan/PAPERS/shortFat-v3a.pdf"); return result; } /** * This is a convient function that defines and upper bound * (Delta>0) for values of r(i) reachable by updates in the * trust region. * * r BetaVector X x(i)y(i). * delta A parameter where sigma > 0 * @return double function value */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -