📄 logisticbase.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * LogisticBase.java * Copyright (C) 2003 Niels Landwehr * */package weka.classifiers.trees.lmt;import weka.core.*;import weka.classifiers.*;import weka.classifiers.functions.*;import weka.filters.Filter;/** * Base/helper class for building logistic regression models with the LogitBoost algorithm. * Used for building logistic model trees (weka.classifiers.trees.lmt.LMT) * and standalone logistic regression (weka.classifiers.functions.SimpleLogistic). * * @author Niels Landwehr * @version $Revision: 1.2 $ */public class LogisticBase extends Classifier implements WeightedInstancesHandler{ /** Header-only version of the numeric version of the training data*/ protected Instances m_numericDataHeader; /** * Numeric version of the training data. Original class is replaced by a numeric pseudo-class. */ protected Instances m_numericData; /** Training data */ protected Instances m_train; /** Use cross-validation to determine best number of LogitBoost iterations ?*/ protected boolean m_useCrossValidation; /**Use error on probabilities for stopping criterion of LogitBoost? */ protected boolean m_errorOnProbabilities; /**Use fixed number of iterations for LogitBoost? (if negative, cross-validate number of iterations)*/ protected int m_fixedNumIterations; /**Use heuristic to stop performing LogitBoost iterations earlier? * If enabled, LogitBoost is stopped if the current (local) minimum of the error on a test set as * a function of the number of iterations has not changed for m_heuristicStop iterations. */ protected int m_heuristicStop = 50; /**The number of LogitBoost iterations performed.*/ protected int m_numRegressions = 0; /**The maximum number of LogitBoost iterations*/ protected int m_maxIterations; /**The number of different classes*/ protected int m_numClasses; /**Array holding the simple regression functions fit by LogitBoost*/ protected SimpleLinearRegression[][] m_regressions; /**Number of folds for cross-validating number of LogitBoost iterations*/ protected static int m_numFoldsBoosting = 5; /**Threshold on the Z-value for LogitBoost*/ protected static final double Z_MAX = 3; /** * Constructor that creates LogisticBase object with standard options. */ public LogisticBase(){ m_fixedNumIterations = -1; m_useCrossValidation = true; m_errorOnProbabilities = false; m_maxIterations = 500; } /** * Constructor to create LogisticBase object. * @param numBoostingIterations fixed number of iterations for LogitBoost (if negative, use cross-validation or * stopping criterion on the training data). * @param useCrossValidation cross-validate number of LogitBoost iterations (if false, use stopping * criterion on the training data). * @param errorOnProbabilities if true, use error on probabilities * instead of misclassification for stopping criterion of LogitBoost */ public LogisticBase(int numBoostingIterations, boolean useCrossValidation, boolean errorOnProbabilities){ m_fixedNumIterations = numBoostingIterations; m_useCrossValidation = useCrossValidation; m_errorOnProbabilities = errorOnProbabilities; m_maxIterations = 500; } /** * Builds the logistic regression model usiing LogitBoost. * * @param data the training data */ public void buildClassifier(Instances data) throws Exception{ m_train = new Instances(data); m_numClasses = m_train.numClasses(); //init the array of simple regression functions m_regressions = initRegressions(); m_numRegressions = 0; //get numeric version of the training data (class variable replaced by numeric pseudo-class) m_numericData = getNumericData(m_train); //save header info m_numericDataHeader = new Instances(m_numericData, 0); if (m_fixedNumIterations > 0) { //run LogitBoost for fixed number of iterations performBoosting(m_fixedNumIterations); } else if (m_useCrossValidation) { //cross-validate number of LogitBoost iterations performBoostingCV(); } else { //run LogitBoost with number of iterations that minimizes error on the training set performBoosting(); } //only keep the simple regression functions that correspond to the selected number of LogitBoost iterations m_regressions = selectRegressions(m_regressions); } /** * Runs LogitBoost, determining the best number of iterations by cross-validation. */ protected void performBoostingCV() throws Exception{ //completed iteration keeps track of the number of iterations that have been //performed in every fold (some might stop earlier than others). //Best iteration is selected only from these. int completedIterations = m_maxIterations; Instances allData = new Instances(m_train); allData.stratify(m_numFoldsBoosting); double[] error = new double[m_maxIterations + 1]; for (int i = 0; i < m_numFoldsBoosting; i++) { //split into training/test data in fold Instances train = allData.trainCV(m_numFoldsBoosting,i); Instances test = allData.testCV(m_numFoldsBoosting,i); //initialize LogitBoost m_numRegressions = 0; m_regressions = initRegressions(); //run LogitBoost iterations int iterations = performBoosting(train,test,error,completedIterations); if (iterations < completedIterations) completedIterations = iterations; } //determine iteration with minimum error over the folds int bestIteration = getBestIteration(error,completedIterations); //rebuild model on all of the training data m_numRegressions = 0; performBoosting(bestIteration); } /** * Runs LogitBoost on a training set and monitors the error on a test set. * Used for running one fold when cross-validating the number of LogitBoost iterations. * @param train the training set * @param test the test set * @param error array to hold the logged error values * @param maxIterations the maximum number of LogitBoost iterations to run * @return the number of completed LogitBoost iterations (can be smaller than maxIterations * if the heuristic for early stopping is active or there is a problem while fitting the regressions * in LogitBoost). * */ protected int performBoosting(Instances train, Instances test, double[] error, int maxIterations) throws Exception{ //get numeric version of the (sub)set of training instances Instances numericTrain = getNumericData(train); //initialize Ys/Fs/ps double[][] trainYs = getYs(train); double[][] trainFs = getFs(numericTrain); double[][] probs = getProbs(trainFs); int iteration = 0; double[] testErrors = new double[maxIterations+1]; int noMin = 0; double lastMin = Double.MAX_VALUE; if (m_errorOnProbabilities) error[0] += getMeanAbsoluteError(test); else error[0] += getErrorRate(test); while (iteration < maxIterations) { //perform single LogitBoost iteration boolean foundAttribute = performIteration(iteration, trainYs, trainFs, probs, numericTrain); if (foundAttribute) { iteration++; m_numRegressions = iteration; } else { //could not fit simple linear regression: stop LogitBoost break; } if (m_errorOnProbabilities) error[iteration] += getMeanAbsoluteError(test); else error[iteration] += getErrorRate(test); //heuristic: stop LogitBoost if the current minimum has not changed for <m_heuristicStop> iterations if (noMin > m_heuristicStop) break; if (error[iteration] < lastMin) { lastMin = error[iteration]; noMin = 0; } else { noMin++; } } return iteration; } /** * Runs LogitBoost with a fixed number of iterations. * @param numIterations the number of iterations to run */ protected void performBoosting(int numIterations) throws Exception{ //initialize Ys/Fs/ps double[][] trainYs = getYs(m_train); double[][] trainFs = getFs(m_numericData); double[][] probs = getProbs(trainFs); int iteration = 0; //run iterations while (iteration < numIterations) { boolean foundAttribute = performIteration(iteration, trainYs, trainFs, probs, m_numericData); if (foundAttribute) iteration++; else break; } m_numRegressions = iteration; } /** * Runs LogitBoost using the stopping criterion on the training set. * The number of iterations is used that gives the lowest error on the training set, either misclassification * or error on probabilities (depending on the errorOnProbabilities option). */ protected void performBoosting() throws Exception{ //initialize Ys/Fs/ps double[][] trainYs = getYs(m_train); double[][] trainFs = getFs(m_numericData); double[][] probs = getProbs(trainFs); int iteration = 0; double[] trainErrors = new double[m_maxIterations+1]; trainErrors[0] = getErrorRate(m_train); int noMin = 0; double lastMin = Double.MAX_VALUE; while (iteration < m_maxIterations) { boolean foundAttribute = performIteration(iteration, trainYs, trainFs, probs, m_numericData); if (foundAttribute) { iteration++; m_numRegressions = iteration; } else { //could not fit simple regression break; } trainErrors[iteration] = getErrorRate(m_train); //heuristic: stop LogitBoost if the current minimum has not changed for <m_heuristicStop> iterations if (noMin > m_heuristicStop) break; if (trainErrors[iteration] < lastMin) { lastMin = trainErrors[iteration]; noMin = 0; } else { noMin++; } } //find iteration with best error m_numRegressions = getBestIteration(trainErrors, iteration); } /** * Returns the misclassification error of the current model on a set of instances. * @param data the set of instances * @return the error rate */ protected double getErrorRate(Instances data) throws Exception { Evaluation eval = new Evaluation(data); eval.evaluateModel(this,data); return eval.errorRate(); } /** * Returns the error of the probability estimates for the current model on a set of instances. * @param data the set of instances * @return the error */ protected double getMeanAbsoluteError(Instances data) throws Exception { Evaluation eval = new Evaluation(data); eval.evaluateModel(this,data); return eval.meanAbsoluteError(); } /** * Helper function to find the minimum in an array of error values. */ protected int getBestIteration(double[] errors, int maxIteration) { double bestError = errors[0]; int bestIteration = 0; for (int i = 1; i <= maxIteration; i++) { if (errors[i] < bestError) { bestError = errors[i]; bestIteration = i; } } return bestIteration; } /** * Performs a single iteration of LogitBoost, and updates the model accordingly. * A simple regression function is fit to the response and added to the m_regressions array. * @param iteration the current iteration * @param trainYs the y-values (see description of LogitBoost) for the model trained so far * @param trainFs the F-values (see description of LogitBoost) for the model trained so far * @param probs the p-values (see description of LogitBoost) for the model trained so far * @param trainNumeric numeric version of the training data * @return returns true if iteration performed successfully, false if no simple regression function * could be fitted. */ protected boolean performIteration(int iteration, double[][] trainYs, double[][] trainFs, double[][] probs, Instances trainNumeric) throws Exception { for (int j = 0; j < m_numClasses; j++) { //make copy of data (need to save the weights) Instances boostData = new Instances(trainNumeric); for (int i = 0; i < trainNumeric.numInstances(); i++) { //compute response and weight double p = probs[i][j]; double actual = trainYs[i][j]; double z = getZ(actual, p); double w = (actual - p) / z; //set values for instance Instance current = boostData.instance(i); current.setValue(boostData.classIndex(), z); current.setWeight(current.weight() * w); } //fit simple regression function m_regressions[j][iteration].buildClassifier(boostData); boolean foundAttribute = m_regressions[j][iteration].foundUsefulAttribute(); if (!foundAttribute) { //could not fit simple regression function return false; } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -