⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 thresholdselector.java

📁 代码是一个分类器的实现,其中使用了部分weka的源代码。可以将项目导入eclipse运行
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    ThresholdSelector.java *    Copyright (C) 1999 Eibe Frank * */package weka.classifiers.meta;import weka.classifiers.RandomizableSingleClassifierEnhancer;import weka.classifiers.evaluation.EvaluationUtils;import weka.classifiers.evaluation.ThresholdCurve;import weka.core.Attribute;import weka.core.AttributeStats;import weka.core.Capabilities;import weka.core.Drawable;import weka.core.FastVector;import weka.core.Instance;import weka.core.Instances;import weka.core.Option;import weka.core.OptionHandler;import weka.core.SelectedTag;import weka.core.Tag;import weka.core.Utils;import weka.core.Capabilities.Capability;import java.util.Enumeration;import java.util.Random;import java.util.Vector;/** <!-- globalinfo-start --> * A metaclassifier that selecting a mid-point threshold on the probability output by a Classifier. The midpoint threshold is set so that a given performance measure is optimized. Currently this is the F-measure. Performance is measured either on the training data, a hold-out set or using cross-validation. In addition, the probabilities returned by the base learner can have their range expanded so that the output probabilities will reside between 0 and 1 (this is useful if the scheme normally produces probabilities in a very narrow range). * <p/> <!-- globalinfo-end --> * <!-- options-start --> * Valid options are: <p/> *  * <pre> -C &lt;integer&gt; *  The class for which threshold is determined. Valid values are: *  1, 2 (for first and second classes, respectively), 3 (for whichever *  class is least frequent), and 4 (for whichever class value is most *  frequent), and 5 (for the first class named any of "yes","pos(itive)" *  "1", or method 3 if no matches). (default 5).</pre> *  * <pre> -X &lt;number of folds&gt; *  Number of folds used for cross validation. If just a *  hold-out set is used, this determines the size of the hold-out set *  (default 3).</pre> *  * <pre> -R &lt;integer&gt; *  Sets whether confidence range correction is applied. This *  can be used to ensure the confidences range from 0 to 1. *  Use 0 for no range correction, 1 for correction based on *  the min/max values seen during threshold selection *  (default 0).</pre> *  * <pre> -E &lt;integer&gt; *  Sets the evaluation mode. Use 0 for *  evaluation using cross-validation, *  1 for evaluation using hold-out set, *  and 2 for evaluation on the *  training data (default 1).</pre> *  * <pre> -M [FMEASURE|ACCURACY|TRUE_POS|TRUE_NEG|TP_RATE|PRECISION|RECALL] *  Measure used for evaluation (default is FMEASURE). * </pre> *  * <pre> -S &lt;num&gt; *  Random number seed. *  (default 1)</pre> *  * <pre> -D *  If set, classifier is run in debug mode and *  may output additional info to the console</pre> *  * <pre> -W *  Full name of base classifier. *  (default: weka.classifiers.functions.Logistic)</pre> *  * <pre>  * Options specific to classifier weka.classifiers.functions.Logistic: * </pre> *  * <pre> -D *  Turn on debugging output.</pre> *  * <pre> -R &lt;ridge&gt; *  Set the ridge in the log-likelihood.</pre> *  * <pre> -M &lt;number&gt; *  Set the maximum number of iterations (default -1, until convergence).</pre> *  <!-- options-end --> * * Options after -- are passed to the designated sub-classifier. <p> * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @version $Revision: 1.37 $  */public class ThresholdSelector   extends RandomizableSingleClassifierEnhancer   implements OptionHandler, Drawable {  /** for serialization */  static final long serialVersionUID = -1795038053239867444L;  /** no range correction */  public static final int RANGE_NONE = 0;  /** Correct based on min/max observed */  public static final int RANGE_BOUNDS = 1;  /** Type of correction applied to threshold range */   public static final Tag [] TAGS_RANGE = {    new Tag(RANGE_NONE, "No range correction"),    new Tag(RANGE_BOUNDS, "Correct based on min/max observed")  };  /** entire training set */  public static final int EVAL_TRAINING_SET = 2;  /** single tuned fold */  public static final int EVAL_TUNED_SPLIT = 1;  /** n-fold cross-validation */  public static final int EVAL_CROSS_VALIDATION = 0;  /** The evaluation modes */  public static final Tag [] TAGS_EVAL = {    new Tag(EVAL_TRAINING_SET, "Entire training set"),    new Tag(EVAL_TUNED_SPLIT, "Single tuned fold"),    new Tag(EVAL_CROSS_VALIDATION, "N-Fold cross validation")  };  /** first class value */  public static final int OPTIMIZE_0     = 0;  /** second class value */  public static final int OPTIMIZE_1     = 1;  /** least frequent class value */  public static final int OPTIMIZE_LFREQ = 2;  /** most frequent class value */  public static final int OPTIMIZE_MFREQ = 3;  /** class value name, either 'yes' or 'pos(itive)' */  public static final int OPTIMIZE_POS_NAME = 4;  /** How to determine which class value to optimize for */  public static final Tag [] TAGS_OPTIMIZE = {    new Tag(OPTIMIZE_0, "First class value"),    new Tag(OPTIMIZE_1, "Second class value"),    new Tag(OPTIMIZE_LFREQ, "Least frequent class value"),    new Tag(OPTIMIZE_MFREQ, "Most frequent class value"),    new Tag(OPTIMIZE_POS_NAME, "Class value named: \"yes\", \"pos(itive)\",\"1\"")  };  /** F-measure */  public static final int FMEASURE  = 1;  /** accuracy */  public static final int ACCURACY  = 2;  /** true-positive */  public static final int TRUE_POS  = 3;  /** true-negative */  public static final int TRUE_NEG  = 4;  /** true-positive rate */  public static final int TP_RATE   = 5;  /** precision */  public static final int PRECISION = 6;  /** recall */  public static final int RECALL    = 7;  /** the measure to use */  public static final Tag[] TAGS_MEASURE = {    new Tag(FMEASURE,  "FMEASURE"),    new Tag(ACCURACY,  "ACCURACY"),    new Tag(TRUE_POS,  "TRUE_POS"),    new Tag(TRUE_NEG,  "TRUE_NEG"),     new Tag(TP_RATE,   "TP_RATE"),       new Tag(PRECISION, "PRECISION"),     new Tag(RECALL,    "RECALL")  };  /** The upper threshold used as the basis of correction */  protected double m_HighThreshold = 1;  /** The lower threshold used as the basis of correction */  protected double m_LowThreshold = 0;  /** The threshold that lead to the best performance */  protected double m_BestThreshold = -Double.MAX_VALUE;  /** The best value that has been observed */  protected double m_BestValue = - Double.MAX_VALUE;    /** The number of folds used in cross-validation */  protected int m_NumXValFolds = 3;  /** Designated class value, determined during building */  protected int m_DesignatedClass = 0;  /** Method to determine which class to optimize for */  protected int m_ClassMode = OPTIMIZE_POS_NAME;  /** The evaluation mode */  protected int m_EvalMode = EVAL_TUNED_SPLIT;  /** The range correction mode */  protected int m_RangeMode = RANGE_NONE;  /** evaluation measure used for determining threshold **/  int m_nMeasure = FMEASURE;  /** The minimum value for the criterion. If threshold adjustment      yields less than that, the default threshold of 0.5 is used. */  protected static final double MIN_VALUE = 0.05;      /**   * Constructor.   */  public ThresholdSelector() {        m_Classifier = new weka.classifiers.functions.Logistic();  }  /**   * String describing default classifier.   *    * @return the default classifier classname   */  protected String defaultClassifierString() {        return "weka.classifiers.functions.Logistic";  }  /**   * Collects the classifier predictions using the specified evaluation method.   *   * @param instances the set of <code>Instances</code> to generate   * predictions for.   * @param mode the evaluation mode.   * @param numFolds the number of folds to use if not evaluating on the   * full training set.   * @return a <code>FastVector</code> containing the predictions.   * @throws Exception if an error occurs generating the predictions.   */  protected FastVector getPredictions(Instances instances, int mode, int numFolds)     throws Exception {    EvaluationUtils eu = new EvaluationUtils();    eu.setSeed(m_Seed);        switch (mode) {    case EVAL_TUNED_SPLIT:      Instances trainData = null, evalData = null;      Instances data = new Instances(instances);      Random random = new Random(m_Seed);      data.randomize(random);      data.stratify(numFolds);            // Make sure that both subsets contain at least one positive instance      for (int subsetIndex = 0; subsetIndex < numFolds; subsetIndex++) {        trainData = data.trainCV(numFolds, subsetIndex, random);        evalData = data.testCV(numFolds, subsetIndex);        if (checkForInstance(trainData) && checkForInstance(evalData)) {          break;        }      }      return eu.getTrainTestPredictions(m_Classifier, trainData, evalData);    case EVAL_TRAINING_SET:      return eu.getTrainTestPredictions(m_Classifier, instances, instances);    case EVAL_CROSS_VALIDATION:      return eu.getCVPredictions(m_Classifier, instances, numFolds);    default:      throw new RuntimeException("Unrecognized evaluation mode");    }  }    /** set measure used for determining threshold     * @param newMeasure Tag representing measure to be used     **/    public void setMeasure(SelectedTag newMeasure) {	if (newMeasure.getTags() == TAGS_MEASURE) {	    m_nMeasure = newMeasure.getSelectedTag().getID();	}    }    /** get measure used for determining threshold      * @return Tag representing measure used     **/    public SelectedTag getMeasure() {	return new SelectedTag(m_nMeasure, TAGS_MEASURE);    }  /**   * Finds the best threshold, this implementation searches for the   * highest FMeasure. If no FMeasure higher than MIN_VALUE is found,   * the default threshold of 0.5 is used.   *   * @param predictions a <code>FastVector</code> containing the predictions.   */  protected void findThreshold(FastVector predictions) {    Instances curve = (new ThresholdCurve()).getCurve(predictions, m_DesignatedClass);    double low = 1.0;    double high = 0.0;    //System.err.println(curve);    if (curve.numInstances() > 0) {      Instance maxInst = curve.instance(0);      double maxValue = 0;       int index1 = 0;      int index2 = 0;      switch (m_nMeasure) {        case FMEASURE:          index1 = curve.attribute(ThresholdCurve.FMEASURE_NAME).index();          maxValue = maxInst.value(index1);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -