📄 logitboost.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* LogitBoost.java
* Copyright (C) 1999, 2002 Len Trigg, Eibe Frank
*
*/
package weka.classifiers.meta;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
import weka.classifiers.Sourcable;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.UnsupportedAttributeTypeException;
import weka.core.UnsupportedClassTypeException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
/**
* Class for performing additive logistic regression..
* This class performs classification using a regression scheme as the
* base learner, and can handle multi-class problems. For more
* information, see<p>
*
* Friedman, J., T. Hastie and R. Tibshirani (1998) <i>Additive Logistic
* Regression: a Statistical View of Boosting</i>
* <a href="ftp://stat.stanford.edu/pub/friedman/boost.ps">download
* postscript</a>. <p>
*
* Valid options are:<p>
*
* -D <br>
* Turn on debugging output.<p>
*
* -W classname <br>
* Specify the full class name of a weak learner as the basis for
* boosting (required).<p>
*
* -I num <br>
* Set the number of boost iterations (default 10). <p>
*
* -Q <br>
* Use resampling instead of reweighting.<p>
*
* -S seed <br>
* Random number seed for resampling (default 1).<p>
*
* -P num <br>
* Set the percentage of weight mass used to build classifiers
* (default 100). <p>
*
* -F num <br>
* Set number of folds for the internal cross-validation
* (default 0 -- no cross-validation). <p>
*
* -R num <br>
* Set number of runs for the internal cross-validation
* (default 1). <p>
*
* -L num <br>
* Set the threshold for the improvement of the
* average loglikelihood (default -Double.MAX_VALUE). <p>
*
* -H num <br>
* Set the value of the shrinkage parameter (default 1). <p>
*
* Options after -- are passed to the designated learner.<p>
*
* @author Len Trigg (trigg@cs.waikato.ac.nz)
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @version $Revision$
*/
public class LogitBoost extends RandomizableIteratedSingleClassifierEnhancer
implements Sourcable, WeightedInstancesHandler {
/** Array for storing the generated base classifiers.
Note: we are hiding the variable from IteratedSingleClassifierEnhancer*/
protected Classifier [][] m_Classifiers;
/** The number of classes */
protected int m_NumClasses;
/** The number of successfully generated base classifiers. */
protected int m_NumGenerated;
/** The number of folds for the internal cross-validation. */
protected int m_NumFolds = 0;
/** The number of runs for the internal cross-validation. */
protected int m_NumRuns = 1;
/** Weight thresholding. The percentage of weight mass used in training */
protected int m_WeightThreshold = 100;
/** A threshold for responses (Friedman suggests between 2 and 4) */
protected static final double Z_MAX = 3;
/** Dummy dataset with a numeric class */
protected Instances m_NumericClassData;
/** The actual class attribute (for getting class names) */
protected Attribute m_ClassAttribute;
/** Use boosting with reweighting? */
protected boolean m_UseResampling;
/** The threshold on the improvement of the likelihood */
protected double m_Precision = -Double.MAX_VALUE;
/** The value of the shrinkage parameter */
protected double m_Shrinkage = 1;
/** The random number generator used */
protected Random m_RandomInstance = null;
/** The value by which the actual target value for the
true class is offset. */
protected double m_Offset = 0.0;
/**
* Returns a string describing classifier
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Class for performing additive logistic regression. "
+ "This class performs classification using a regression scheme as the "
+ "base learner, and can handle multi-class problems. For more "
+ "information, see\n\n"
+ "Friedman, J., T. Hastie and R. Tibshirani (1998) \"Additive Logistic "
+ "Regression: a Statistical View of Boosting\". Technical report. "
+ "Stanford University.\n\n"
+ "Can do efficient internal cross-validation to determine "
+ "appropriate number of iterations.";
}
/**
* Constructor.
*/
public LogitBoost() {
m_Classifier = new weka.classifiers.trees.DecisionStump();
}
/**
* String describing default classifier.
*/
protected String defaultClassifierString() {
return "weka.classifiers.trees.DecisionStump";
}
/**
* Select only instances with weights that contribute to
* the specified quantile of the weight distribution
*
* @param data the input instances
* @param quantile the specified quantile eg 0.9 to select
* 90% of the weight mass
* @return the selected instances
*/
protected Instances selectWeightQuantile(Instances data, double quantile) {
int numInstances = data.numInstances();
Instances trainData = new Instances(data, numInstances);
double [] weights = new double [numInstances];
double sumOfWeights = 0;
for (int i = 0; i < numInstances; i++) {
weights[i] = data.instance(i).weight();
sumOfWeights += weights[i];
}
double weightMassToSelect = sumOfWeights * quantile;
int [] sortedIndices = Utils.sort(weights);
// Select the instances
sumOfWeights = 0;
for (int i = numInstances-1; i >= 0; i--) {
Instance instance = (Instance)data.instance(sortedIndices[i]).copy();
trainData.add(instance);
sumOfWeights += weights[sortedIndices[i]];
if ((sumOfWeights > weightMassToSelect) &&
(i > 0) &&
(weights[sortedIndices[i]] != weights[sortedIndices[i-1]])) {
break;
}
}
if (m_Debug) {
System.err.println("Selected " + trainData.numInstances()
+ " out of " + numInstances);
}
return trainData;
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(6);
newVector.addElement(new Option(
"\tUse resampling for boosting.",
"Q", 0, "-Q"));
newVector.addElement(new Option(
"\tPercentage of weight mass to base training on.\n"
+"\t(default 100, reduce to around 90 speed up)",
"P", 1, "-P <percent>"));
newVector.addElement(new Option(
"\tNumber of folds for internal cross-validation.\n"
+"\t(default 0 -- no cross-validation)",
"F", 1, "-F <num>"));
newVector.addElement(new Option(
"\tNumber of runs for internal cross-validation.\n"
+"\t(default 1)",
"R", 1, "-R <num>"));
newVector.addElement(new Option(
"\tThreshold on the improvement of the likelihood.\n"
+"\t(default -Double.MAX_VALUE)",
"L", 1, "-L <num>"));
newVector.addElement(new Option(
"\tShrinkage parameter.\n"
+"\t(default 1)",
"H", 1, "-H <num>"));
Enumeration em = super.listOptions();
while (em.hasMoreElements()) {
newVector.addElement(em.nextElement());
}
return newVector.elements();
}
/**
* Parses a given list of options. Valid options are:<p>
*
* -D <br>
* Turn on debugging output.<p>
*
* -W classname <br>
* Specify the full class name of a weak learner as the basis for
* boosting (required).<p>
*
* -I num <br>
* Set the number of boost iterations (default 10). <p>
*
* -Q <br>
* Use resampling instead of reweighting.<p>
*
* -S seed <br>
* Random number seed for resampling (default 1).<p>
*
* -P num <br>
* Set the percentage of weight mass used to build classifiers
* (default 100). <p>
*
* -F num <br>
* Set number of folds for the internal cross-validation
* (default 0 -- no cross-validation). <p>
*
* -R num <br>
* Set number of runs for the internal cross-validation
* (default 1. <p>
*
* -L num <br>
* Set the threshold for the improvement of the
* average loglikelihood (default -Double.MAX_VALUE). <p>
*
* -H num <br>
* Set the value of the shrinkage parameter (default 1). <p>
*
* Options after -- are passed to the designated learner.<p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String numFolds = Utils.getOption('F', options);
if (numFolds.length() != 0) {
setNumFolds(Integer.parseInt(numFolds));
} else {
setNumFolds(0);
}
String numRuns = Utils.getOption('R', options);
if (numRuns.length() != 0) {
setNumRuns(Integer.parseInt(numRuns));
} else {
setNumRuns(1);
}
String thresholdString = Utils.getOption('P', options);
if (thresholdString.length() != 0) {
setWeightThreshold(Integer.parseInt(thresholdString));
} else {
setWeightThreshold(100);
}
String precisionString = Utils.getOption('L', options);
if (precisionString.length() != 0) {
setLikelihoodThreshold(new Double(precisionString).
doubleValue());
} else {
setLikelihoodThreshold(-Double.MAX_VALUE);
}
String shrinkageString = Utils.getOption('H', options);
if (shrinkageString.length() != 0) {
setShrinkage(new Double(shrinkageString).
doubleValue());
} else {
setShrinkage(1.0);
}
setUseResampling(Utils.getFlag('Q', options));
if (m_UseResampling && (thresholdString.length() != 0)) {
throw new Exception("Weight pruning with resampling"+
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -