📄 logisticbase.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * LogisticBase.java * Copyright (C) 2003 Niels Landwehr * */package weka.classifiers.trees.lmt;import weka.classifiers.Classifier;import weka.classifiers.Evaluation;import weka.classifiers.functions.SimpleLinearRegression;import weka.core.Attribute;import weka.core.Instance;import weka.core.Instances;import weka.core.Utils;import weka.core.WeightedInstancesHandler;/** * Base/helper class for building logistic regression models with the LogitBoost algorithm. * Used for building logistic model trees (weka.classifiers.trees.lmt.LMT) * and standalone logistic regression (weka.classifiers.functions.SimpleLogistic). * <!-- options-start --> * Valid options are: <p/> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * <!-- options-end --> * * @author Niels Landwehr * @author Marc Sumner * @version $Revision: 1.6 $ */public class LogisticBase extends Classifier implements WeightedInstancesHandler { /** for serialization */ static final long serialVersionUID = 168765678097825064L; /** Header-only version of the numeric version of the training data*/ protected Instances m_numericDataHeader; /** * Numeric version of the training data. Original class is replaced by a numeric pseudo-class. */ protected Instances m_numericData; /** Training data */ protected Instances m_train; /** Use cross-validation to determine best number of LogitBoost iterations ?*/ protected boolean m_useCrossValidation; /**Use error on probabilities for stopping criterion of LogitBoost? */ protected boolean m_errorOnProbabilities; /**Use fixed number of iterations for LogitBoost? (if negative, cross-validate number of iterations)*/ protected int m_fixedNumIterations; /**Use heuristic to stop performing LogitBoost iterations earlier? * If enabled, LogitBoost is stopped if the current (local) minimum of the error on a test set as * a function of the number of iterations has not changed for m_heuristicStop iterations. */ protected int m_heuristicStop = 50; /**The number of LogitBoost iterations performed.*/ protected int m_numRegressions = 0; /**The maximum number of LogitBoost iterations*/ protected int m_maxIterations; /**The number of different classes*/ protected int m_numClasses; /**Array holding the simple regression functions fit by LogitBoost*/ protected SimpleLinearRegression[][] m_regressions; /**Number of folds for cross-validating number of LogitBoost iterations*/ protected static int m_numFoldsBoosting = 5; /**Threshold on the Z-value for LogitBoost*/ protected static final double Z_MAX = 3; /** If true, the AIC is used to choose the best iteration*/ private boolean m_useAIC = false; /** Effective number of parameters used for AIC / BIC automatic stopping */ protected double m_numParameters = 0; /**Threshold for trimming weights. Instances with a weight lower than this (as a percentage * of total weights) are not included in the regression fit. **/ protected double m_weightTrimBeta = 0; /** * Constructor that creates LogisticBase object with standard options. */ public LogisticBase(){ m_fixedNumIterations = -1; m_useCrossValidation = true; m_errorOnProbabilities = false; m_maxIterations = 500; m_useAIC = false; m_numParameters = 0; } /** * Constructor to create LogisticBase object. * @param numBoostingIterations fixed number of iterations for LogitBoost (if negative, use cross-validation or * stopping criterion on the training data). * @param useCrossValidation cross-validate number of LogitBoost iterations (if false, use stopping * criterion on the training data). * @param errorOnProbabilities if true, use error on probabilities * instead of misclassification for stopping criterion of LogitBoost */ public LogisticBase(int numBoostingIterations, boolean useCrossValidation, boolean errorOnProbabilities){ m_fixedNumIterations = numBoostingIterations; m_useCrossValidation = useCrossValidation; m_errorOnProbabilities = errorOnProbabilities; m_maxIterations = 500; m_useAIC = false; m_numParameters = 0; } /** * Builds the logistic regression model usiing LogitBoost. * * @param data the training data * @throws Exception if something goes wrong */ public void buildClassifier(Instances data) throws Exception { m_train = new Instances(data); m_numClasses = m_train.numClasses(); //init the array of simple regression functions m_regressions = initRegressions(); m_numRegressions = 0; //get numeric version of the training data (class variable replaced by numeric pseudo-class) m_numericData = getNumericData(m_train); //save header info m_numericDataHeader = new Instances(m_numericData, 0); if (m_fixedNumIterations > 0) { //run LogitBoost for fixed number of iterations performBoosting(m_fixedNumIterations); } else if (m_useAIC) { // Marc had this after the test for m_useCrossValidation. Changed by Eibe. //run LogitBoost using information criterion for stopping performBoostingInfCriterion(); } else if (m_useCrossValidation) { //cross-validate number of LogitBoost iterations performBoostingCV(); } else { //run LogitBoost with number of iterations that minimizes error on the training set performBoosting(); } //only keep the simple regression functions that correspond to the selected number of LogitBoost iterations m_regressions = selectRegressions(m_regressions); } /** * Runs LogitBoost, determining the best number of iterations by cross-validation. * * @throws Exception if something goes wrong */ protected void performBoostingCV() throws Exception{ //completed iteration keeps track of the number of iterations that have been //performed in every fold (some might stop earlier than others). //Best iteration is selected only from these. int completedIterations = m_maxIterations; Instances allData = new Instances(m_train); allData.stratify(m_numFoldsBoosting); double[] error = new double[m_maxIterations + 1]; for (int i = 0; i < m_numFoldsBoosting; i++) { //split into training/test data in fold Instances train = allData.trainCV(m_numFoldsBoosting,i); Instances test = allData.testCV(m_numFoldsBoosting,i); //initialize LogitBoost m_numRegressions = 0; m_regressions = initRegressions(); //run LogitBoost iterations int iterations = performBoosting(train,test,error,completedIterations); if (iterations < completedIterations) completedIterations = iterations; } //determine iteration with minimum error over the folds int bestIteration = getBestIteration(error,completedIterations); //rebuild model on all of the training data m_numRegressions = 0; performBoosting(bestIteration); } /** * Runs LogitBoost, determining the best number of iterations by an information criterion (currently AIC). */ protected void performBoostingInfCriterion() throws Exception{ double criterion = 0.0; double bestCriterion = Double.MAX_VALUE; int bestIteration = 0; int noMin = 0; // Variable to keep track of criterion values (AIC) double criterionValue = Double.MAX_VALUE; // initialize Ys/Fs/ps double[][] trainYs = getYs(m_train); double[][] trainFs = getFs(m_numericData); double[][] probs = getProbs(trainFs); // Array with true/false if the attribute is included in the model or not boolean[][] attributes = new boolean[m_numClasses][m_numericDataHeader.numAttributes()]; int iteration = 0; while (iteration < m_maxIterations) { //perform single LogitBoost iteration boolean foundAttribute = performIteration(iteration, trainYs, trainFs, probs, m_numericData); if (foundAttribute) { iteration++; m_numRegressions = iteration; } else { //could not fit simple linear regression: stop LogitBoost break; } double numberOfAttributes = m_numParameters + iteration; // Fill criterion array values criterionValue = 2.0 * negativeLogLikelihood(trainYs, probs) + 2.0 * numberOfAttributes; //heuristic: stop LogitBoost if the current minimum has not changed for <m_heuristicStop> iterations if (noMin > m_heuristicStop) break; if (criterionValue < bestCriterion) { bestCriterion = criterionValue; bestIteration = iteration; noMin = 0; } else { noMin++; } } m_numRegressions = 0; performBoosting(bestIteration); } /** * Runs LogitBoost on a training set and monitors the error on a test set. * Used for running one fold when cross-validating the number of LogitBoost iterations. * @param train the training set * @param test the test set * @param error array to hold the logged error values * @param maxIterations the maximum number of LogitBoost iterations to run * @return the number of completed LogitBoost iterations (can be smaller than maxIterations * if the heuristic for early stopping is active or there is a problem while fitting the regressions * in LogitBoost). * @throws Exception if something goes wrong */ protected int performBoosting(Instances train, Instances test, double[] error, int maxIterations) throws Exception{ //get numeric version of the (sub)set of training instances Instances numericTrain = getNumericData(train); //initialize Ys/Fs/ps double[][] trainYs = getYs(train); double[][] trainFs = getFs(numericTrain); double[][] probs = getProbs(trainFs); int iteration = 0; int noMin = 0; double lastMin = Double.MAX_VALUE; if (m_errorOnProbabilities) error[0] += getMeanAbsoluteError(test); else error[0] += getErrorRate(test); while (iteration < maxIterations) { //perform single LogitBoost iteration boolean foundAttribute = performIteration(iteration, trainYs, trainFs, probs, numericTrain); if (foundAttribute) { iteration++; m_numRegressions = iteration; } else { //could not fit simple linear regression: stop LogitBoost break; } if (m_errorOnProbabilities) error[iteration] += getMeanAbsoluteError(test); else error[iteration] += getErrorRate(test); //heuristic: stop LogitBoost if the current minimum has not changed for <m_heuristicStop> iterations if (noMin > m_heuristicStop) break; if (error[iteration] < lastMin) { lastMin = error[iteration]; noMin = 0; } else { noMin++; } } return iteration; } /** * Runs LogitBoost with a fixed number of iterations. * @param numIterations the number of iterations to run * @throws Exception if something goes wrong
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -