⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 thresholdcurve.java

📁 :<<数据挖掘--实用机器学习技术及java实现>>一书的配套源程序
💻 JAVA
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    ThresholdCurve.java *    Copyright (C) 2000 Intelligenesis Corp. * */package weka.classifiers.evaluation;import weka.core.Utils;import weka.core.Attribute;import weka.core.FastVector;import weka.core.Instance;import weka.core.Instances;import weka.classifiers.DistributionClassifier;/** * Generates points illustrating prediction tradeoffs that can be obtained * by varying the threshold value between classes. For example, the typical  * threshold value of 0.5 means the predicted probability of "positive" must be * higher than 0.5 for the instance to be predicted as "positive". The  * resulting dataset can be used to visualize precision/recall tradeoff, or  * for ROC curve analysis (true positive rate vs false positive rate). * * @author Len Trigg (len@intelligenesis.net) * @version $Revision: 1.12 $ */public class ThresholdCurve {  /** The name of the relation used in threshold curve datasets */  public final static String RELATION_NAME = "ThresholdCurve";  public final static String TRUE_POS_NAME  = "True Positives";  public final static String FALSE_NEG_NAME = "False Negatives";  public final static String FALSE_POS_NAME = "False Positives";  public final static String TRUE_NEG_NAME  = "True Negatives";  public final static String FP_RATE_NAME   = "False Positive Rate";  public final static String TP_RATE_NAME   = "True Positive Rate";  public final static String PRECISION_NAME = "Precision";  public final static String RECALL_NAME    = "Recall";  public final static String FALLOUT_NAME   = "Fallout";  public final static String FMEASURE_NAME  = "FMeasure";  public final static String THRESHOLD_NAME = "Threshold";  /**   * Calculates the performance stats for the default class and return    * results as a set of Instances. The   * structure of these Instances is as follows:<p> <ul>    * <li> <b>True Positives </b>   * <li> <b>False Negatives</b>   * <li> <b>False Positives</b>   * <li> <b>True Negatives</b>   * <li> <b>False Positive Rate</b>   * <li> <b>True Positive Rate</b>   * <li> <b>Precision</b>   * <li> <b>Recall</b>     * <li> <b>Fallout</b>     * <li> <b>Threshold</b> contains the probability threshold that gives   * rise to the previous performance values.    * </ul> <p>   * For the definitions of these measures, see TwoClassStats <p>   *   * @see TwoClassStats   * @param classIndex index of the class of interest.   * @return datapoints as a set of instances, null if no predictions   * have been made.   */  public Instances getCurve(FastVector predictions) {    if (predictions.size() == 0) {      return null;    }    return getCurve(predictions,                     ((NominalPrediction)predictions.elementAt(0))                    .distribution().length - 1);  }  /**   * Calculates the performance stats for the desired class and return    * results as a set of Instances.   *   * @param classIndex index of the class of interest.   * @return datapoints as a set of instances.   */  public Instances getCurve(FastVector predictions, int classIndex) {    if ((predictions.size() == 0) ||        (((NominalPrediction)predictions.elementAt(0))         .distribution().length <= classIndex)) {      return null;    }    double totPos = 0, totNeg = 0;    double [] probs = getProbabilities(predictions, classIndex);    // Get distribution of positive/negatives    for (int i = 0; i < probs.length; i++) {      NominalPrediction pred = (NominalPrediction)predictions.elementAt(i);      if (pred.actual() == Prediction.MISSING_VALUE) {        System.err.println(getClass().getName()                            + " Skipping prediction with missing class value");        continue;      }      if (pred.weight() < 0) {        System.err.println(getClass().getName()                            + " Skipping prediction with negative weight");        continue;      }      if (pred.actual() == classIndex) {        totPos += pred.weight();      } else {        totNeg += pred.weight();      }    }    Instances insts = makeHeader();    int [] sorted = Utils.sort(probs);    TwoClassStats tc = new TwoClassStats(totPos, totNeg, 0, 0);    for (int i = 0; i < sorted.length; i++) {      NominalPrediction pred = (NominalPrediction)predictions.elementAt(sorted[i]);      if (pred.actual() == Prediction.MISSING_VALUE) {        System.err.println(getClass().getName()                           + " Skipping prediction with missing class value");        continue;      }      if (pred.weight() < 0) {        System.err.println(getClass().getName()                            + " Skipping prediction with negative weight");        continue;      }      if (pred.actual() == classIndex) {        tc.setTruePositive(tc.getTruePositive() - pred.weight());        tc.setFalseNegative(tc.getFalseNegative() + pred.weight());      } else {        tc.setFalsePositive(tc.getFalsePositive() - pred.weight());        tc.setTrueNegative(tc.getTrueNegative() + pred.weight());      }      /*      System.out.println(tc + " " + probs[sorted[i]]                          + " " + (pred.actual() == classIndex));      */      if ((i != (sorted.length - 1)) &&          ((i == 0) ||            (probs[sorted[i]] != probs[sorted[i - 1]]))) {        insts.add(makeInstance(tc, probs[sorted[i]]));      }    }    return insts;  }  /**   * Calculates the n point precision result, which is the precision averaged   * over n evenly spaced (w.r.t recall) samples of the curve.   *   * @param tcurve a previously extracted threshold curve Instances.   * @param n the number of points to average over.   * @return the n-point precision.   */  public static double getNPointPrecision(Instances tcurve, int n) {    if (!RELATION_NAME.equals(tcurve.relationName())         || (tcurve.numInstances() == 0)) {      return Double.NaN;    }    int recallInd = tcurve.attribute(RECALL_NAME).index();    int precisInd = tcurve.attribute(PRECISION_NAME).index();    double [] recallVals = tcurve.attributeToDoubleArray(recallInd);    int [] sorted = Utils.sort(recallVals);    double isize = 1.0 / (n - 1);    double psum = 0;    for (int i = 0; i < n; i++) {      int pos = binarySearch(sorted, recallVals, i * isize);      double recall = recallVals[sorted[pos]];      double precis = tcurve.instance(sorted[pos]).value(precisInd);      /*      System.err.println("Point " + (i + 1) + ": i=" + pos                          + " r=" + (i * isize)                         + " p'=" + precis                          + " r'=" + recall);      */      // interpolate figures for non-endpoints      while ((pos != 0) && (pos < sorted.length - 1)) {        pos++;        double recall2 = recallVals[sorted[pos]];        if (recall2 != recall) {          double precis2 = tcurve.instance(sorted[pos]).value(precisInd);          double slope = (precis2 - precis) / (recall2 - recall);          double offset = precis - recall * slope;          precis = isize * i * slope + offset;          /*          System.err.println("Point2 " + (i + 1) + ": i=" + pos                              + " r=" + (i * isize)                             + " p'=" + precis2                              + " r'=" + recall2                             + " p''=" + precis);          */          break;        }      }      psum += precis;    }    return psum / n;  }  /**   * Calculates the area under the ROC curve.  This is normalised so   * that 0.5 is random, 1.0 is perfect and 0.0 is bizarre.   *   * @param tcurve a previously extracted threshold curve Instances.   * @return the ROC area, or Double.NaN if you don't pass in    * a ThresholdCurve generated Instances.    */  public static double getROCArea(Instances tcurve) {    final int n = tcurve.numInstances();    if (!RELATION_NAME.equals(tcurve.relationName())         || (n == 0)) {      return Double.NaN;    }    final int tpInd = tcurve.attribute(TRUE_POS_NAME).index();    final int fpInd = tcurve.attribute(FALSE_POS_NAME).index();    final double [] tpVals = tcurve.attributeToDoubleArray(tpInd);    final double [] fpVals = tcurve.attributeToDoubleArray(fpInd);    final double tp0 = tpVals[0];    final double fp0 = fpVals[0];    double area = 0.0;    //starts at high values and goes down    double xlast = 1.0;    double ylast = 1.0;    for (int i = 1; i < n; i++) {      final double x = fpVals[i] / fp0;      final double y = tpVals[i] / tp0;      final double areaDelta = (y + ylast) * (xlast - x) / 2.0;      /*      System.err.println("[" + i + "]"                         + " x=" + x                         + " y'=" + y                         + " xl=" + xlast                         + " yl=" + ylast                         + " a'=" + areaDelta);      */      area += areaDelta;      xlast = x;      ylast = y;    }    //make sure ends at 0,0    if (xlast > 0.0) {      final double areaDelta = ylast * xlast / 2.0;      //System.err.println(" a'=" + areaDelta);      area += areaDelta;    }    //System.err.println(" area'=" + area);    return area;  }  /**   * Gets the index of the instance with the closest threshold value to the   * desired target   *   * @param tcurve a set of instances that have been generated by this class   * @param threshold the target threshold   * @return the index of the instance that has threshold closest to   * the target, or -1 if this could not be found (i.e. no data, or   * bad threshold target)   */  public static int getThresholdInstance(Instances tcurve, double threshold) {    if (!RELATION_NAME.equals(tcurve.relationName())         || (tcurve.numInstances() == 0)        || (threshold < 0)        || (threshold > 1.0)) {      return -1;    }    if (tcurve.numInstances() == 1) {      return 0;    }    double [] tvals = tcurve.attributeToDoubleArray(tcurve.numAttributes() - 1);    int [] sorted = Utils.sort(tvals);    return binarySearch(sorted, tvals, threshold);  }  private static int binarySearch(int [] index, double [] vals, double target) {        int lo = 0, hi = index.length - 1;    while (hi - lo > 1) {      int mid = lo + (hi - lo) / 2;      double midval = vals[index[mid]];      if (target > midval) {        lo = mid;      } else if (target < midval) {        hi = mid;      } else {        while ((mid > 0) && (vals[index[mid - 1]] == target)) {          mid --;        }        return mid;      }    }    return lo;  }  private double [] getProbabilities(FastVector predictions, int classIndex) {    // sort by predicted probability of the desired class.    double [] probs = new double [predictions.size()];    for (int i = 0; i < probs.length; i++) {      NominalPrediction pred = (NominalPrediction)predictions.elementAt(i);      probs[i] = pred.distribution()[classIndex];    }    return probs;  }  private Instances makeHeader() {    FastVector fv = new FastVector();    fv.addElement(new Attribute(TRUE_POS_NAME));    fv.addElement(new Attribute(FALSE_NEG_NAME));    fv.addElement(new Attribute(FALSE_POS_NAME));    fv.addElement(new Attribute(TRUE_NEG_NAME));    fv.addElement(new Attribute(FP_RATE_NAME));    fv.addElement(new Attribute(TP_RATE_NAME));    fv.addElement(new Attribute(PRECISION_NAME));    fv.addElement(new Attribute(RECALL_NAME));    fv.addElement(new Attribute(FALLOUT_NAME));    fv.addElement(new Attribute(FMEASURE_NAME));    fv.addElement(new Attribute(THRESHOLD_NAME));          return new Instances(RELATION_NAME, fv, 100);  }    private Instance makeInstance(TwoClassStats tc, double prob) {    int count = 0;    double [] vals = new double[11];    vals[count++] = tc.getTruePositive();    vals[count++] = tc.getFalseNegative();    vals[count++] = tc.getFalsePositive();    vals[count++] = tc.getTrueNegative();    vals[count++] = tc.getFalsePositiveRate();    vals[count++] = tc.getTruePositiveRate();    vals[count++] = tc.getPrecision();    vals[count++] = tc.getRecall();    vals[count++] = tc.getFallout();    vals[count++] = tc.getFMeasure();    vals[count++] = prob;    return new Instance(1.0, vals);  }    /**   * Tests the ThresholdCurve generation from the command line.   * The classifier is currently hardcoded. Pipe in an arff file.   *   * @param args currently ignored   */  public static void main(String [] args) {    try {            Instances inst = new Instances(new java.io.InputStreamReader(System.in));      if (false) {        System.out.println(ThresholdCurve.getNPointPrecision(inst, 11));      } else {        inst.setClassIndex(inst.numAttributes() - 1);        ThresholdCurve tc = new ThresholdCurve();        EvaluationUtils eu = new EvaluationUtils();        DistributionClassifier classifier = new weka.classifiers.SMO();        FastVector predictions = new FastVector();        for (int i = 0; i < 2; i++) { // Do two runs.          eu.setSeed(i);          predictions.appendElements(eu.getCVPredictions(classifier, inst, 10));          //System.out.println("\n\n\n");        }        Instances result = tc.getCurve(predictions);        System.out.println(result);      }    } catch (Exception ex) {      ex.printStackTrace();    }  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -