⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 thresholdselector.java

📁 :<<数据挖掘--实用机器学习技术及java实现>>一书的配套源程序
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    ThresholdSelector.java *    Copyright (C) 1999 Eibe Frank * */package weka.classifiers;import java.util.Enumeration;import java.util.Random;import java.util.Vector;import weka.classifiers.evaluation.EvaluationUtils;import weka.classifiers.evaluation.ThresholdCurve;import weka.core.Attribute;import weka.core.AttributeStats;import weka.core.FastVector;import weka.core.Instance;import weka.core.Instances;import weka.core.Option;import weka.core.OptionHandler;import weka.core.SelectedTag;import weka.core.Tag;import weka.core.Utils;import weka.core.UnsupportedClassTypeException;/** * Class for selecting a threshold on a probability output by a * distribution classifier. The threshold is set so that a given * performance measure is optimized. Currently this is the * F-measure. Performance is measured either on the training data, a hold-out * set or using cross-validation. In addition, the probabilities returned * by the base learner can have their range expanded so that the output * probabilities will reside between 0 and 1 (this is useful if the scheme * normally produces probabilities in a very narrow range).<p> * * Valid options are:<p> * * -C num <br> * The class for which threshold is determined. Valid values are: * 1, 2 (for first and second classes, respectively), 3 (for whichever * class is least frequent), 4 (for whichever class value is most  * frequent), and 5 (for the first class named any of "yes","pos(itive)", * "1", or method 3 if no matches). (default 5). <p> * * -W classname <br> * Specify the full class name of the base classifier. <p> * * -X num <br>  * Number of folds used for cross validation. If just a * hold-out set is used, this determines the size of the hold-out set * (default 3).<p> * * -R integer <br> * Sets whether confidence range correction is applied. This can be used * to ensure the confidences range from 0 to 1. Use 0 for no range correction, * 1 for correction based on the min/max values seen during threshold selection * (default 0).<p> * * -S seed <br> * Random number seed (default 1).<p> * * -E integer <br> * Sets the evaluation mode. Use 0 for evaluation using cross-validation, * 1 for evaluation using hold-out set, and 2 for evaluation on the * training data (default 1).<p> * * Options after -- are passed to the designated sub-classifier. <p> * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @version $Revision: 1.22 $  */public class ThresholdSelector extends DistributionClassifier   implements OptionHandler {  /* Type of correction applied to threshold range */   public final static int RANGE_NONE = 0;  public final static int RANGE_BOUNDS = 1;  public static final Tag [] TAGS_RANGE = {    new Tag(RANGE_NONE, "No range correction"),    new Tag(RANGE_BOUNDS, "Correct based on min/max observed")  };  /* The evaluation modes */  public final static int EVAL_TRAINING_SET = 2;  public final static int EVAL_TUNED_SPLIT = 1;  public final static int EVAL_CROSS_VALIDATION = 0;  public static final Tag [] TAGS_EVAL = {    new Tag(EVAL_TRAINING_SET, "Entire training set"),    new Tag(EVAL_TUNED_SPLIT, "Single tuned fold"),    new Tag(EVAL_CROSS_VALIDATION, "N-Fold cross validation")  };  /* How to determine which class value to optimize for */  public final static int OPTIMIZE_0     = 0;  public final static int OPTIMIZE_1     = 1;  public final static int OPTIMIZE_LFREQ = 2;  public final static int OPTIMIZE_MFREQ = 3;  public final static int OPTIMIZE_POS_NAME = 4;  public static final Tag [] TAGS_OPTIMIZE = {    new Tag(OPTIMIZE_0, "First class value"),    new Tag(OPTIMIZE_1, "Second class value"),    new Tag(OPTIMIZE_LFREQ, "Least frequent class value"),    new Tag(OPTIMIZE_MFREQ, "Most frequent class value"),    new Tag(OPTIMIZE_POS_NAME, "Class value named: \"yes\", \"pos(itive)\",\"1\"")  };  /** The generated base classifier */  protected DistributionClassifier m_Classifier =     new weka.classifiers.ZeroR();  /** The upper threshold used as the basis of correction */  protected double m_HighThreshold = 1;  /** The lower threshold used as the basis of correction */  protected double m_LowThreshold = 0;  /** The threshold that lead to the best performance */  protected double m_BestThreshold = -Double.MAX_VALUE;  /** The best value that has been observed */  protected double m_BestValue = - Double.MAX_VALUE;    /** The number of folds used in cross-validation */  protected int m_NumXValFolds = 3;  /** Random number seed */  protected int m_Seed = 1;  /** Designated class value, determined during building */  protected int m_DesignatedClass = 0;  /** Method to determine which class to optimize for */  protected int m_ClassMode = OPTIMIZE_POS_NAME;  /** The evaluation mode */  protected int m_EvalMode = EVAL_TUNED_SPLIT;  /** The range correction mode */  protected int m_RangeMode = RANGE_NONE;  /** The minimum value for the criterion. If threshold adjustment      yields less than that, the default threshold of 0.5 is used. */  protected final static double MIN_VALUE = 0.05;  /**   * Collects the classifier predictions using the specified evaluation method.   *   * @param instances the set of <code>Instances</code> to generate   * predictions for.   * @param mode the evaluation mode.   * @param numFolds the number of folds to use if not evaluating on the   * full training set.   * @return a <code>FastVector</code> containing the predictions.   * @exception Exception if an error occurs generating the predictions.   */  protected FastVector getPredictions(Instances instances, int mode, int numFolds)     throws Exception {    EvaluationUtils eu = new EvaluationUtils();    eu.setSeed(m_Seed);        switch (mode) {    case EVAL_TUNED_SPLIT:      Instances trainData = null, evalData = null;      Instances data = new Instances(instances);      data.randomize(new Random(m_Seed));      data.stratify(numFolds);            // Make sure that both subsets contain at least one positive instance      for (int subsetIndex = 0; subsetIndex < numFolds; subsetIndex++) {        trainData = data.trainCV(numFolds, subsetIndex);        evalData = data.testCV(numFolds, subsetIndex);        if (checkForInstance(trainData) && checkForInstance(evalData)) {          break;        }      }      return eu.getTrainTestPredictions(m_Classifier, trainData, evalData);    case EVAL_TRAINING_SET:      return eu.getTrainTestPredictions(m_Classifier, instances, instances);    case EVAL_CROSS_VALIDATION:      return eu.getCVPredictions(m_Classifier, instances, numFolds);    default:      throw new RuntimeException("Unrecognized evaluation mode");    }  }  /**   * Finds the best threshold, this implementation searches for the   * highest FMeasure. If no FMeasure higher than MIN_VALUE is found,   * the default threshold of 0.5 is used.   *   * @param predictions a <code>FastVector</code> containing the predictions.   */  protected void findThreshold(FastVector predictions) {    Instances curve = (new ThresholdCurve()).getCurve(predictions, m_DesignatedClass);    //System.err.println(curve);    double low = 1.0;    double high = 0.0;    if (curve.numInstances() > 0) {      Instance maxFM = curve.instance(0);      int indexFM = curve.attribute(ThresholdCurve.FMEASURE_NAME).index();      int indexThreshold = curve.attribute(ThresholdCurve.THRESHOLD_NAME).index();      for (int i = 1; i < curve.numInstances(); i++) {        Instance current = curve.instance(i);        if (current.value(indexFM) > maxFM.value(indexFM)) {          maxFM = current;        }        if (m_RangeMode == RANGE_BOUNDS) {          double thresh = current.value(indexThreshold);          if (thresh < low) {            low = thresh;          }          if (thresh > high) {            high = thresh;          }        }      }      if (maxFM.value(indexFM) > MIN_VALUE) {        m_BestThreshold = maxFM.value(indexThreshold);        m_BestValue = maxFM.value(indexFM);        //System.err.println("maxFM: " + maxFM);      }      if (m_RangeMode == RANGE_BOUNDS) {        m_LowThreshold = low;        m_HighThreshold = high;        //System.err.println("Threshold range: " + low + " - " + high);      }    }  }  /**   * Returns an enumeration describing the available options   *   * @return an enumeration of all the available options   */  public Enumeration listOptions() {    Vector newVector = new Vector(6);    newVector.addElement(new Option(              "\tThe class for which threshold is determined. Valid values are:\n" +              "\t1, 2 (for first and second classes, respectively), 3 (for whichever\n" +              "\tclass is least frequent), and 4 (for whichever class value is most\n" +              "\tfrequent), and 5 (for the first class named any of \"yes\",\"pos(itive)\"\n" +              "\t\"1\", or method 3 if no matches). (default 5).",	      "C", 1, "-C <integer>"));    newVector.addElement(new Option(	      "\tFull name of classifier to perform parameter selection on.\n"	      + "\teg: weka.classifiers.NaiveBayes",	      "W", 1, "-W <classifier class name>"));    newVector.addElement(new Option(	      "\tNumber of folds used for cross validation. If just a\n" +	      "\thold-out set is used, this determines the size of the hold-out set\n" +	      "\t(default 3).",	      "X", 1, "-X <number of folds>"));    newVector.addElement(new Option(	      "\tSets whether confidence range correction is applied. This\n" +              "\tcan be used to ensure the confidences range from 0 to 1.\n" +              "\tUse 0 for no range correction, 1 for correction based on\n" +              "\tthe min/max values seen during threshold selection\n"+              "\t(default 0).",	      "R", 1, "-R <integer>"));    newVector.addElement(new Option(	      "\tSets the random number seed (default 1).",	      "S", 1, "-S <random number seed>"));    newVector.addElement(new Option(	      "\tSets the evaluation mode. Use 0 for\n" +	      "\tevaluation using cross-validation,\n" +	      "\t1 for evaluation using hold-out set,\n" +	      "\tand 2 for evaluation on the\n" +	      "\ttraining data (default 1).",	      "E", 1, "-E <integer>"));    if ((m_Classifier != null) &&	(m_Classifier instanceof OptionHandler)) {      newVector.addElement(new Option("",	        "", 0,		"\nOptions specific to sub-classifier "	        + m_Classifier.getClass().getName()		+ ":\n(use -- to signal start of sub-classifier options)"));      Enumeration enum = ((OptionHandler)m_Classifier).listOptions();      while (enum.hasMoreElements()) {	newVector.addElement(enum.nextElement());      }    }    return newVector.elements();  }  /**   * Parses a given list of options. Valid options are:<p>   *   * -C num <br>   * The class for which threshold is determined. Valid values are:   * 1, 2 (for first and second classes, respectively), 3 (for whichever   * class is least frequent), 4 (for whichever class value is most    * frequent), and 5 (for the first class named any of "yes","pos(itive)",   * "1", or method 3 if no matches). (default 3). <p>   *   * -W classname <br>   * Specify the full class name of classifier to perform cross-validation   * selection on.<p>   *   * -X num <br>    * Number of folds used for cross validation. If just a   * hold-out set is used, this determines the size of the hold-out set   * (default 3).<p>   *   * -R integer <br>   * Sets whether confidence range correction is applied. This can be used   * to ensure the confidences range from 0 to 1. Use 0 for no range correction,   * 1 for correction based on the min/max values seen during threshold    * selection (default 0).<p>   *   * -S seed <br>   * Random number seed (default 1).<p>   *   * -E integer <br>   * Sets the evaluation mode. Use 0 for evaluation using cross-validation,   * 1 for evaluation using hold-out set, and 2 for evaluation on the   * training data (default 1).<p>   *   * Options after -- are passed to the designated sub-classifier. <p>   *   * @param options the list of options as an array of strings   * @exception Exception if an option is not supported   */  public void setOptions(String[] options) throws Exception {        String classString = Utils.getOption('C', options);    if (classString.length() != 0) {      setDesignatedClass(new SelectedTag(Integer.parseInt(classString) - 1,                                          TAGS_OPTIMIZE));    } else {      setDesignatedClass(new SelectedTag(OPTIMIZE_LFREQ, TAGS_OPTIMIZE));    }    String modeString = Utils.getOption('E', options);    if (modeString.length() != 0) {      setEvaluationMode(new SelectedTag(Integer.parseInt(modeString),                                          TAGS_EVAL));    } else {      setEvaluationMode(new SelectedTag(EVAL_TUNED_SPLIT, TAGS_EVAL));    }    String rangeString = Utils.getOption('R', options);    if (rangeString.length() != 0) {      setRangeCorrection(new SelectedTag(Integer.parseInt(rangeString) - 1,                                          TAGS_RANGE));    } else {      setRangeCorrection(new SelectedTag(RANGE_NONE, TAGS_RANGE));    }    String foldsString = Utils.getOption('X', options);    if (foldsString.length() != 0) {      setNumXValFolds(Integer.parseInt(foldsString));    } else {      setNumXValFolds(3);    }    String randomString = Utils.getOption('S', options);    if (randomString.length() != 0) {      setSeed(Integer.parseInt(randomString));    } else {      setSeed(1);    }    String classifierName = Utils.getOption('W', options);    if (classifierName.length() == 0) {      throw new Exception("A classifier must be specified with"			  + " the -W option.");    }    setDistributionClassifier((DistributionClassifier)Classifier.		  forName(classifierName,			  Utils.partitionOptions(options)));  }  /**   * Gets the current settings of the Classifier.   *   * @return an array of strings suitable for passing to setOptions   */  public String [] getOptions() {    String [] classifierOptions = new String [0];    if ((m_Classifier != null) && 	(m_Classifier instanceof OptionHandler)) {      classifierOptions = ((OptionHandler)m_Classifier).getOptions();    }    int current = 0;    String [] options = new String [classifierOptions.length + 13];    options[current++] = "-C"; options[current++] = "" + (m_DesignatedClass + 1);    options[current++] = "-X"; options[current++] = "" + getNumXValFolds();    options[current++] = "-S"; options[current++] = "" + getSeed();    if (getDistributionClassifier() != null) {      options[current++] = "-W";      options[current++] = getDistributionClassifier().getClass().getName();    }    options[current++] = "-E"; options[current++] = "" + m_EvalMode;    options[current++] = "-R"; options[current++] = "" + m_RangeMode;    options[current++] = "--";

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -