⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 thresholdselector.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    ThresholdSelector.java
 *    Copyright (C) 1999 Eibe Frank
 *
 */

package weka.classifiers.meta;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.EvaluationUtils;
import weka.classifiers.evaluation.ThresholdCurve;
import weka.core.Attribute;
import weka.core.AttributeStats;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.UnsupportedClassTypeException;
import weka.core.Utils;

/**
 * Class for selecting a threshold on a probability output by a
 * distribution classifier. The threshold is set so that a given
 * performance measure is optimized. Currently this is the
 * F-measure. Performance is measured either on the training data, a hold-out
 * set or using cross-validation. In addition, the probabilities returned
 * by the base learner can have their range expanded so that the output
 * probabilities will reside between 0 and 1 (this is useful if the scheme
 * normally produces probabilities in a very narrow range).<p>
 *
 * Valid options are:<p>
 *
 * -C num <br>
 * The class for which threshold is determined. Valid values are:
 * 1, 2 (for first and second classes, respectively), 3 (for whichever
 * class is least frequent), 4 (for whichever class value is most 
 * frequent), and 5 (for the first class named any of "yes","pos(itive)",
 * "1", or method 3 if no matches). (default 5). <p>
 *
 * -W classname <br>
 * Specify the full class name of the base classifier. <p>
 *
 * -X num <br> 
 * Number of folds used for cross validation. If just a
 * hold-out set is used, this determines the size of the hold-out set
 * (default 3).<p>
 *
 * -R integer <br>
 * Sets whether confidence range correction is applied. This can be used
 * to ensure the confidences range from 0 to 1. Use 0 for no range correction,
 * 1 for correction based on the min/max values seen during threshold selection
 * (default 0).<p>
 *
 * -S seed <br>
 * Random number seed (default 1).<p>
 *
 * -E integer <br>
 * Sets the evaluation mode. Use 0 for evaluation using cross-validation,
 * 1 for evaluation using hold-out set, and 2 for evaluation on the
 * training data (default 1).<p>
 *
 * Options after -- are passed to the designated sub-classifier. <p>
 *
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @version $Revision$ 
 */
public class ThresholdSelector extends Classifier 
  implements OptionHandler, Drawable {

  /* Type of correction applied to threshold range */ 
  public static final int RANGE_NONE = 0;
  public static final int RANGE_BOUNDS = 1;
  public static final Tag [] TAGS_RANGE = {
    new Tag(RANGE_NONE, "No range correction"),
    new Tag(RANGE_BOUNDS, "Correct based on min/max observed")
  };

  /* The evaluation modes */
  public static final int EVAL_TRAINING_SET = 2;
  public static final int EVAL_TUNED_SPLIT = 1;
  public static final int EVAL_CROSS_VALIDATION = 0;
  public static final Tag [] TAGS_EVAL = {
    new Tag(EVAL_TRAINING_SET, "Entire training set"),
    new Tag(EVAL_TUNED_SPLIT, "Single tuned fold"),
    new Tag(EVAL_CROSS_VALIDATION, "N-Fold cross validation")
  };

  /* How to determine which class value to optimize for */
  public static final int OPTIMIZE_0     = 0;
  public static final int OPTIMIZE_1     = 1;
  public static final int OPTIMIZE_LFREQ = 2;
  public static final int OPTIMIZE_MFREQ = 3;
  public static final int OPTIMIZE_POS_NAME = 4;
  public static final Tag [] TAGS_OPTIMIZE = {
    new Tag(OPTIMIZE_0, "First class value"),
    new Tag(OPTIMIZE_1, "Second class value"),
    new Tag(OPTIMIZE_LFREQ, "Least frequent class value"),
    new Tag(OPTIMIZE_MFREQ, "Most frequent class value"),
    new Tag(OPTIMIZE_POS_NAME, "Class value named: \"yes\", \"pos(itive)\",\"1\"")
  };

  /** The generated base classifier */
  protected Classifier m_Classifier = new weka.classifiers.functions.Logistic();

  /** The upper threshold used as the basis of correction */
  protected double m_HighThreshold = 1;

  /** The lower threshold used as the basis of correction */
  protected double m_LowThreshold = 0;

  /** The threshold that lead to the best performance */
  protected double m_BestThreshold = -Double.MAX_VALUE;

  /** The best value that has been observed */
  protected double m_BestValue = - Double.MAX_VALUE;
  
  /** The number of folds used in cross-validation */
  protected int m_NumXValFolds = 3;

  /** Random number seed */
  protected int m_Seed = 1;

  /** Designated class value, determined during building */
  protected int m_DesignatedClass = 0;

  /** Method to determine which class to optimize for */
  protected int m_ClassMode = OPTIMIZE_POS_NAME;

  /** The evaluation mode */
  protected int m_EvalMode = EVAL_TUNED_SPLIT;

  /** The range correction mode */
  protected int m_RangeMode = RANGE_NONE;

  /** The minimum value for the criterion. If threshold adjustment
      yields less than that, the default threshold of 0.5 is used. */
  protected static final double MIN_VALUE = 0.05;

  /**
   * Collects the classifier predictions using the specified evaluation method.
   *
   * @param instances the set of <code>Instances</code> to generate
   * predictions for.
   * @param mode the evaluation mode.
   * @param numFolds the number of folds to use if not evaluating on the
   * full training set.
   * @return a <code>FastVector</code> containing the predictions.
   * @exception Exception if an error occurs generating the predictions.
   */
  protected FastVector getPredictions(Instances instances, int mode, int numFolds) 
    throws Exception {

    EvaluationUtils eu = new EvaluationUtils();
    eu.setSeed(m_Seed);
    
    switch (mode) {
    case EVAL_TUNED_SPLIT:
      Instances trainData = null, evalData = null;
      Instances data = new Instances(instances);
      Random random = new Random(m_Seed);
      data.randomize(random);
      data.stratify(numFolds);
      
      // Make sure that both subsets contain at least one positive instance
      for (int subsetIndex = 0; subsetIndex < numFolds; subsetIndex++) {
        trainData = data.trainCV(numFolds, subsetIndex, random);
        evalData = data.testCV(numFolds, subsetIndex);
        if (checkForInstance(trainData) && checkForInstance(evalData)) {
          break;
        }
      }
      return eu.getTrainTestPredictions(m_Classifier, trainData, evalData);
    case EVAL_TRAINING_SET:
      return eu.getTrainTestPredictions(m_Classifier, instances, instances);
    case EVAL_CROSS_VALIDATION:
      return eu.getCVPredictions(m_Classifier, instances, numFolds);
    default:
      throw new RuntimeException("Unrecognized evaluation mode");
    }
  }

  /**
   * Finds the best threshold, this implementation searches for the
   * highest FMeasure. If no FMeasure higher than MIN_VALUE is found,
   * the default threshold of 0.5 is used.
   *
   * @param predictions a <code>FastVector</code> containing the predictions.
   */
  protected void findThreshold(FastVector predictions) {

    Instances curve = (new ThresholdCurve()).getCurve(predictions, m_DesignatedClass);

    //System.err.println(curve);
    double low = 1.0;
    double high = 0.0;
    if (curve.numInstances() > 0) {
      Instance maxFM = curve.instance(0);
      int indexFM = curve.attribute(ThresholdCurve.FMEASURE_NAME).index();
      int indexThreshold = curve.attribute(ThresholdCurve.THRESHOLD_NAME).index();
      for (int i = 1; i < curve.numInstances(); i++) {
        Instance current = curve.instance(i);
        if (current.value(indexFM) > maxFM.value(indexFM)) {
          maxFM = current;
        }
        if (m_RangeMode == RANGE_BOUNDS) {
          double thresh = current.value(indexThreshold);
          if (thresh < low) {
            low = thresh;
          }
          if (thresh > high) {
            high = thresh;
          }
        }
      }
      if (maxFM.value(indexFM) > MIN_VALUE) {
        m_BestThreshold = maxFM.value(indexThreshold);
        m_BestValue = maxFM.value(indexFM);
        //System.err.println("maxFM: " + maxFM);
      }
      if (m_RangeMode == RANGE_BOUNDS) {
        m_LowThreshold = low;
        m_HighThreshold = high;
        //System.err.println("Threshold range: " + low + " - " + high);
      }
    }
  }

  /**
   * Returns an enumeration describing the available options.
   *
   * @return an enumeration of all the available options.
   */
  public Enumeration listOptions() {

    Vector newVector = new Vector(6);

    newVector.addElement(new Option(
              "\tThe class for which threshold is determined. Valid values are:\n" +
              "\t1, 2 (for first and second classes, respectively), 3 (for whichever\n" +
              "\tclass is least frequent), and 4 (for whichever class value is most\n" +
              "\tfrequent), and 5 (for the first class named any of \"yes\",\"pos(itive)\"\n" +
              "\t\"1\", or method 3 if no matches). (default 5).",
	      "C", 1, "-C <integer>"));
    newVector.addElement(new Option(
	      "\tFull name of classifier to perform parameter selection on.\n"
	      + "\teg: weka.classifiers.bayes.NaiveBayes",
	      "W", 1, "-W <classifier class name>"));
    newVector.addElement(new Option(
	      "\tNumber of folds used for cross validation. If just a\n" +
	      "\thold-out set is used, this determines the size of the hold-out set\n" +
	      "\t(default 3).",
	      "X", 1, "-X <number of folds>"));
    newVector.addElement(new Option(
	      "\tSets whether confidence range correction is applied. This\n" +
              "\tcan be used to ensure the confidences range from 0 to 1.\n" +
              "\tUse 0 for no range correction, 1 for correction based on\n" +
              "\tthe min/max values seen during threshold selection\n"+
              "\t(default 0).",
	      "R", 1, "-R <integer>"));
    newVector.addElement(new Option(
	      "\tSets the random number seed (default 1).",
	      "S", 1, "-S <random number seed>"));
    newVector.addElement(new Option(
	      "\tSets the evaluation mode. Use 0 for\n" +
	      "\tevaluation using cross-validation,\n" +
	      "\t1 for evaluation using hold-out set,\n" +
	      "\tand 2 for evaluation on the\n" +
	      "\ttraining data (default 1).",
	      "E", 1, "-E <integer>"));

    if ((m_Classifier != null) &&
	(m_Classifier instanceof OptionHandler)) {
      newVector.addElement(new Option("",
	        "", 0,
		"\nOptions specific to sub-classifier "
	        + m_Classifier.getClass().getName()
		+ ":\n(use -- to signal start of sub-classifier options)"));
      Enumeration em = ((OptionHandler)m_Classifier).listOptions();
      while (em.hasMoreElements()) {
	newVector.addElement(em.nextElement());
      }
    }
    return newVector.elements();
  }

  /**
   * Parses a given list of options. Valid options are:<p>
   *
   * -C num <br>
   * The class for which threshold is determined. Valid values are:
   * 1, 2 (for first and second classes, respectively), 3 (for whichever
   * class is least frequent), 4 (for whichever class value is most 
   * frequent), and 5 (for the first class named any of "yes","pos(itive)",
   * "1", or method 3 if no matches). (default 3). <p>
   *
   * -W classname <br>
   * Specify the full class name of classifier to perform cross-validation
   * selection on.<p>
   *
   * -X num <br> 
   * Number of folds used for cross validation. If just a
   * hold-out set is used, this determines the size of the hold-out set
   * (default 3).<p>
   *
   * -R integer <br>
   * Sets whether confidence range correction is applied. This can be used
   * to ensure the confidences range from 0 to 1. Use 0 for no range correction,
   * 1 for correction based on the min/max values seen during threshold 
   * selection (default 0).<p>
   *
   * -S seed <br>
   * Random number seed (default 1).<p>
   *
   * -E integer <br>
   * Sets the evaluation mode. Use 0 for evaluation using cross-validation,
   * 1 for evaluation using hold-out set, and 2 for evaluation on the
   * training data (default 1).<p>
   *
   * Options after -- are passed to the designated sub-classifier. <p>
   *
   * @param options the list of options as an array of strings
   * @exception Exception if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {
    
    String classString = Utils.getOption('C', options);
    if (classString.length() != 0) {
      setDesignatedClass(new SelectedTag(Integer.parseInt(classString) - 1, 
                                         TAGS_OPTIMIZE));
    } else {
      setDesignatedClass(new SelectedTag(OPTIMIZE_LFREQ, TAGS_OPTIMIZE));
    }

    String modeString = Utils.getOption('E', options);
    if (modeString.length() != 0) {
      setEvaluationMode(new SelectedTag(Integer.parseInt(modeString), 
                                         TAGS_EVAL));
    } else {
      setEvaluationMode(new SelectedTag(EVAL_TUNED_SPLIT, TAGS_EVAL));
    }

    String rangeString = Utils.getOption('R', options);
    if (rangeString.length() != 0) {
      setRangeCorrection(new SelectedTag(Integer.parseInt(rangeString) - 1, 
                                         TAGS_RANGE));
    } else {
      setRangeCorrection(new SelectedTag(RANGE_NONE, TAGS_RANGE));
    }

    String foldsString = Utils.getOption('X', options);
    if (foldsString.length() != 0) {
      setNumXValFolds(Integer.parseInt(foldsString));
    } else {
      setNumXValFolds(3);
    }

    String randomString = Utils.getOption('S', options);
    if (randomString.length() != 0) {
      setSeed(Integer.parseInt(randomString));
    } else {
      setSeed(1);
    }

    String classifierName = Utils.getOption('W', options);
    if (classifierName.length() == 0) {
      throw new Exception("A classifier must be specified with"
			  + " the -W option.");
    }

    setClassifier(Classifier.forName(classifierName,
				     Utils.partitionOptions(options)));
  }

  /**
   * Gets the current settings of the Classifier.
   *
   * @return an array of strings suitable for passing to setOptions
   */
  public String [] getOptions() {

    String [] classifierOptions = new String [0];
    if ((m_Classifier != null) && 
	(m_Classifier instanceof OptionHandler)) {
      classifierOptions = ((OptionHandler)m_Classifier).getOptions();
    }

    int current = 0;
    String [] options = new String [classifierOptions.length + 13];

    options[current++] = "-C"; options[current++] = "" + (m_DesignatedClass + 1);
    options[current++] = "-X"; options[current++] = "" + getNumXValFolds();
    options[current++] = "-S"; options[current++] = "" + getSeed();

    if (getClassifier() != null) {
      options[current++] = "-W";
      options[current++] = getClassifier().getClass().getName();
    }
    options[current++] = "-E"; options[current++] = "" + m_EvalMode;
    options[current++] = "-R"; options[current++] = "" + m_RangeMode;
    options[current++] = "--";

    System.arraycopy(classifierOptions, 0, options, current, 
		     classifierOptions.length);
    current += classifierOptions.length;
    while (current < options.length) {
      options[current++] = "";
    }
    return options;
  }

  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data 
   * @exception Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) 
    throws Exception {

    if (instances.numClasses() > 2) {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -