⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 thresholdfinder.java

📁 一个很好的LIBSVM的JAVA源码。对于要研究和改进SVM算法的学者。可以参考。来自数据挖掘工具YALE工具包。
💻 JAVA
字号:
/*
 *  YALE - Yet Another Learning Environment
 *  Copyright (C) 2001-2004
 *      Simon Fischer, Ralf Klinkenberg, Ingo Mierswa, 
 *          Katharina Morik, Oliver Ritthoff
 *      Artificial Intelligence Unit
 *      Computer Science Department
 *      University of Dortmund
 *      44221 Dortmund,  Germany
 *  email: yale-team@lists.sourceforge.net
 *  web:   http://yale.cs.uni-dortmund.de/
 *
 *  This program is free software; you can redistribute it and/or
 *  modify it under the terms of the GNU General Public License as 
 *  published by the Free Software Foundation; either version 2 of the
 *  License, or (at your option) any later version. 
 *
 *  This program is distributed in the hope that it will be useful, but
 *  WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 *  General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 *  USA.
 */
package edu.udo.cs.yale.operator;

import java.util.Arrays;
import java.util.List;
import java.util.LinkedList;
import java.util.Iterator;

import edu.udo.cs.yale.Statistics;
import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleReader;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.performance.PerformanceVector;
import edu.udo.cs.yale.operator.performance.EstimatedPerformance;
import edu.udo.cs.yale.operator.parameter.*;
import edu.udo.cs.yale.gui.SimplePlotterDialog;

/** This operator finds the best threshold for crisp classifying based on user defined costs.
 *
 *  @version $Id$ 
 */
public class ThresholdFinder extends Operator {

    private class WeightedConfidenceAndLabel implements Comparable {
		
	private double confidence, label;
	private double weight = 1;
		
	public WeightedConfidenceAndLabel(double confidence, double label) {
	    this.confidence = confidence;
	    this.label = label;
	}

	public WeightedConfidenceAndLabel(double confidence, double label, double weight) {
	    this(confidence, label);
	    this.weight = weight;
	}
		
	public int compareTo(Object obj) {
	    return (-1)*Double.compare(this.confidence, ((WeightedConfidenceAndLabel) obj).confidence); // IM: (-1) * ... ???
	}
	
	public double getLabel() {
	    return this.label;	
	}
		
	public double getConfidence() {
	    return this.confidence;	
	}

	public double getWeight() {
	    return weight;
	}

	public String toString() {
	    return "conf: " + confidence + ", label: " + label + ", weight: " + weight;
	}
    }

    /** Defines the maximum amount of points which is plotted in the ROC curve. */
    private static final int MAX_ROC_POINTS = 200;

    // The parameters of this operator:
    private static final String COSTS_POS  = "misclassification_costs_positive";
    private static final String COSTS_NEG  = "misclassification_costs_negative";
    private static final String SHOW_PLOT  = "show_roc_plot";
    private static final String CREATE_AUC = "create_AUC_performance";

    /**
     * @see edu.udo.cs.yale.operator.Operator#getInputClasses()
     */
    public Class[] getInputClasses() {
	return new Class[] { ExampleSet.class };
    }

    /**
     * @see edu.udo.cs.yale.operator.Operator#getOutputClasses()
     */
    public Class[] getOutputClasses() {
	if (getParameterAsBoolean(CREATE_AUC))
	    return new Class[] { ExampleSet.class, Threshold.class, PerformanceVector.class };
	else 
	    return new Class[] { ExampleSet.class, Threshold.class };
    }
    
    /**
     * @see edu.udo.cs.yale.operator.Operator#apply()
     */
    public IOObject[] apply() throws OperatorException {
	ExampleSet exampleSet = (ExampleSet) this.getInput(ExampleSet.class);
	Attribute label = exampleSet.getLabel();
	if (label == null)
	    throw new UserError(this, 105);
	if (!label.isNominal())
	    throw new UserError(this, 101, label, "threshold finding");
	if (!label.isBooleanClassification())
	    throw new UserError(this, 118, new Object[] { label, new Integer(label.getValues().size()), new Integer(2) });
	
	ExampleReader reader = exampleSet.getExampleReader();
	WeightedConfidenceAndLabel[] calArray = new WeightedConfidenceAndLabel[exampleSet.getSize()];
	Attribute weightAttr = exampleSet.getWeight();
		
	int index = 0;
	while (reader.hasNext()) {
	    Example example = (Example) reader.next();
	    WeightedConfidenceAndLabel wcl;
	    if (weightAttr == null) {
		wcl = new WeightedConfidenceAndLabel(example.getPredictedLabel(), example.getLabel());
	    }
	    else wcl= new WeightedConfidenceAndLabel(example.getPredictedLabel(), example.getLabel(), example.getValue(weightAttr));
	    calArray[index++] = wcl;
	}
	Arrays.sort(calArray);

	int negativeLabelIndex = label.getNegativeIndex();
	int positiveLabelIndex = label.getPositiveIndex();

	final double slope;
	{ // The slope is defined by the ratio of positive examples and the different misclassification costs.
	    // The formula for the slope is (#pos / #neg) / (costs_neg / costs_pos).
	    double costRatio = getParameterAsDouble(COSTS_NEG) / getParameterAsDouble(COSTS_POS);
	    slope = costRatio;
	}

	// init the true positives and sum of example weights
	double tp  = 0.0d;
	double sum = 0;
	
	// The task is to find the isometric that crosses the TP-axis as high as possible
	// The TP value of the best isometric seen so far is stored in bestIsometricsTpValue,
	// the corresponding threshold is stored in bestThreshold.
	double bestIsometricsTpValue = 0; // IM: Double.NEGATIVE_INFINITY ?
	double bestThreshold = 1;

	// Iterate through the example set sorted by predictions.
	// In each iteration the example with next highest confidence of being positive 
	// is added to the set of covered examples.
	List statsData = new LinkedList();
	statsData.add(new double[] { 0.0d, 0.0d }); // add first point in ROC curve
	for (int i = 0; i < calArray.length; i++) {
	    WeightedConfidenceAndLabel wcl = calArray[i];
	    double weight = wcl.getWeight();
	    double fp = sum - tp;
	    // c is the value at the TP axis connecting the current point in ROC space
	    // with a line with the slope given by the user.
	    double c = tp - (fp * slope);
	    if (wcl.getLabel() == positiveLabelIndex) {
		tp += weight;
	    } else {
		if (c > bestIsometricsTpValue) {
		    bestIsometricsTpValue = c;
		    bestThreshold = wcl.getConfidence();
		}
	    }
	    statsData.add(new double[] { fp, tp });
	    sum += weight;
	}
	// scaling for plotting
	double sumPos = tp;
	double sumNeg = (sum - sumPos);
	bestIsometricsTpValue /= sumPos;
	statsData.add(new double[] { sumNeg, sumPos }); // add last point in ROC curve

	// show plotter and calculate AUC (area under curve)
	double aucSum = 0.0d;
	double[] last = null;
	Statistics stats = new Statistics("ROCplot");
	stats.init(new String[] { "FP/N", "TP/P", "Slope" });
	Iterator i = statsData.iterator();
	boolean first = true;
	int pointCounter = 0;
	int eachPoint = (int)Math.round((double)statsData.size() / (double)MAX_ROC_POINTS);
	while (i.hasNext()) {
	    double[] point    = (double[])i.next();
	    double fpDivN     = point[0] / sumNeg;  // false positives divided by sum of all negatives
	    double tpDivP     = point[1] / sumPos;  // true positives divided by sum of all positives
	    if ((eachPoint < 1) || ((pointCounter % eachPoint) == 0)) {  // draw only MAX_ROC_POINTS points
		stats.add(new Object[] { new Double(fpDivN), new Double(tpDivP), 
					 new Double(bestIsometricsTpValue + (fpDivN * slope * (sumNeg/sumPos))) });
	    }
	    if (last != null) {
		aucSum += ((tpDivP - last[1]) * (fpDivN - last[0]) / 2.0d) + (last[1] * (fpDivN - last[0]));
	    }
	    last = new double[] { fpDivN, tpDivP };
	    pointCounter++;
	}

	if (getParameterAsBoolean(SHOW_PLOT)) {
	    SimplePlotterDialog plotter = new SimplePlotterDialog(stats);
	    plotter.setXAxis(0);
	    plotter.plotColumn(1, true);
	    plotter.plotColumn(2, true);
	    plotter.setDrawRange(0.0d, 1.0d, 0.0d, 1.0d);
	    plotter.show();
	}

	// AUC result
	if (getParameterAsBoolean(CREATE_AUC)) {
	    //aucSum /= 100;
	    PerformanceVector aucPerformanceVector = new PerformanceVector();
	    aucPerformanceVector.addCriterion(new EstimatedPerformance("AUC", aucSum, 1, false));
	    aucPerformanceVector.setMainCriterionName("AUC");
	    return new IOObject[] { 
		exampleSet,
		new Threshold(bestThreshold, label.mapIndex(negativeLabelIndex), label.mapIndex(positiveLabelIndex)),
		aucPerformanceVector
	    };
	} else {  // without AUC
	    return new IOObject[] { 
		exampleSet,
		new Threshold(bestThreshold, label.mapIndex(negativeLabelIndex), label.mapIndex(positiveLabelIndex))
	    };
	}
    }


    public List getParameterTypes() {
	List list = super.getParameterTypes();
	list.add(new ParameterTypeDouble(COSTS_POS, "The costs assigned when a positive example is classified as negative.", 
					 0, Double.POSITIVE_INFINITY, 1));
	list.add(new ParameterTypeDouble(COSTS_NEG, "The costs assigned when a negative example is classified as positive.", 
					 0, Double.POSITIVE_INFINITY, 1));
	list.add(new ParameterTypeBoolean(SHOW_PLOT, "Display a plot of the ROC curve.", false));
	list.add(new ParameterTypeBoolean(CREATE_AUC, "Indicates if the area under the ROC curve should be delivered as performance criterion.", false));
	return list;
    }
	
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -