📄 thresholdfinder.java
字号:
/*
* YALE - Yet Another Learning Environment
* Copyright (C) 2001-2004
* Simon Fischer, Ralf Klinkenberg, Ingo Mierswa,
* Katharina Morik, Oliver Ritthoff
* Artificial Intelligence Unit
* Computer Science Department
* University of Dortmund
* 44221 Dortmund, Germany
* email: yale-team@lists.sourceforge.net
* web: http://yale.cs.uni-dortmund.de/
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License as
* published by the Free Software Foundation; either version 2 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
* USA.
*/
package edu.udo.cs.yale.operator;
import java.util.Arrays;
import java.util.List;
import java.util.LinkedList;
import java.util.Iterator;
import edu.udo.cs.yale.Statistics;
import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleReader;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.performance.PerformanceVector;
import edu.udo.cs.yale.operator.performance.EstimatedPerformance;
import edu.udo.cs.yale.operator.parameter.*;
import edu.udo.cs.yale.gui.SimplePlotterDialog;
/** This operator finds the best threshold for crisp classifying based on user defined costs.
*
* @version $Id$
*/
public class ThresholdFinder extends Operator {
private class WeightedConfidenceAndLabel implements Comparable {
private double confidence, label;
private double weight = 1;
public WeightedConfidenceAndLabel(double confidence, double label) {
this.confidence = confidence;
this.label = label;
}
public WeightedConfidenceAndLabel(double confidence, double label, double weight) {
this(confidence, label);
this.weight = weight;
}
public int compareTo(Object obj) {
return (-1)*Double.compare(this.confidence, ((WeightedConfidenceAndLabel) obj).confidence); // IM: (-1) * ... ???
}
public double getLabel() {
return this.label;
}
public double getConfidence() {
return this.confidence;
}
public double getWeight() {
return weight;
}
public String toString() {
return "conf: " + confidence + ", label: " + label + ", weight: " + weight;
}
}
/** Defines the maximum amount of points which is plotted in the ROC curve. */
private static final int MAX_ROC_POINTS = 200;
// The parameters of this operator:
private static final String COSTS_POS = "misclassification_costs_positive";
private static final String COSTS_NEG = "misclassification_costs_negative";
private static final String SHOW_PLOT = "show_roc_plot";
private static final String CREATE_AUC = "create_AUC_performance";
/**
* @see edu.udo.cs.yale.operator.Operator#getInputClasses()
*/
public Class[] getInputClasses() {
return new Class[] { ExampleSet.class };
}
/**
* @see edu.udo.cs.yale.operator.Operator#getOutputClasses()
*/
public Class[] getOutputClasses() {
if (getParameterAsBoolean(CREATE_AUC))
return new Class[] { ExampleSet.class, Threshold.class, PerformanceVector.class };
else
return new Class[] { ExampleSet.class, Threshold.class };
}
/**
* @see edu.udo.cs.yale.operator.Operator#apply()
*/
public IOObject[] apply() throws OperatorException {
ExampleSet exampleSet = (ExampleSet) this.getInput(ExampleSet.class);
Attribute label = exampleSet.getLabel();
if (label == null)
throw new UserError(this, 105);
if (!label.isNominal())
throw new UserError(this, 101, label, "threshold finding");
if (!label.isBooleanClassification())
throw new UserError(this, 118, new Object[] { label, new Integer(label.getValues().size()), new Integer(2) });
ExampleReader reader = exampleSet.getExampleReader();
WeightedConfidenceAndLabel[] calArray = new WeightedConfidenceAndLabel[exampleSet.getSize()];
Attribute weightAttr = exampleSet.getWeight();
int index = 0;
while (reader.hasNext()) {
Example example = (Example) reader.next();
WeightedConfidenceAndLabel wcl;
if (weightAttr == null) {
wcl = new WeightedConfidenceAndLabel(example.getPredictedLabel(), example.getLabel());
}
else wcl= new WeightedConfidenceAndLabel(example.getPredictedLabel(), example.getLabel(), example.getValue(weightAttr));
calArray[index++] = wcl;
}
Arrays.sort(calArray);
int negativeLabelIndex = label.getNegativeIndex();
int positiveLabelIndex = label.getPositiveIndex();
final double slope;
{ // The slope is defined by the ratio of positive examples and the different misclassification costs.
// The formula for the slope is (#pos / #neg) / (costs_neg / costs_pos).
double costRatio = getParameterAsDouble(COSTS_NEG) / getParameterAsDouble(COSTS_POS);
slope = costRatio;
}
// init the true positives and sum of example weights
double tp = 0.0d;
double sum = 0;
// The task is to find the isometric that crosses the TP-axis as high as possible
// The TP value of the best isometric seen so far is stored in bestIsometricsTpValue,
// the corresponding threshold is stored in bestThreshold.
double bestIsometricsTpValue = 0; // IM: Double.NEGATIVE_INFINITY ?
double bestThreshold = 1;
// Iterate through the example set sorted by predictions.
// In each iteration the example with next highest confidence of being positive
// is added to the set of covered examples.
List statsData = new LinkedList();
statsData.add(new double[] { 0.0d, 0.0d }); // add first point in ROC curve
for (int i = 0; i < calArray.length; i++) {
WeightedConfidenceAndLabel wcl = calArray[i];
double weight = wcl.getWeight();
double fp = sum - tp;
// c is the value at the TP axis connecting the current point in ROC space
// with a line with the slope given by the user.
double c = tp - (fp * slope);
if (wcl.getLabel() == positiveLabelIndex) {
tp += weight;
} else {
if (c > bestIsometricsTpValue) {
bestIsometricsTpValue = c;
bestThreshold = wcl.getConfidence();
}
}
statsData.add(new double[] { fp, tp });
sum += weight;
}
// scaling for plotting
double sumPos = tp;
double sumNeg = (sum - sumPos);
bestIsometricsTpValue /= sumPos;
statsData.add(new double[] { sumNeg, sumPos }); // add last point in ROC curve
// show plotter and calculate AUC (area under curve)
double aucSum = 0.0d;
double[] last = null;
Statistics stats = new Statistics("ROCplot");
stats.init(new String[] { "FP/N", "TP/P", "Slope" });
Iterator i = statsData.iterator();
boolean first = true;
int pointCounter = 0;
int eachPoint = (int)Math.round((double)statsData.size() / (double)MAX_ROC_POINTS);
while (i.hasNext()) {
double[] point = (double[])i.next();
double fpDivN = point[0] / sumNeg; // false positives divided by sum of all negatives
double tpDivP = point[1] / sumPos; // true positives divided by sum of all positives
if ((eachPoint < 1) || ((pointCounter % eachPoint) == 0)) { // draw only MAX_ROC_POINTS points
stats.add(new Object[] { new Double(fpDivN), new Double(tpDivP),
new Double(bestIsometricsTpValue + (fpDivN * slope * (sumNeg/sumPos))) });
}
if (last != null) {
aucSum += ((tpDivP - last[1]) * (fpDivN - last[0]) / 2.0d) + (last[1] * (fpDivN - last[0]));
}
last = new double[] { fpDivN, tpDivP };
pointCounter++;
}
if (getParameterAsBoolean(SHOW_PLOT)) {
SimplePlotterDialog plotter = new SimplePlotterDialog(stats);
plotter.setXAxis(0);
plotter.plotColumn(1, true);
plotter.plotColumn(2, true);
plotter.setDrawRange(0.0d, 1.0d, 0.0d, 1.0d);
plotter.show();
}
// AUC result
if (getParameterAsBoolean(CREATE_AUC)) {
//aucSum /= 100;
PerformanceVector aucPerformanceVector = new PerformanceVector();
aucPerformanceVector.addCriterion(new EstimatedPerformance("AUC", aucSum, 1, false));
aucPerformanceVector.setMainCriterionName("AUC");
return new IOObject[] {
exampleSet,
new Threshold(bestThreshold, label.mapIndex(negativeLabelIndex), label.mapIndex(positiveLabelIndex)),
aucPerformanceVector
};
} else { // without AUC
return new IOObject[] {
exampleSet,
new Threshold(bestThreshold, label.mapIndex(negativeLabelIndex), label.mapIndex(positiveLabelIndex))
};
}
}
public List getParameterTypes() {
List list = super.getParameterTypes();
list.add(new ParameterTypeDouble(COSTS_POS, "The costs assigned when a positive example is classified as negative.",
0, Double.POSITIVE_INFINITY, 1));
list.add(new ParameterTypeDouble(COSTS_NEG, "The costs assigned when a negative example is classified as positive.",
0, Double.POSITIVE_INFINITY, 1));
list.add(new ParameterTypeBoolean(SHOW_PLOT, "Display a plot of the ROC curve.", false));
list.add(new ParameterTypeBoolean(CREATE_AUC, "Indicates if the area under the ROC curve should be delivered as performance criterion.", false));
return list;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -