📄 confidenceapplier.java
字号:
/*
* YALE - Yet Another Learning Environment
* Copyright (C) 2001-2004
* Simon Fischer, Ralf Klinkenberg, Ingo Mierswa,
* Katharina Morik, Oliver Ritthoff
* Artificial Intelligence Unit
* Computer Science Department
* University of Dortmund
* 44221 Dortmund, Germany
* email: yale-team@lists.sourceforge.net
* web: http://yale.cs.uni-dortmund.de/
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License as
* published by the Free Software Foundation; either version 2 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
* USA.
*/
package edu.udo.cs.yale.operator.learner;
import edu.udo.cs.yale.operator.Operator;
import edu.udo.cs.yale.operator.OperatorChain;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.UserError;
import edu.udo.cs.yale.operator.Value;
import edu.udo.cs.yale.operator.IllegalInputException;
import edu.udo.cs.yale.operator.IOObject;
import edu.udo.cs.yale.operator.ResultObjectAdapter;
import edu.udo.cs.yale.operator.IOContainer;
import edu.udo.cs.yale.operator.IODescription;
import edu.udo.cs.yale.operator.SimpleResultObject;
import edu.udo.cs.yale.Statistics;
import edu.udo.cs.yale.operator.performance.PerformanceVector;
import edu.udo.cs.yale.operator.parameter.*;
import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleReader;
import edu.udo.cs.yale.tools.LogService;
import java.util.List;
import java.util.LinkedList;
import java.util.Map;
import java.util.TreeMap;
import java.util.Set;
import java.io.*;
/** Finds the best threshold that optimizes the given performance criterion.
*
* @version $Id: ConfidenceApplier.java,v 2.13 2004/08/27 11:57:38 ingomierswa Exp $
*/
public class ConfidenceApplier extends OperatorChain {
private static final Class[] INPUT_CLASSES = { ExampleSet.class };
private static final Class[] OUTPUT_CLASSES = { ExampleSet.class };
private static final String[] SEARCH_TYPES = { "linear", "fixed" };
//private static final int BINARY = 0;
private static final int LINEAR = 0;
private static final int FIXED = 1;
private int positiveIndex = 2;
private int negativeIndex = 1;
private double threshold = Double.NaN;
private ThresholdPerformances thresholdPerformances;
public ConfidenceApplier() {
addValue(new Value("threshold", "The last threshold used.") {
public double getValue() {
return threshold;
}
});
}
public static class ThresholdPerformances extends ResultObjectAdapter {
private Map performanceMap = new TreeMap();
private double bestThreshold = Double.NEGATIVE_INFINITY;
private PerformanceVector bestPerformance = null;
public void add(double threshold, PerformanceVector performance) {
performanceMap.put(new Double(threshold), performance);
if ((bestPerformance == null) ||
(performance.compareTo(bestPerformance) > 0)) {
bestThreshold = threshold;
bestPerformance = performance;
}
}
public PerformanceVector getBestPerformance() { return bestPerformance; }
public double getBestThreshold() { return bestThreshold; }
public Set getThresholds() { return performanceMap.keySet(); }
public PerformanceVector getPerformance(Double threshold) { return (PerformanceVector)performanceMap.get(threshold); }
public String toString() {
return "Best threshold: "+bestThreshold+ "; "+bestPerformance;
}
public String toHTML() {
String result = "<h1>Threshold</h1><b>Best threshold is: "+bestThreshold+"</b>";
result += bestPerformance.toHTML();
return result;
}
public java.awt.Component getVisualisationComponent() {
javax.swing.JLabel label = new javax.swing.JLabel("<html>"+toHTML()+"</html>");
label.setFont(label.getFont().deriveFont(java.awt.Font.PLAIN));
return label;
}
}
public IOObject[] apply() throws OperatorException {
thresholdPerformances = new ThresholdPerformances();
ExampleSet exampleSet = (ExampleSet)getInput(ExampleSet.class);
Attribute confidence = exampleSet.getPredictedLabel();
if (confidence == null) throw new UserError(this, 107, new Object[0]);
Attribute predictedLabel = new Attribute(exampleSet.getLabel(), "prediction");
exampleSet.getExampleTable().addAttribute(predictedLabel);
exampleSet.setPredictedLabel(predictedLabel);
switch (getParameterAsInt("search_type")) {
case LINEAR:
linearSearch(exampleSet, confidence);
return new IOObject[] { exampleSet,
thresholdPerformances.bestPerformance,
thresholdPerformances };
// case BINARY:
// return new IOObject[] { exampleSet, binarySearch(exampleSet, confidence,
// confidence.getMinimum(), null,
// confidence.getMaximum(), null),
// new SimpleResultObject("Threshold", "Best threshold is: "+bestThreshold) };
case FIXED:
setPredictions(exampleSet, confidence, getParameterAsDouble("fixed_threshold"));
return new IOObject[] { exampleSet };
default:
throw new RuntimeException("Illegal search type!");
}
}
private void addRecord(double threshold, PerformanceVector performance) {
Statistics statistics = getExperiment().getStatistics(getName());
if (statistics.getNumberOfRows() == 0) {
String[] names = new String[1 + performance.size()];
names[0] = "threshold";
for (int i = 0; i < performance.size(); i++) {
names[i+1] = performance.getCriterion(i).getName();
}
statistics.init(names);
}
Object row[] = new Object[1 + performance.size()];
row[0] = new Double(threshold);
for (int i = 0; i < performance.size(); i++) {
row[i+1] = new Double(performance.getCriterion(i).getValue());
}
statistics.add(row);
}
private void linearSearch(ExampleSet exampleSet, Attribute confidence) throws OperatorException {
double delta = getParameterAsDouble("delta_threshold_min");
LogService.logMessage("Starting linear search from "+confidence.getMinimum() +
" to "+confidence.getMaximum()+"; step size is "+delta+".", LogService.TASK);
for (int i = 0; confidence.getMinimum() + delta*i <= confidence.getMaximum(); i++) {
threshold = confidence.getMinimum() + delta*(double)i;
setPredictions(exampleSet, confidence, threshold);
PerformanceVector performance = evaluate(exampleSet);
addRecord(threshold, performance);
thresholdPerformances.add(threshold, performance);
}
}
// private PerformanceVector binarySearch(ExampleSet exampleSet,
// Attribute confidence,
// double lower,
// PerformanceVector lowerPerformance,
// double upper,
// PerformanceVector upperPerformance) throws OperatorException {
// if (lowerPerformance == null) {
// setPredictions(exampleSet, confidence, lower);
// lowerPerformance = evaluate(exampleSet);
// }
// if (upperPerformance == null) {
// setPredictions(exampleSet, confidence, upper);
// upperPerformance = evaluate(exampleSet);
// }
// boolean lowerIsBetter = lowerPerformance.compareTo(upperPerformance) > 0;
// if (upper - lower < getParameterAsDouble("delta_threshold_min")) {
// if (lowerIsBetter) {
// bestThreshold = lower;
// return lowerPerformance;
// } else {
// bestThreshold = upper;
// return upperPerformance;
// }
// } else {
// double mean = (lower + upper) / 2;
// if (lowerIsBetter) {
// return binarySearch(exampleSet, confidence, lower, lowerPerformance, mean, null);
// } else {
// return binarySearch(exampleSet, confidence, mean, null, upper, upperPerformance);
// }
// }
// }
private void setPredictions(ExampleSet exampleSet, Attribute confidence, double threshold) {
ExampleReader reader = exampleSet.getExampleReader();
while (reader.hasNext()) {
Example example = reader.next();
example.setPredictedLabel(example.getValue(confidence) > threshold ?
positiveIndex : negativeIndex);
}
}
private PerformanceVector evaluate(ExampleSet exampleSet) throws OperatorException {
return (PerformanceVector)getOperator(0).apply(new IOContainer(new IOObject[] { exampleSet } )).getInput(PerformanceVector.class);
}
private boolean needEvaluation() {
return getParameterAsInt("search_type") != FIXED;
}
public void experimentFinished() {
super.experimentFinished();
String filename = getParameterAsString("filename");
File statFile = getExperiment().resolveFileName(filename);
if (filename != null) {
LogService.logMessage(getName()+": writing statistics to '"+filename+"'", LogService.INIT);
try {
PrintWriter out = new PrintWriter(new FileWriter(statFile));
getExperiment().getStatistics(getName()).write(out);
out.close();
} catch (IOException e) {
LogService.logMessage(getName() + ": Could not write to file '"+filename+"'", LogService.ERROR);
}
}
}
public Class[] getInputClasses() { return INPUT_CLASSES; }
public Class[] getOutputClasses() { return OUTPUT_CLASSES; }
public int getNumberOfSteps() {
return 1;
}
public int getMinNumberOfInnerOperators() { return 0; }
public int getMaxNumberOfInnerOperators() { return 1; }
public Class[] checkIO(Class[] input) throws IllegalInputException{
if (!IODescription.containsClass(ExampleSet.class, input))
throw new IllegalInputException(this, ExampleSet.class);
if (needEvaluation()) {
if (getNumberOfOperators() != 1) {
throw new IllegalInputException(this, "If search_type is not \"fixed\", an inner evaluator is needed!");
}
Operator evaluator = getOperator(0);
if (!IODescription.containsClass(PerformanceVector.class, evaluator.checkIO(new Class[] {ExampleSet.class}))) {
throw new IllegalInputException(this, evaluator, PerformanceVector.class);
}
return new Class[] {ExampleSet.class, PerformanceVector.class, SimpleResultObject.class};
} else {
if (getNumberOfOperators() != 0) {
throw new IllegalInputException(this, "If search_type is \"fixed\", there must not be an inner evaluator!");
}
return new Class[] {ExampleSet.class};
}
}
public List getParameterTypes() {
List types = super.getParameterTypes();
types.add(new ParameterTypeCategory("search_type", "Specifies the way the best threshold is searched. Linear tries all values in a range, fixed uses the value specified by the parameter fixed_threshold.", SEARCH_TYPES, 0));
types.add(new ParameterTypeDouble("fixed_threshold", "Fixed threshold for choosing the predicted label.", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0.0));
types.add(new ParameterTypeDouble("delta_threshold_min", "Minimum threshold difference.", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0.025));
types.add(new ParameterTypeFile("filename", "File to save the statistics to.", true));
return types;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -