📄 thresholdselector.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* ThresholdSelector.java
* Copyright (C) 1999 Eibe Frank
*
*/
package weka.classifiers.meta;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.EvaluationUtils;
import weka.classifiers.evaluation.ThresholdCurve;
import weka.core.Attribute;
import weka.core.AttributeStats;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.UnsupportedClassTypeException;
import weka.core.Utils;
/**
* Class for selecting a threshold on a probability output by a
* distribution classifier. The threshold is set so that a given
* performance measure is optimized. Currently this is the
* F-measure. Performance is measured either on the training data, a hold-out
* set or using cross-validation. In addition, the probabilities returned
* by the base learner can have their range expanded so that the output
* probabilities will reside between 0 and 1 (this is useful if the scheme
* normally produces probabilities in a very narrow range).<p>
*
* Valid options are:<p>
*
* -C num <br>
* The class for which threshold is determined. Valid values are:
* 1, 2 (for first and second classes, respectively), 3 (for whichever
* class is least frequent), 4 (for whichever class value is most
* frequent), and 5 (for the first class named any of "yes","pos(itive)",
* "1", or method 3 if no matches). (default 5). <p>
*
* -W classname <br>
* Specify the full class name of the base classifier. <p>
*
* -X num <br>
* Number of folds used for cross validation. If just a
* hold-out set is used, this determines the size of the hold-out set
* (default 3).<p>
*
* -R integer <br>
* Sets whether confidence range correction is applied. This can be used
* to ensure the confidences range from 0 to 1. Use 0 for no range correction,
* 1 for correction based on the min/max values seen during threshold selection
* (default 0).<p>
*
* -S seed <br>
* Random number seed (default 1).<p>
*
* -E integer <br>
* Sets the evaluation mode. Use 0 for evaluation using cross-validation,
* 1 for evaluation using hold-out set, and 2 for evaluation on the
* training data (default 1).<p>
*
* Options after -- are passed to the designated sub-classifier. <p>
*
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @version $Revision$
*/
public class ThresholdSelector extends Classifier
implements OptionHandler, Drawable {
/* Type of correction applied to threshold range */
public static final int RANGE_NONE = 0;
public static final int RANGE_BOUNDS = 1;
public static final Tag [] TAGS_RANGE = {
new Tag(RANGE_NONE, "No range correction"),
new Tag(RANGE_BOUNDS, "Correct based on min/max observed")
};
/* The evaluation modes */
public static final int EVAL_TRAINING_SET = 2;
public static final int EVAL_TUNED_SPLIT = 1;
public static final int EVAL_CROSS_VALIDATION = 0;
public static final Tag [] TAGS_EVAL = {
new Tag(EVAL_TRAINING_SET, "Entire training set"),
new Tag(EVAL_TUNED_SPLIT, "Single tuned fold"),
new Tag(EVAL_CROSS_VALIDATION, "N-Fold cross validation")
};
/* How to determine which class value to optimize for */
public static final int OPTIMIZE_0 = 0;
public static final int OPTIMIZE_1 = 1;
public static final int OPTIMIZE_LFREQ = 2;
public static final int OPTIMIZE_MFREQ = 3;
public static final int OPTIMIZE_POS_NAME = 4;
public static final Tag [] TAGS_OPTIMIZE = {
new Tag(OPTIMIZE_0, "First class value"),
new Tag(OPTIMIZE_1, "Second class value"),
new Tag(OPTIMIZE_LFREQ, "Least frequent class value"),
new Tag(OPTIMIZE_MFREQ, "Most frequent class value"),
new Tag(OPTIMIZE_POS_NAME, "Class value named: \"yes\", \"pos(itive)\",\"1\"")
};
/** The generated base classifier */
protected Classifier m_Classifier = new weka.classifiers.functions.Logistic();
/** The upper threshold used as the basis of correction */
protected double m_HighThreshold = 1;
/** The lower threshold used as the basis of correction */
protected double m_LowThreshold = 0;
/** The threshold that lead to the best performance */
protected double m_BestThreshold = -Double.MAX_VALUE;
/** The best value that has been observed */
protected double m_BestValue = - Double.MAX_VALUE;
/** The number of folds used in cross-validation */
protected int m_NumXValFolds = 3;
/** Random number seed */
protected int m_Seed = 1;
/** Designated class value, determined during building */
protected int m_DesignatedClass = 0;
/** Method to determine which class to optimize for */
protected int m_ClassMode = OPTIMIZE_POS_NAME;
/** The evaluation mode */
protected int m_EvalMode = EVAL_TUNED_SPLIT;
/** The range correction mode */
protected int m_RangeMode = RANGE_NONE;
/** The minimum value for the criterion. If threshold adjustment
yields less than that, the default threshold of 0.5 is used. */
protected static final double MIN_VALUE = 0.05;
/**
* Collects the classifier predictions using the specified evaluation method.
*
* @param instances the set of <code>Instances</code> to generate
* predictions for.
* @param mode the evaluation mode.
* @param numFolds the number of folds to use if not evaluating on the
* full training set.
* @return a <code>FastVector</code> containing the predictions.
* @exception Exception if an error occurs generating the predictions.
*/
protected FastVector getPredictions(Instances instances, int mode, int numFolds)
throws Exception {
EvaluationUtils eu = new EvaluationUtils();
eu.setSeed(m_Seed);
switch (mode) {
case EVAL_TUNED_SPLIT:
Instances trainData = null, evalData = null;
Instances data = new Instances(instances);
Random random = new Random(m_Seed);
data.randomize(random);
data.stratify(numFolds);
// Make sure that both subsets contain at least one positive instance
for (int subsetIndex = 0; subsetIndex < numFolds; subsetIndex++) {
trainData = data.trainCV(numFolds, subsetIndex, random);
evalData = data.testCV(numFolds, subsetIndex);
if (checkForInstance(trainData) && checkForInstance(evalData)) {
break;
}
}
return eu.getTrainTestPredictions(m_Classifier, trainData, evalData);
case EVAL_TRAINING_SET:
return eu.getTrainTestPredictions(m_Classifier, instances, instances);
case EVAL_CROSS_VALIDATION:
return eu.getCVPredictions(m_Classifier, instances, numFolds);
default:
throw new RuntimeException("Unrecognized evaluation mode");
}
}
/**
* Finds the best threshold, this implementation searches for the
* highest FMeasure. If no FMeasure higher than MIN_VALUE is found,
* the default threshold of 0.5 is used.
*
* @param predictions a <code>FastVector</code> containing the predictions.
*/
protected void findThreshold(FastVector predictions) {
Instances curve = (new ThresholdCurve()).getCurve(predictions, m_DesignatedClass);
//System.err.println(curve);
double low = 1.0;
double high = 0.0;
if (curve.numInstances() > 0) {
Instance maxFM = curve.instance(0);
int indexFM = curve.attribute(ThresholdCurve.FMEASURE_NAME).index();
int indexThreshold = curve.attribute(ThresholdCurve.THRESHOLD_NAME).index();
for (int i = 1; i < curve.numInstances(); i++) {
Instance current = curve.instance(i);
if (current.value(indexFM) > maxFM.value(indexFM)) {
maxFM = current;
}
if (m_RangeMode == RANGE_BOUNDS) {
double thresh = current.value(indexThreshold);
if (thresh < low) {
low = thresh;
}
if (thresh > high) {
high = thresh;
}
}
}
if (maxFM.value(indexFM) > MIN_VALUE) {
m_BestThreshold = maxFM.value(indexThreshold);
m_BestValue = maxFM.value(indexFM);
//System.err.println("maxFM: " + maxFM);
}
if (m_RangeMode == RANGE_BOUNDS) {
m_LowThreshold = low;
m_HighThreshold = high;
//System.err.println("Threshold range: " + low + " - " + high);
}
}
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(6);
newVector.addElement(new Option(
"\tThe class for which threshold is determined. Valid values are:\n" +
"\t1, 2 (for first and second classes, respectively), 3 (for whichever\n" +
"\tclass is least frequent), and 4 (for whichever class value is most\n" +
"\tfrequent), and 5 (for the first class named any of \"yes\",\"pos(itive)\"\n" +
"\t\"1\", or method 3 if no matches). (default 5).",
"C", 1, "-C <integer>"));
newVector.addElement(new Option(
"\tFull name of classifier to perform parameter selection on.\n"
+ "\teg: weka.classifiers.bayes.NaiveBayes",
"W", 1, "-W <classifier class name>"));
newVector.addElement(new Option(
"\tNumber of folds used for cross validation. If just a\n" +
"\thold-out set is used, this determines the size of the hold-out set\n" +
"\t(default 3).",
"X", 1, "-X <number of folds>"));
newVector.addElement(new Option(
"\tSets whether confidence range correction is applied. This\n" +
"\tcan be used to ensure the confidences range from 0 to 1.\n" +
"\tUse 0 for no range correction, 1 for correction based on\n" +
"\tthe min/max values seen during threshold selection\n"+
"\t(default 0).",
"R", 1, "-R <integer>"));
newVector.addElement(new Option(
"\tSets the random number seed (default 1).",
"S", 1, "-S <random number seed>"));
newVector.addElement(new Option(
"\tSets the evaluation mode. Use 0 for\n" +
"\tevaluation using cross-validation,\n" +
"\t1 for evaluation using hold-out set,\n" +
"\tand 2 for evaluation on the\n" +
"\ttraining data (default 1).",
"E", 1, "-E <integer>"));
if ((m_Classifier != null) &&
(m_Classifier instanceof OptionHandler)) {
newVector.addElement(new Option("",
"", 0,
"\nOptions specific to sub-classifier "
+ m_Classifier.getClass().getName()
+ ":\n(use -- to signal start of sub-classifier options)"));
Enumeration em = ((OptionHandler)m_Classifier).listOptions();
while (em.hasMoreElements()) {
newVector.addElement(em.nextElement());
}
}
return newVector.elements();
}
/**
* Parses a given list of options. Valid options are:<p>
*
* -C num <br>
* The class for which threshold is determined. Valid values are:
* 1, 2 (for first and second classes, respectively), 3 (for whichever
* class is least frequent), 4 (for whichever class value is most
* frequent), and 5 (for the first class named any of "yes","pos(itive)",
* "1", or method 3 if no matches). (default 3). <p>
*
* -W classname <br>
* Specify the full class name of classifier to perform cross-validation
* selection on.<p>
*
* -X num <br>
* Number of folds used for cross validation. If just a
* hold-out set is used, this determines the size of the hold-out set
* (default 3).<p>
*
* -R integer <br>
* Sets whether confidence range correction is applied. This can be used
* to ensure the confidences range from 0 to 1. Use 0 for no range correction,
* 1 for correction based on the min/max values seen during threshold
* selection (default 0).<p>
*
* -S seed <br>
* Random number seed (default 1).<p>
*
* -E integer <br>
* Sets the evaluation mode. Use 0 for evaluation using cross-validation,
* 1 for evaluation using hold-out set, and 2 for evaluation on the
* training data (default 1).<p>
*
* Options after -- are passed to the designated sub-classifier. <p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String classString = Utils.getOption('C', options);
if (classString.length() != 0) {
setDesignatedClass(new SelectedTag(Integer.parseInt(classString) - 1,
TAGS_OPTIMIZE));
} else {
setDesignatedClass(new SelectedTag(OPTIMIZE_LFREQ, TAGS_OPTIMIZE));
}
String modeString = Utils.getOption('E', options);
if (modeString.length() != 0) {
setEvaluationMode(new SelectedTag(Integer.parseInt(modeString),
TAGS_EVAL));
} else {
setEvaluationMode(new SelectedTag(EVAL_TUNED_SPLIT, TAGS_EVAL));
}
String rangeString = Utils.getOption('R', options);
if (rangeString.length() != 0) {
setRangeCorrection(new SelectedTag(Integer.parseInt(rangeString) - 1,
TAGS_RANGE));
} else {
setRangeCorrection(new SelectedTag(RANGE_NONE, TAGS_RANGE));
}
String foldsString = Utils.getOption('X', options);
if (foldsString.length() != 0) {
setNumXValFolds(Integer.parseInt(foldsString));
} else {
setNumXValFolds(3);
}
String randomString = Utils.getOption('S', options);
if (randomString.length() != 0) {
setSeed(Integer.parseInt(randomString));
} else {
setSeed(1);
}
String classifierName = Utils.getOption('W', options);
if (classifierName.length() == 0) {
throw new Exception("A classifier must be specified with"
+ " the -W option.");
}
setClassifier(Classifier.forName(classifierName,
Utils.partitionOptions(options)));
}
/**
* Gets the current settings of the Classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String [] getOptions() {
String [] classifierOptions = new String [0];
if ((m_Classifier != null) &&
(m_Classifier instanceof OptionHandler)) {
classifierOptions = ((OptionHandler)m_Classifier).getOptions();
}
int current = 0;
String [] options = new String [classifierOptions.length + 13];
options[current++] = "-C"; options[current++] = "" + (m_DesignatedClass + 1);
options[current++] = "-X"; options[current++] = "" + getNumXValFolds();
options[current++] = "-S"; options[current++] = "" + getSeed();
if (getClassifier() != null) {
options[current++] = "-W";
options[current++] = getClassifier().getClass().getName();
}
options[current++] = "-E"; options[current++] = "" + m_EvalMode;
options[current++] = "-R"; options[current++] = "" + m_RangeMode;
options[current++] = "--";
System.arraycopy(classifierOptions, 0, options, current,
classifierOptions.length);
current += classifierOptions.length;
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Generates the classifier.
*
* @param instances set of instances serving as training data
* @exception Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances instances)
throws Exception {
if (instances.numClasses() > 2) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -