📄 adaboostm1.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* AdaBoostM1.java
* Copyright (C) 1999 Eibe Frank,Len Trigg
*
*/
package weka.classifiers.meta;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import eti.bi.util.NumberFormatter;
import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
import weka.classifiers.Sourcable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.UnsupportedAttributeTypeException;
import weka.core.UnsupportedClassTypeException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
/**
* Class for boosting a classifier using Freund & Schapire's Adaboost
* M1 method. For more information, see<p>
*
* Yoav Freund and Robert E. Schapire
* (1996). <i>Experiments with a new boosting algorithm</i>. Proc
* International Conference on Machine Learning, pages 148-156, Morgan
* Kaufmann, San Francisco.<p>
*
* Valid options are:<p>
*
* -D <br>
* Turn on debugging output.<p>
*
* -W classname <br>
* Specify the full class name of a classifier as the basis for
* boosting (required).<p>
*
* -I num <br>
* Set the number of boost iterations (default 10). <p>
*
* -P num <br>
* Set the percentage of weight mass used to build classifiers
* (default 100). <p>
*
* -Q <br>
* Use resampling instead of reweighting.<p>
*
* -S seed <br>
* Random number seed for resampling (default 1). <p>
*
* Options after -- are passed to the designated classifier.<p>
*
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @author Len Trigg (trigg@cs.waikato.ac.nz)
* @version $Revision$
*/
public class AdaBoostM1 extends RandomizableIteratedSingleClassifierEnhancer
implements WeightedInstancesHandler, Sourcable {
/** Max num iterations tried to find classifier with non-zero error. */
private static int MAX_NUM_RESAMPLING_ITERATIONS = 10;
/** Array for storing the weights for the votes. */
protected double [] m_Betas;
/** The number of successfully generated base classifiers. */
protected int m_NumIterationsPerformed;
/** Weight Threshold. The percentage of weight mass used in training */
protected int m_WeightThreshold = 100;
/** Use boosting with reweighting? */
protected boolean m_UseResampling;
/** The number of classes */
protected int m_NumClasses;
/**
* Returns a string describing classifier
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Class for boosting a nominal class classifier using the Adaboost "
+ "M1 method. Only nominal class problems can be tackled. Often "
+ "dramatically improves performance, but sometimes overfits. For more "
+ "information, see\n\n"
+ "Yoav Freund and Robert E. Schapire (1996). \"Experiments with a new boosting "
+ "algorithm\". Proc International Conference on Machine Learning, "
+ "pages 148-156, Morgan Kaufmann, San Francisco.";
}
/**
* Constructor.
*/
public AdaBoostM1() {
m_Classifier = new weka.classifiers.trees.DecisionStump();
}
/**
* String describing default classifier.
*/
protected String defaultClassifierString() {
return "weka.classifiers.trees.DecisionStump";
}
/**
* Select only instances with weights that contribute to
* the specified quantile of the weight distribution
*
* @param data the input instances
* @param quantile the specified quantile eg 0.9 to select
* 90% of the weight mass
* @return the selected instances
*/
protected Instances selectWeightQuantile(Instances data, double quantile) {
int numInstances = data.numInstances();
Instances trainData = new Instances(data, numInstances);
double [] weights = new double [numInstances];
double sumOfWeights = 0;
for(int i = 0; i < numInstances; i++) {
weights[i] = data.instance(i).weight();
sumOfWeights += weights[i];
}
double weightMassToSelect = sumOfWeights * quantile;
int [] sortedIndices = Utils.sort(weights);
// Select the instances
sumOfWeights = 0;
for(int i = numInstances - 1; i >= 0; i--) {
Instance instance = (Instance)data.instance(sortedIndices[i]).copy();
trainData.add(instance);
sumOfWeights += weights[sortedIndices[i]];
if ((sumOfWeights > weightMassToSelect) &&
(i > 0) &&
(weights[sortedIndices[i]] != weights[sortedIndices[i - 1]])) {
break;
}
}
if (m_Debug) {
System.err.println("Selected " + trainData.numInstances()
+ " out of " + numInstances);
}
return trainData;
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(2);
newVector.addElement(new Option(
"\tPercentage of weight mass to base training on.\n"
+"\t(default 100, reduce to around 90 speed up)",
"P", 1, "-P <num>"));
newVector.addElement(new Option(
"\tUse resampling for boosting.",
"Q", 0, "-Q"));
Enumeration em = super.listOptions();
while (em.hasMoreElements()) {
newVector.addElement(em.nextElement());
}
return newVector.elements();
}
/**
* Parses a given list of options. Valid options are:<p>
*
* -D <br>
* Turn on debugging output.<p>
*
* -W classname <br>
* Specify the full class name of a classifier as the basis for
* boosting (required).<p>
*
* -I num <br>
* Set the number of boost iterations (default 10). <p>
*
* -P num <br>
* Set the percentage of weight mass used to build classifiers
* (default 100). <p>
*
* -Q <br>
* Use resampling instead of reweighting.<p>
*
* -S seed <br>
* Random number seed for resampling (default 1).<p>
*
* Options after -- are passed to the designated classifier.<p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String thresholdString = Utils.getOption('P', options);
if (thresholdString.length() != 0) {
setWeightThreshold(Integer.parseInt(thresholdString));
} else {
setWeightThreshold(100);
}
setUseResampling(Utils.getFlag('Q', options));
if (m_UseResampling && (thresholdString.length() != 0)) {
throw new Exception("Weight pruning with resampling"+
"not allowed.");
}
super.setOptions(options);
}
/**
* Gets the current settings of the Classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String [] getOptions() {
String [] superOptions = super.getOptions();
String [] options = new String [superOptions.length + 2];
int current = 0;
if (getUseResampling()) {
options[current++] = "-Q";
} else {
options[current++] = "-P";
options[current++] = "" + getWeightThreshold();
}
System.arraycopy(superOptions, 0, options, current,
superOptions.length);
return options;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String weightThresholdTipText() {
return "Weight threshold for weight pruning.";
}
/**
* Set weight threshold
*
* @param thresholding the percentage of weight mass used for training
*/
public void setWeightThreshold(int threshold) {
m_WeightThreshold = threshold;
}
/**
* Get the degree of weight thresholding
*
* @return the percentage of weight mass used for training
*/
public int getWeightThreshold() {
return m_WeightThreshold;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String useResamplingTipText() {
return "Whether resampling is used instead of reweighting.";
}
/**
* Set resampling mode
*
* @param resampling true if resampling should be done
*/
public void setUseResampling(boolean r) {
m_UseResampling = r;
}
/**
* Get whether resampling is turned on
*
* @return true if resampling output is on
*/
public boolean getUseResampling() {
return m_UseResampling;
}
/**
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -