📄 logisticbase.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* LogisticBase.java
* Copyright (C) 2003 Niels Landwehr
*
*/
package weka.classifiers.trees.lmt;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.SimpleLinearRegression;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
/**
* Base/helper class for building logistic regression models with the LogitBoost algorithm.
* Used for building logistic model trees (weka.classifiers.trees.lmt.LMT)
* and standalone logistic regression (weka.classifiers.functions.SimpleLogistic).
*
* @author Niels Landwehr
* @version $Revision$
*/
public class LogisticBase extends Classifier implements WeightedInstancesHandler{
/** Header-only version of the numeric version of the training data*/
protected Instances m_numericDataHeader;
/**
* Numeric version of the training data. Original class is replaced by a numeric pseudo-class.
*/
protected Instances m_numericData;
/** Training data */
protected Instances m_train;
/** Use cross-validation to determine best number of LogitBoost iterations ?*/
protected boolean m_useCrossValidation;
/**Use error on probabilities for stopping criterion of LogitBoost? */
protected boolean m_errorOnProbabilities;
/**Use fixed number of iterations for LogitBoost? (if negative, cross-validate number of iterations)*/
protected int m_fixedNumIterations;
/**Use heuristic to stop performing LogitBoost iterations earlier?
* If enabled, LogitBoost is stopped if the current (local) minimum of the error on a test set as
* a function of the number of iterations has not changed for m_heuristicStop iterations.
*/
protected int m_heuristicStop = 50;
/**The number of LogitBoost iterations performed.*/
protected int m_numRegressions = 0;
/**The maximum number of LogitBoost iterations*/
protected int m_maxIterations;
/**The number of different classes*/
protected int m_numClasses;
/**Array holding the simple regression functions fit by LogitBoost*/
protected SimpleLinearRegression[][] m_regressions;
/**Number of folds for cross-validating number of LogitBoost iterations*/
protected static int m_numFoldsBoosting = 5;
/**Threshold on the Z-value for LogitBoost*/
protected static final double Z_MAX = 3;
/**
* Constructor that creates LogisticBase object with standard options.
*/
public LogisticBase(){
m_fixedNumIterations = -1;
m_useCrossValidation = true;
m_errorOnProbabilities = false;
m_maxIterations = 500;
}
/**
* Constructor to create LogisticBase object.
* @param numBoostingIterations fixed number of iterations for LogitBoost (if negative, use cross-validation or
* stopping criterion on the training data).
* @param useCrossValidation cross-validate number of LogitBoost iterations (if false, use stopping
* criterion on the training data).
* @param errorOnProbabilities if true, use error on probabilities
* instead of misclassification for stopping criterion of LogitBoost
*/
public LogisticBase(int numBoostingIterations, boolean useCrossValidation, boolean errorOnProbabilities){
m_fixedNumIterations = numBoostingIterations;
m_useCrossValidation = useCrossValidation;
m_errorOnProbabilities = errorOnProbabilities;
m_maxIterations = 500;
}
/**
* Builds the logistic regression model usiing LogitBoost.
*
* @param data the training data
*/
public void buildClassifier(Instances data) throws Exception{
m_train = new Instances(data);
m_numClasses = m_train.numClasses();
//init the array of simple regression functions
m_regressions = initRegressions();
m_numRegressions = 0;
//get numeric version of the training data (class variable replaced by numeric pseudo-class)
m_numericData = getNumericData(m_train);
//save header info
m_numericDataHeader = new Instances(m_numericData, 0);
if (m_fixedNumIterations > 0) {
//run LogitBoost for fixed number of iterations
performBoosting(m_fixedNumIterations);
} else if (m_useCrossValidation) {
//cross-validate number of LogitBoost iterations
performBoostingCV();
} else {
//run LogitBoost with number of iterations that minimizes error on the training set
performBoosting();
}
//only keep the simple regression functions that correspond to the selected number of LogitBoost iterations
m_regressions = selectRegressions(m_regressions);
}
/**
* Runs LogitBoost, determining the best number of iterations by cross-validation.
*/
protected void performBoostingCV() throws Exception{
//completed iteration keeps track of the number of iterations that have been
//performed in every fold (some might stop earlier than others).
//Best iteration is selected only from these.
int completedIterations = m_maxIterations;
Instances allData = new Instances(m_train);
allData.stratify(m_numFoldsBoosting);
double[] error = new double[m_maxIterations + 1];
for (int i = 0; i < m_numFoldsBoosting; i++) {
//split into training/test data in fold
Instances train = allData.trainCV(m_numFoldsBoosting,i);
Instances test = allData.testCV(m_numFoldsBoosting,i);
//initialize LogitBoost
m_numRegressions = 0;
m_regressions = initRegressions();
//run LogitBoost iterations
int iterations = performBoosting(train,test,error,completedIterations);
if (iterations < completedIterations) completedIterations = iterations;
}
//determine iteration with minimum error over the folds
int bestIteration = getBestIteration(error,completedIterations);
//rebuild model on all of the training data
m_numRegressions = 0;
performBoosting(bestIteration);
}
/**
* Runs LogitBoost on a training set and monitors the error on a test set.
* Used for running one fold when cross-validating the number of LogitBoost iterations.
* @param train the training set
* @param test the test set
* @param error array to hold the logged error values
* @param maxIterations the maximum number of LogitBoost iterations to run
* @return the number of completed LogitBoost iterations (can be smaller than maxIterations
* if the heuristic for early stopping is active or there is a problem while fitting the regressions
* in LogitBoost).
*
*/
protected int performBoosting(Instances train, Instances test,
double[] error, int maxIterations) throws Exception{
//get numeric version of the (sub)set of training instances
Instances numericTrain = getNumericData(train);
//initialize Ys/Fs/ps
double[][] trainYs = getYs(train);
double[][] trainFs = getFs(numericTrain);
double[][] probs = getProbs(trainFs);
int iteration = 0;
double[] testErrors = new double[maxIterations+1];
int noMin = 0;
double lastMin = Double.MAX_VALUE;
if (m_errorOnProbabilities) error[0] += getMeanAbsoluteError(test);
else error[0] += getErrorRate(test);
while (iteration < maxIterations) {
//perform single LogitBoost iteration
boolean foundAttribute = performIteration(iteration, trainYs, trainFs, probs, numericTrain);
if (foundAttribute) {
iteration++;
m_numRegressions = iteration;
} else {
//could not fit simple linear regression: stop LogitBoost
break;
}
if (m_errorOnProbabilities) error[iteration] += getMeanAbsoluteError(test);
else error[iteration] += getErrorRate(test);
//heuristic: stop LogitBoost if the current minimum has not changed for <m_heuristicStop> iterations
if (noMin > m_heuristicStop) break;
if (error[iteration] < lastMin) {
lastMin = error[iteration];
noMin = 0;
} else {
noMin++;
}
}
return iteration;
}
/**
* Runs LogitBoost with a fixed number of iterations.
* @param numIterations the number of iterations to run
*/
protected void performBoosting(int numIterations) throws Exception{
//initialize Ys/Fs/ps
double[][] trainYs = getYs(m_train);
double[][] trainFs = getFs(m_numericData);
double[][] probs = getProbs(trainFs);
int iteration = 0;
//run iterations
while (iteration < numIterations) {
boolean foundAttribute = performIteration(iteration, trainYs, trainFs, probs, m_numericData);
if (foundAttribute) iteration++;
else break;
}
m_numRegressions = iteration;
}
/**
* Runs LogitBoost using the stopping criterion on the training set.
* The number of iterations is used that gives the lowest error on the training set, either misclassification
* or error on probabilities (depending on the errorOnProbabilities option).
*/
protected void performBoosting() throws Exception{
//initialize Ys/Fs/ps
double[][] trainYs = getYs(m_train);
double[][] trainFs = getFs(m_numericData);
double[][] probs = getProbs(trainFs);
int iteration = 0;
double[] trainErrors = new double[m_maxIterations+1];
trainErrors[0] = getErrorRate(m_train);
int noMin = 0;
double lastMin = Double.MAX_VALUE;
while (iteration < m_maxIterations) {
boolean foundAttribute = performIteration(iteration, trainYs, trainFs, probs, m_numericData);
if (foundAttribute) {
iteration++;
m_numRegressions = iteration;
} else {
//could not fit simple regression
break;
}
trainErrors[iteration] = getErrorRate(m_train);
//heuristic: stop LogitBoost if the current minimum has not changed for <m_heuristicStop> iterations
if (noMin > m_heuristicStop) break;
if (trainErrors[iteration] < lastMin) {
lastMin = trainErrors[iteration];
noMin = 0;
} else {
noMin++;
}
}
//find iteration with best error
m_numRegressions = getBestIteration(trainErrors, iteration);
}
/**
* Returns the misclassification error of the current model on a set of instances.
* @param data the set of instances
* @return the error rate
*/
protected double getErrorRate(Instances data) throws Exception {
Evaluation eval = new Evaluation(data);
eval.evaluateModel(this,data);
return eval.errorRate();
}
/**
* Returns the error of the probability estimates for the current model on a set of instances.
* @param data the set of instances
* @return the error
*/
protected double getMeanAbsoluteError(Instances data) throws Exception {
Evaluation eval = new Evaluation(data);
eval.evaluateModel(this,data);
return eval.meanAbsoluteError();
}
/**
* Helper function to find the minimum in an array of error values.
*/
protected int getBestIteration(double[] errors, int maxIteration) {
double bestError = errors[0];
int bestIteration = 0;
for (int i = 1; i <= maxIteration; i++) {
if (errors[i] < bestError) {
bestError = errors[i];
bestIteration = i;
}
}
return bestIteration;
}
/**
* Performs a single iteration of LogitBoost, and updates the model accordingly.
* A simple regression function is fit to the response and added to the m_regressions array.
* @param iteration the current iteration
* @param trainYs the y-values (see description of LogitBoost) for the model trained so far
* @param trainFs the F-values (see description of LogitBoost) for the model trained so far
* @param probs the p-values (see description of LogitBoost) for the model trained so far
* @param trainNumeric numeric version of the training data
* @return returns true if iteration performed successfully, false if no simple regression function
* could be fitted.
*/
protected boolean performIteration(int iteration,
double[][] trainYs,
double[][] trainFs,
double[][] probs,
Instances trainNumeric) throws Exception {
for (int j = 0; j < m_numClasses; j++) {
//make copy of data (need to save the weights)
Instances boostData = new Instances(trainNumeric);
for (int i = 0; i < trainNumeric.numInstances(); i++) {
//compute response and weight
double p = probs[i][j];
double actual = trainYs[i][j];
double z = getZ(actual, p);
double w = (actual - p) / z;
//set values for instance
Instance current = boostData.instance(i);
current.setValue(boostData.classIndex(), z);
current.setWeight(current.weight() * w);
}
//fit simple regression function
m_regressions[j][iteration].buildClassifier(boostData);
boolean foundAttribute = m_regressions[j][iteration].foundUsefulAttribute();
if (!foundAttribute) {
//could not fit simple regression function
return false;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -