📄 thresholdselector.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * ThresholdSelector.java * Copyright (C) 1999 Eibe Frank * */package weka.classifiers.meta;import weka.classifiers.RandomizableSingleClassifierEnhancer;import weka.classifiers.evaluation.EvaluationUtils;import weka.classifiers.evaluation.ThresholdCurve;import weka.core.Attribute;import weka.core.AttributeStats;import weka.core.Capabilities;import weka.core.Drawable;import weka.core.FastVector;import weka.core.Instance;import weka.core.Instances;import weka.core.Option;import weka.core.OptionHandler;import weka.core.SelectedTag;import weka.core.Tag;import weka.core.Utils;import weka.core.Capabilities.Capability;import java.util.Enumeration;import java.util.Random;import java.util.Vector;/** <!-- globalinfo-start --> * A metaclassifier that selecting a mid-point threshold on the probability output by a Classifier. The midpoint threshold is set so that a given performance measure is optimized. Currently this is the F-measure. Performance is measured either on the training data, a hold-out set or using cross-validation. In addition, the probabilities returned by the base learner can have their range expanded so that the output probabilities will reside between 0 and 1 (this is useful if the scheme normally produces probabilities in a very narrow range). * <p/> <!-- globalinfo-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -C <integer> * The class for which threshold is determined. Valid values are: * 1, 2 (for first and second classes, respectively), 3 (for whichever * class is least frequent), and 4 (for whichever class value is most * frequent), and 5 (for the first class named any of "yes","pos(itive)" * "1", or method 3 if no matches). (default 5).</pre> * * <pre> -X <number of folds> * Number of folds used for cross validation. If just a * hold-out set is used, this determines the size of the hold-out set * (default 3).</pre> * * <pre> -R <integer> * Sets whether confidence range correction is applied. This * can be used to ensure the confidences range from 0 to 1. * Use 0 for no range correction, 1 for correction based on * the min/max values seen during threshold selection * (default 0).</pre> * * <pre> -E <integer> * Sets the evaluation mode. Use 0 for * evaluation using cross-validation, * 1 for evaluation using hold-out set, * and 2 for evaluation on the * training data (default 1).</pre> * * <pre> -M [FMEASURE|ACCURACY|TRUE_POS|TRUE_NEG|TP_RATE|PRECISION|RECALL] * Measure used for evaluation (default is FMEASURE). * </pre> * * <pre> -S <num> * Random number seed. * (default 1)</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -W * Full name of base classifier. * (default: weka.classifiers.functions.Logistic)</pre> * * <pre> * Options specific to classifier weka.classifiers.functions.Logistic: * </pre> * * <pre> -D * Turn on debugging output.</pre> * * <pre> -R <ridge> * Set the ridge in the log-likelihood.</pre> * * <pre> -M <number> * Set the maximum number of iterations (default -1, until convergence).</pre> * <!-- options-end --> * * Options after -- are passed to the designated sub-classifier. <p> * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @version $Revision: 1.37 $ */public class ThresholdSelector extends RandomizableSingleClassifierEnhancer implements OptionHandler, Drawable { /** for serialization */ static final long serialVersionUID = -1795038053239867444L; /** no range correction */ public static final int RANGE_NONE = 0; /** Correct based on min/max observed */ public static final int RANGE_BOUNDS = 1; /** Type of correction applied to threshold range */ public static final Tag [] TAGS_RANGE = { new Tag(RANGE_NONE, "No range correction"), new Tag(RANGE_BOUNDS, "Correct based on min/max observed") }; /** entire training set */ public static final int EVAL_TRAINING_SET = 2; /** single tuned fold */ public static final int EVAL_TUNED_SPLIT = 1; /** n-fold cross-validation */ public static final int EVAL_CROSS_VALIDATION = 0; /** The evaluation modes */ public static final Tag [] TAGS_EVAL = { new Tag(EVAL_TRAINING_SET, "Entire training set"), new Tag(EVAL_TUNED_SPLIT, "Single tuned fold"), new Tag(EVAL_CROSS_VALIDATION, "N-Fold cross validation") }; /** first class value */ public static final int OPTIMIZE_0 = 0; /** second class value */ public static final int OPTIMIZE_1 = 1; /** least frequent class value */ public static final int OPTIMIZE_LFREQ = 2; /** most frequent class value */ public static final int OPTIMIZE_MFREQ = 3; /** class value name, either 'yes' or 'pos(itive)' */ public static final int OPTIMIZE_POS_NAME = 4; /** How to determine which class value to optimize for */ public static final Tag [] TAGS_OPTIMIZE = { new Tag(OPTIMIZE_0, "First class value"), new Tag(OPTIMIZE_1, "Second class value"), new Tag(OPTIMIZE_LFREQ, "Least frequent class value"), new Tag(OPTIMIZE_MFREQ, "Most frequent class value"), new Tag(OPTIMIZE_POS_NAME, "Class value named: \"yes\", \"pos(itive)\",\"1\"") }; /** F-measure */ public static final int FMEASURE = 1; /** accuracy */ public static final int ACCURACY = 2; /** true-positive */ public static final int TRUE_POS = 3; /** true-negative */ public static final int TRUE_NEG = 4; /** true-positive rate */ public static final int TP_RATE = 5; /** precision */ public static final int PRECISION = 6; /** recall */ public static final int RECALL = 7; /** the measure to use */ public static final Tag[] TAGS_MEASURE = { new Tag(FMEASURE, "FMEASURE"), new Tag(ACCURACY, "ACCURACY"), new Tag(TRUE_POS, "TRUE_POS"), new Tag(TRUE_NEG, "TRUE_NEG"), new Tag(TP_RATE, "TP_RATE"), new Tag(PRECISION, "PRECISION"), new Tag(RECALL, "RECALL") }; /** The upper threshold used as the basis of correction */ protected double m_HighThreshold = 1; /** The lower threshold used as the basis of correction */ protected double m_LowThreshold = 0; /** The threshold that lead to the best performance */ protected double m_BestThreshold = -Double.MAX_VALUE; /** The best value that has been observed */ protected double m_BestValue = - Double.MAX_VALUE; /** The number of folds used in cross-validation */ protected int m_NumXValFolds = 3; /** Designated class value, determined during building */ protected int m_DesignatedClass = 0; /** Method to determine which class to optimize for */ protected int m_ClassMode = OPTIMIZE_POS_NAME; /** The evaluation mode */ protected int m_EvalMode = EVAL_TUNED_SPLIT; /** The range correction mode */ protected int m_RangeMode = RANGE_NONE; /** evaluation measure used for determining threshold **/ int m_nMeasure = FMEASURE; /** The minimum value for the criterion. If threshold adjustment yields less than that, the default threshold of 0.5 is used. */ protected static final double MIN_VALUE = 0.05; /** * Constructor. */ public ThresholdSelector() { m_Classifier = new weka.classifiers.functions.Logistic(); } /** * String describing default classifier. * * @return the default classifier classname */ protected String defaultClassifierString() { return "weka.classifiers.functions.Logistic"; } /** * Collects the classifier predictions using the specified evaluation method. * * @param instances the set of <code>Instances</code> to generate * predictions for. * @param mode the evaluation mode. * @param numFolds the number of folds to use if not evaluating on the * full training set. * @return a <code>FastVector</code> containing the predictions. * @throws Exception if an error occurs generating the predictions. */ protected FastVector getPredictions(Instances instances, int mode, int numFolds) throws Exception { EvaluationUtils eu = new EvaluationUtils(); eu.setSeed(m_Seed); switch (mode) { case EVAL_TUNED_SPLIT: Instances trainData = null, evalData = null; Instances data = new Instances(instances); Random random = new Random(m_Seed); data.randomize(random); data.stratify(numFolds); // Make sure that both subsets contain at least one positive instance for (int subsetIndex = 0; subsetIndex < numFolds; subsetIndex++) { trainData = data.trainCV(numFolds, subsetIndex, random); evalData = data.testCV(numFolds, subsetIndex); if (checkForInstance(trainData) && checkForInstance(evalData)) { break; } } return eu.getTrainTestPredictions(m_Classifier, trainData, evalData); case EVAL_TRAINING_SET: return eu.getTrainTestPredictions(m_Classifier, instances, instances); case EVAL_CROSS_VALIDATION: return eu.getCVPredictions(m_Classifier, instances, numFolds); default: throw new RuntimeException("Unrecognized evaluation mode"); } } /** set measure used for determining threshold * @param newMeasure Tag representing measure to be used **/ public void setMeasure(SelectedTag newMeasure) { if (newMeasure.getTags() == TAGS_MEASURE) { m_nMeasure = newMeasure.getSelectedTag().getID(); } } /** get measure used for determining threshold * @return Tag representing measure used **/ public SelectedTag getMeasure() { return new SelectedTag(m_nMeasure, TAGS_MEASURE); } /** * Finds the best threshold, this implementation searches for the * highest FMeasure. If no FMeasure higher than MIN_VALUE is found, * the default threshold of 0.5 is used. * * @param predictions a <code>FastVector</code> containing the predictions. */ protected void findThreshold(FastVector predictions) { Instances curve = (new ThresholdCurve()).getCurve(predictions, m_DesignatedClass); double low = 1.0; double high = 0.0; //System.err.println(curve); if (curve.numInstances() > 0) { Instance maxInst = curve.instance(0); double maxValue = 0; int index1 = 0; int index2 = 0; switch (m_nMeasure) { case FMEASURE: index1 = curve.attribute(ThresholdCurve.FMEASURE_NAME).index(); maxValue = maxInst.value(index1);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -