⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 confidenceapplier.java

📁 一个很好的LIBSVM的JAVA源码。对于要研究和改进SVM算法的学者。可以参考。来自数据挖掘工具YALE工具包。
💻 JAVA
字号:
/*
 *  YALE - Yet Another Learning Environment
 *  Copyright (C) 2001-2004
 *      Simon Fischer, Ralf Klinkenberg, Ingo Mierswa, 
 *          Katharina Morik, Oliver Ritthoff
 *      Artificial Intelligence Unit
 *      Computer Science Department
 *      University of Dortmund
 *      44221 Dortmund,  Germany
 *  email: yale-team@lists.sourceforge.net
 *  web:   http://yale.cs.uni-dortmund.de/
 *
 *  This program is free software; you can redistribute it and/or
 *  modify it under the terms of the GNU General Public License as 
 *  published by the Free Software Foundation; either version 2 of the
 *  License, or (at your option) any later version. 
 *
 *  This program is distributed in the hope that it will be useful, but
 *  WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 *  General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 *  USA.
 */
package edu.udo.cs.yale.operator.learner;

import edu.udo.cs.yale.operator.Operator;
import edu.udo.cs.yale.operator.OperatorChain;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.UserError;
import edu.udo.cs.yale.operator.Value;
import edu.udo.cs.yale.operator.IllegalInputException;
import edu.udo.cs.yale.operator.IOObject;
import edu.udo.cs.yale.operator.ResultObjectAdapter;
import edu.udo.cs.yale.operator.IOContainer;
import edu.udo.cs.yale.operator.IODescription;
import edu.udo.cs.yale.operator.SimpleResultObject;
import edu.udo.cs.yale.Statistics;
import edu.udo.cs.yale.operator.performance.PerformanceVector;
import edu.udo.cs.yale.operator.parameter.*;
import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleReader;
import edu.udo.cs.yale.tools.LogService;

import java.util.List;
import java.util.LinkedList;
import java.util.Map;
import java.util.TreeMap;
import java.util.Set;
import java.io.*;

/** Finds the best threshold that optimizes the given performance criterion.
 *
 *  @version $Id: ConfidenceApplier.java,v 2.13 2004/08/27 11:57:38 ingomierswa Exp $
 */
public class ConfidenceApplier extends OperatorChain {

    private static final Class[] INPUT_CLASSES = { ExampleSet.class };
    private static final Class[] OUTPUT_CLASSES = { ExampleSet.class };

    private static final String[] SEARCH_TYPES = { "linear", "fixed" };
    //private static final int BINARY = 0;
    private static final int LINEAR = 0;
    private static final int FIXED  = 1;

    private int positiveIndex = 2;
    private int negativeIndex = 1;

    private double threshold = Double.NaN;
    private ThresholdPerformances thresholdPerformances;

    public ConfidenceApplier() {
	addValue(new Value("threshold", "The last threshold used.") {
		public double getValue() {
		    return threshold;
		}
	    });
    }

    public static class ThresholdPerformances extends ResultObjectAdapter {
	private Map performanceMap = new TreeMap();
	private double bestThreshold = Double.NEGATIVE_INFINITY;
	private PerformanceVector bestPerformance = null;
	
	public void add(double threshold, PerformanceVector performance) {
	    performanceMap.put(new Double(threshold), performance);
	    if ((bestPerformance == null) ||
		(performance.compareTo(bestPerformance) > 0)) {
		bestThreshold   = threshold;
		bestPerformance = performance;
	    }
	}

	public PerformanceVector getBestPerformance() { return bestPerformance;	}
	public double getBestThreshold() { return bestThreshold; }
	public Set getThresholds() { return performanceMap.keySet(); }
	public PerformanceVector getPerformance(Double threshold) { return (PerformanceVector)performanceMap.get(threshold); }

	public String toString() {
	    return "Best threshold: "+bestThreshold+ "; "+bestPerformance;
	}

	public String toHTML() {
	    String result = "<h1>Threshold</h1><b>Best threshold is: "+bestThreshold+"</b>";
	    result += bestPerformance.toHTML();
	    return result;
	}
	
	public java.awt.Component getVisualisationComponent() {
	    javax.swing.JLabel label = new javax.swing.JLabel("<html>"+toHTML()+"</html>");
	    label.setFont(label.getFont().deriveFont(java.awt.Font.PLAIN));
	    return label;
	}
    }

    public IOObject[] apply() throws OperatorException {
	thresholdPerformances = new ThresholdPerformances();
	ExampleSet exampleSet = (ExampleSet)getInput(ExampleSet.class);
	Attribute confidence = exampleSet.getPredictedLabel();
	if (confidence == null) throw new UserError(this, 107, new Object[0]);

  	Attribute predictedLabel = new Attribute(exampleSet.getLabel(), "prediction");
  	exampleSet.getExampleTable().addAttribute(predictedLabel);
  	exampleSet.setPredictedLabel(predictedLabel);

	switch (getParameterAsInt("search_type")) {
	case LINEAR:
	    linearSearch(exampleSet, confidence);
	    return new IOObject[] { exampleSet, 
				    thresholdPerformances.bestPerformance,
				    thresholdPerformances };
//  	case BINARY:
//  	    return new IOObject[] { exampleSet, binarySearch(exampleSet, confidence, 
//  							     confidence.getMinimum(), null,
//  							     confidence.getMaximum(), null),
//  				    new SimpleResultObject("Threshold", "Best threshold is: "+bestThreshold) };
	case FIXED:
	    setPredictions(exampleSet, confidence, getParameterAsDouble("fixed_threshold"));
	    return new IOObject[] { exampleSet };
	default:
	    throw new RuntimeException("Illegal search type!");
	}
    }

    private void addRecord(double threshold, PerformanceVector performance) {
	Statistics statistics = getExperiment().getStatistics(getName());
	if (statistics.getNumberOfRows() == 0) {
	    String[] names = new String[1 + performance.size()];
	    names[0] = "threshold";
	    for (int i = 0; i < performance.size(); i++) {
		names[i+1] =  performance.getCriterion(i).getName();
	    }
	    statistics.init(names);
	}

	Object row[] = new Object[1 + performance.size()];
	row[0] = new Double(threshold);
	for (int i = 0; i < performance.size(); i++) {
	    row[i+1] =  new Double(performance.getCriterion(i).getValue());
	}
	statistics.add(row);
    }

    private void linearSearch(ExampleSet exampleSet, Attribute confidence) throws OperatorException {
	double delta = getParameterAsDouble("delta_threshold_min");
	LogService.logMessage("Starting linear search from "+confidence.getMinimum() +
			      " to "+confidence.getMaximum()+"; step size is "+delta+".", LogService.TASK);
	for (int i = 0; confidence.getMinimum() + delta*i <= confidence.getMaximum(); i++) {
	    threshold = confidence.getMinimum() + delta*(double)i;
	    setPredictions(exampleSet, confidence, threshold);
	    PerformanceVector performance = evaluate(exampleSet);
	    addRecord(threshold, performance);
	    thresholdPerformances.add(threshold, performance);
	}
    }

//      private PerformanceVector binarySearch(ExampleSet exampleSet, 
//  					   Attribute confidence,
//  					   double lower, 
//  					   PerformanceVector lowerPerformance,
//  					   double upper,
//  					   PerformanceVector upperPerformance) throws OperatorException {

//  	if (lowerPerformance == null) {
//  	    setPredictions(exampleSet, confidence, lower);
//  	    lowerPerformance = evaluate(exampleSet);
//  	}
//  	if (upperPerformance == null) {
//  	    setPredictions(exampleSet, confidence, upper);
//  	    upperPerformance = evaluate(exampleSet);
//  	}

//  	boolean lowerIsBetter = lowerPerformance.compareTo(upperPerformance) > 0;
//  	if (upper - lower < getParameterAsDouble("delta_threshold_min")) {
//  	    if (lowerIsBetter) {
//  		bestThreshold = lower;
//  		return lowerPerformance;
//  	    } else {
//  		bestThreshold = upper;
//  		return upperPerformance;
//  	    }
//  	} else {
//  	    double mean = (lower + upper) / 2;
//  	    if (lowerIsBetter) {
//  		return binarySearch(exampleSet, confidence, lower, lowerPerformance, mean, null);
//  	    } else {
//  		return binarySearch(exampleSet, confidence, mean, null, upper, upperPerformance);
//  	    }
//  	}
//      }

    private void setPredictions(ExampleSet exampleSet, Attribute confidence, double threshold) {
	ExampleReader reader = exampleSet.getExampleReader();
	while (reader.hasNext()) {
	    Example example = reader.next();
	    example.setPredictedLabel(example.getValue(confidence) > threshold ?
				      positiveIndex : negativeIndex);
	}
    }

    private PerformanceVector evaluate(ExampleSet exampleSet) throws OperatorException {
	return (PerformanceVector)getOperator(0).apply(new IOContainer(new IOObject[] { exampleSet } )).getInput(PerformanceVector.class);
    }

    private boolean needEvaluation() {
	return getParameterAsInt("search_type") != FIXED;
    }

    public void experimentFinished() {
	super.experimentFinished();

	String filename = getParameterAsString("filename");
	File statFile = getExperiment().resolveFileName(filename);
	if (filename != null) {
	    LogService.logMessage(getName()+": writing statistics to '"+filename+"'", LogService.INIT);
	    try {
		PrintWriter out = new PrintWriter(new FileWriter(statFile));
		getExperiment().getStatistics(getName()).write(out);
		out.close();
	    } catch (IOException e) {
		LogService.logMessage(getName() + ": Could not write to file '"+filename+"'", LogService.ERROR);
	    }
	}
    }

    public Class[] getInputClasses() { return INPUT_CLASSES; }
    public Class[] getOutputClasses() { return OUTPUT_CLASSES; }

    public int getNumberOfSteps() {
	return 1;
    }

    public int getMinNumberOfInnerOperators() { return 0; }
    public int getMaxNumberOfInnerOperators() { return 1; }

    public Class[] checkIO(Class[] input) throws IllegalInputException{
	if (!IODescription.containsClass(ExampleSet.class, input)) 
	    throw new IllegalInputException(this, ExampleSet.class);

	if (needEvaluation()) {
	    if (getNumberOfOperators() != 1) {
		throw new IllegalInputException(this, "If search_type is not \"fixed\", an inner evaluator is needed!");
	    }
	    Operator evaluator = getOperator(0);
	    if (!IODescription.containsClass(PerformanceVector.class, evaluator.checkIO(new Class[] {ExampleSet.class}))) {
		throw new IllegalInputException(this, evaluator, PerformanceVector.class);
	    }
	    return new Class[] {ExampleSet.class, PerformanceVector.class, SimpleResultObject.class};
	} else {
	    if (getNumberOfOperators() != 0) {
		throw new IllegalInputException(this, "If search_type is \"fixed\", there must not be an inner evaluator!");
	    }
	    return new Class[] {ExampleSet.class};
	}
    }

    public List getParameterTypes() {
	List types = super.getParameterTypes();

	types.add(new ParameterTypeCategory("search_type", "Specifies the way the best threshold is searched. Linear tries all values in a range, fixed uses the value specified by the parameter fixed_threshold.", SEARCH_TYPES, 0));
	types.add(new ParameterTypeDouble("fixed_threshold", "Fixed threshold for choosing the predicted label.", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0.0));
	types.add(new ParameterTypeDouble("delta_threshold_min", "Minimum threshold difference.", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0.025));
	types.add(new ParameterTypeFile("filename", "File to save the statistics to.", true));
	return types;
    }

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -