📄 thresholdselector.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * ThresholdSelector.java * Copyright (C) 1999 Eibe Frank * */package weka.classifiers;import java.util.Enumeration;import java.util.Random;import java.util.Vector;import weka.classifiers.evaluation.EvaluationUtils;import weka.classifiers.evaluation.ThresholdCurve;import weka.core.Attribute;import weka.core.AttributeStats;import weka.core.FastVector;import weka.core.Instance;import weka.core.Instances;import weka.core.Option;import weka.core.OptionHandler;import weka.core.SelectedTag;import weka.core.Tag;import weka.core.Utils;import weka.core.UnsupportedClassTypeException;/** * Class for selecting a threshold on a probability output by a * distribution classifier. The threshold is set so that a given * performance measure is optimized. Currently this is the * F-measure. Performance is measured either on the training data, a hold-out * set or using cross-validation. In addition, the probabilities returned * by the base learner can have their range expanded so that the output * probabilities will reside between 0 and 1 (this is useful if the scheme * normally produces probabilities in a very narrow range).<p> * * Valid options are:<p> * * -C num <br> * The class for which threshold is determined. Valid values are: * 1, 2 (for first and second classes, respectively), 3 (for whichever * class is least frequent), 4 (for whichever class value is most * frequent), and 5 (for the first class named any of "yes","pos(itive)", * "1", or method 3 if no matches). (default 5). <p> * * -W classname <br> * Specify the full class name of the base classifier. <p> * * -X num <br> * Number of folds used for cross validation. If just a * hold-out set is used, this determines the size of the hold-out set * (default 3).<p> * * -R integer <br> * Sets whether confidence range correction is applied. This can be used * to ensure the confidences range from 0 to 1. Use 0 for no range correction, * 1 for correction based on the min/max values seen during threshold selection * (default 0).<p> * * -S seed <br> * Random number seed (default 1).<p> * * -E integer <br> * Sets the evaluation mode. Use 0 for evaluation using cross-validation, * 1 for evaluation using hold-out set, and 2 for evaluation on the * training data (default 1).<p> * * Options after -- are passed to the designated sub-classifier. <p> * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @version $Revision: 1.22 $ */public class ThresholdSelector extends DistributionClassifier implements OptionHandler { /* Type of correction applied to threshold range */ public final static int RANGE_NONE = 0; public final static int RANGE_BOUNDS = 1; public static final Tag [] TAGS_RANGE = { new Tag(RANGE_NONE, "No range correction"), new Tag(RANGE_BOUNDS, "Correct based on min/max observed") }; /* The evaluation modes */ public final static int EVAL_TRAINING_SET = 2; public final static int EVAL_TUNED_SPLIT = 1; public final static int EVAL_CROSS_VALIDATION = 0; public static final Tag [] TAGS_EVAL = { new Tag(EVAL_TRAINING_SET, "Entire training set"), new Tag(EVAL_TUNED_SPLIT, "Single tuned fold"), new Tag(EVAL_CROSS_VALIDATION, "N-Fold cross validation") }; /* How to determine which class value to optimize for */ public final static int OPTIMIZE_0 = 0; public final static int OPTIMIZE_1 = 1; public final static int OPTIMIZE_LFREQ = 2; public final static int OPTIMIZE_MFREQ = 3; public final static int OPTIMIZE_POS_NAME = 4; public static final Tag [] TAGS_OPTIMIZE = { new Tag(OPTIMIZE_0, "First class value"), new Tag(OPTIMIZE_1, "Second class value"), new Tag(OPTIMIZE_LFREQ, "Least frequent class value"), new Tag(OPTIMIZE_MFREQ, "Most frequent class value"), new Tag(OPTIMIZE_POS_NAME, "Class value named: \"yes\", \"pos(itive)\",\"1\"") }; /** The generated base classifier */ protected DistributionClassifier m_Classifier = new weka.classifiers.ZeroR(); /** The upper threshold used as the basis of correction */ protected double m_HighThreshold = 1; /** The lower threshold used as the basis of correction */ protected double m_LowThreshold = 0; /** The threshold that lead to the best performance */ protected double m_BestThreshold = -Double.MAX_VALUE; /** The best value that has been observed */ protected double m_BestValue = - Double.MAX_VALUE; /** The number of folds used in cross-validation */ protected int m_NumXValFolds = 3; /** Random number seed */ protected int m_Seed = 1; /** Designated class value, determined during building */ protected int m_DesignatedClass = 0; /** Method to determine which class to optimize for */ protected int m_ClassMode = OPTIMIZE_POS_NAME; /** The evaluation mode */ protected int m_EvalMode = EVAL_TUNED_SPLIT; /** The range correction mode */ protected int m_RangeMode = RANGE_NONE; /** The minimum value for the criterion. If threshold adjustment yields less than that, the default threshold of 0.5 is used. */ protected final static double MIN_VALUE = 0.05; /** * Collects the classifier predictions using the specified evaluation method. * * @param instances the set of <code>Instances</code> to generate * predictions for. * @param mode the evaluation mode. * @param numFolds the number of folds to use if not evaluating on the * full training set. * @return a <code>FastVector</code> containing the predictions. * @exception Exception if an error occurs generating the predictions. */ protected FastVector getPredictions(Instances instances, int mode, int numFolds) throws Exception { EvaluationUtils eu = new EvaluationUtils(); eu.setSeed(m_Seed); switch (mode) { case EVAL_TUNED_SPLIT: Instances trainData = null, evalData = null; Instances data = new Instances(instances); data.randomize(new Random(m_Seed)); data.stratify(numFolds); // Make sure that both subsets contain at least one positive instance for (int subsetIndex = 0; subsetIndex < numFolds; subsetIndex++) { trainData = data.trainCV(numFolds, subsetIndex); evalData = data.testCV(numFolds, subsetIndex); if (checkForInstance(trainData) && checkForInstance(evalData)) { break; } } return eu.getTrainTestPredictions(m_Classifier, trainData, evalData); case EVAL_TRAINING_SET: return eu.getTrainTestPredictions(m_Classifier, instances, instances); case EVAL_CROSS_VALIDATION: return eu.getCVPredictions(m_Classifier, instances, numFolds); default: throw new RuntimeException("Unrecognized evaluation mode"); } } /** * Finds the best threshold, this implementation searches for the * highest FMeasure. If no FMeasure higher than MIN_VALUE is found, * the default threshold of 0.5 is used. * * @param predictions a <code>FastVector</code> containing the predictions. */ protected void findThreshold(FastVector predictions) { Instances curve = (new ThresholdCurve()).getCurve(predictions, m_DesignatedClass); //System.err.println(curve); double low = 1.0; double high = 0.0; if (curve.numInstances() > 0) { Instance maxFM = curve.instance(0); int indexFM = curve.attribute(ThresholdCurve.FMEASURE_NAME).index(); int indexThreshold = curve.attribute(ThresholdCurve.THRESHOLD_NAME).index(); for (int i = 1; i < curve.numInstances(); i++) { Instance current = curve.instance(i); if (current.value(indexFM) > maxFM.value(indexFM)) { maxFM = current; } if (m_RangeMode == RANGE_BOUNDS) { double thresh = current.value(indexThreshold); if (thresh < low) { low = thresh; } if (thresh > high) { high = thresh; } } } if (maxFM.value(indexFM) > MIN_VALUE) { m_BestThreshold = maxFM.value(indexThreshold); m_BestValue = maxFM.value(indexFM); //System.err.println("maxFM: " + maxFM); } if (m_RangeMode == RANGE_BOUNDS) { m_LowThreshold = low; m_HighThreshold = high; //System.err.println("Threshold range: " + low + " - " + high); } } } /** * Returns an enumeration describing the available options * * @return an enumeration of all the available options */ public Enumeration listOptions() { Vector newVector = new Vector(6); newVector.addElement(new Option( "\tThe class for which threshold is determined. Valid values are:\n" + "\t1, 2 (for first and second classes, respectively), 3 (for whichever\n" + "\tclass is least frequent), and 4 (for whichever class value is most\n" + "\tfrequent), and 5 (for the first class named any of \"yes\",\"pos(itive)\"\n" + "\t\"1\", or method 3 if no matches). (default 5).", "C", 1, "-C <integer>")); newVector.addElement(new Option( "\tFull name of classifier to perform parameter selection on.\n" + "\teg: weka.classifiers.NaiveBayes", "W", 1, "-W <classifier class name>")); newVector.addElement(new Option( "\tNumber of folds used for cross validation. If just a\n" + "\thold-out set is used, this determines the size of the hold-out set\n" + "\t(default 3).", "X", 1, "-X <number of folds>")); newVector.addElement(new Option( "\tSets whether confidence range correction is applied. This\n" + "\tcan be used to ensure the confidences range from 0 to 1.\n" + "\tUse 0 for no range correction, 1 for correction based on\n" + "\tthe min/max values seen during threshold selection\n"+ "\t(default 0).", "R", 1, "-R <integer>")); newVector.addElement(new Option( "\tSets the random number seed (default 1).", "S", 1, "-S <random number seed>")); newVector.addElement(new Option( "\tSets the evaluation mode. Use 0 for\n" + "\tevaluation using cross-validation,\n" + "\t1 for evaluation using hold-out set,\n" + "\tand 2 for evaluation on the\n" + "\ttraining data (default 1).", "E", 1, "-E <integer>")); if ((m_Classifier != null) && (m_Classifier instanceof OptionHandler)) { newVector.addElement(new Option("", "", 0, "\nOptions specific to sub-classifier " + m_Classifier.getClass().getName() + ":\n(use -- to signal start of sub-classifier options)")); Enumeration enum = ((OptionHandler)m_Classifier).listOptions(); while (enum.hasMoreElements()) { newVector.addElement(enum.nextElement()); } } return newVector.elements(); } /** * Parses a given list of options. Valid options are:<p> * * -C num <br> * The class for which threshold is determined. Valid values are: * 1, 2 (for first and second classes, respectively), 3 (for whichever * class is least frequent), 4 (for whichever class value is most * frequent), and 5 (for the first class named any of "yes","pos(itive)", * "1", or method 3 if no matches). (default 3). <p> * * -W classname <br> * Specify the full class name of classifier to perform cross-validation * selection on.<p> * * -X num <br> * Number of folds used for cross validation. If just a * hold-out set is used, this determines the size of the hold-out set * (default 3).<p> * * -R integer <br> * Sets whether confidence range correction is applied. This can be used * to ensure the confidences range from 0 to 1. Use 0 for no range correction, * 1 for correction based on the min/max values seen during threshold * selection (default 0).<p> * * -S seed <br> * Random number seed (default 1).<p> * * -E integer <br> * Sets the evaluation mode. Use 0 for evaluation using cross-validation, * 1 for evaluation using hold-out set, and 2 for evaluation on the * training data (default 1).<p> * * Options after -- are passed to the designated sub-classifier. <p> * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String classString = Utils.getOption('C', options); if (classString.length() != 0) { setDesignatedClass(new SelectedTag(Integer.parseInt(classString) - 1, TAGS_OPTIMIZE)); } else { setDesignatedClass(new SelectedTag(OPTIMIZE_LFREQ, TAGS_OPTIMIZE)); } String modeString = Utils.getOption('E', options); if (modeString.length() != 0) { setEvaluationMode(new SelectedTag(Integer.parseInt(modeString), TAGS_EVAL)); } else { setEvaluationMode(new SelectedTag(EVAL_TUNED_SPLIT, TAGS_EVAL)); } String rangeString = Utils.getOption('R', options); if (rangeString.length() != 0) { setRangeCorrection(new SelectedTag(Integer.parseInt(rangeString) - 1, TAGS_RANGE)); } else { setRangeCorrection(new SelectedTag(RANGE_NONE, TAGS_RANGE)); } String foldsString = Utils.getOption('X', options); if (foldsString.length() != 0) { setNumXValFolds(Integer.parseInt(foldsString)); } else { setNumXValFolds(3); } String randomString = Utils.getOption('S', options); if (randomString.length() != 0) { setSeed(Integer.parseInt(randomString)); } else { setSeed(1); } String classifierName = Utils.getOption('W', options); if (classifierName.length() == 0) { throw new Exception("A classifier must be specified with" + " the -W option."); } setDistributionClassifier((DistributionClassifier)Classifier. forName(classifierName, Utils.partitionOptions(options))); } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String [] classifierOptions = new String [0]; if ((m_Classifier != null) && (m_Classifier instanceof OptionHandler)) { classifierOptions = ((OptionHandler)m_Classifier).getOptions(); } int current = 0; String [] options = new String [classifierOptions.length + 13]; options[current++] = "-C"; options[current++] = "" + (m_DesignatedClass + 1); options[current++] = "-X"; options[current++] = "" + getNumXValFolds(); options[current++] = "-S"; options[current++] = "" + getSeed(); if (getDistributionClassifier() != null) { options[current++] = "-W"; options[current++] = getDistributionClassifier().getClass().getName(); } options[current++] = "-E"; options[current++] = "" + m_EvalMode; options[current++] = "-R"; options[current++] = "" + m_RangeMode; options[current++] = "--";
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -