📄 confidenceapplier.java
字号:
/* * YALE - Yet Another Learning Environment * Copyright (C) 2002, 2003 * Simon Fischer, Ralf Klinkenberg, Ingo Mierswa, * Katharina Morik, Oliver Ritthoff * Artificial Intelligence Unit * Computer Science Department * University of Dortmund * 44221 Dortmund, Germany * email: yale@ls8.cs.uni-dortmund.de * web: http://yale.cs.uni-dortmund.de/ * * This program is free software; you can redistribute it and/or * modify it under the terms of the GNU General Public License as * published by the Free Software Foundation; either version 2 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 * USA. */package edu.udo.cs.yale.operator.learner;import edu.udo.cs.yale.operator.Operator;import edu.udo.cs.yale.operator.OperatorChain;import edu.udo.cs.yale.operator.OperatorException;import edu.udo.cs.yale.operator.UserError;import edu.udo.cs.yale.operator.Value;import edu.udo.cs.yale.operator.IllegalInputException;import edu.udo.cs.yale.operator.IOObject;import edu.udo.cs.yale.operator.ResultObjectAdapter;import edu.udo.cs.yale.operator.IOContainer;import edu.udo.cs.yale.operator.IODescription;import edu.udo.cs.yale.operator.SimpleResultObject;import edu.udo.cs.yale.Statistics;import edu.udo.cs.yale.operator.performance.PerformanceVector;import edu.udo.cs.yale.operator.parameter.*;import edu.udo.cs.yale.example.Attribute;import edu.udo.cs.yale.example.ExampleSet;import edu.udo.cs.yale.example.Example;import edu.udo.cs.yale.example.ExampleReader;import edu.udo.cs.yale.tools.LogService;import java.util.List;import java.util.LinkedList;import java.util.Map;import java.util.TreeMap;import java.util.Set;import java.io.*;/** Finds the best threshold that optimizes the given performance criterion. * * @version $Id: ConfidenceApplier.java,v 2.7 2003/07/03 16:01:30 fischer Exp $ */public class ConfidenceApplier extends OperatorChain { private static final Class[] INPUT_CLASSES = { ExampleSet.class }; private static final Class[] OUTPUT_CLASSES = { ExampleSet.class }; private static final String[] SEARCH_TYPES = { "linear", "fixed" }; //private static final int BINARY = 0; private static final int LINEAR = 0; private static final int FIXED = 1; private int positiveIndex = 2; private int negativeIndex = 1; private double threshold = Double.NaN; private ThresholdPerformances thresholdPerformances; public ConfidenceApplier() { addValue(new Value("threshold", "The last threshold used.") { public double getValue() { return threshold; } }); } public static class ThresholdPerformances extends ResultObjectAdapter { private Map performanceMap = new TreeMap(); private double bestThreshold = Double.NEGATIVE_INFINITY; private PerformanceVector bestPerformance = null; public void add(double threshold, PerformanceVector performance) { performanceMap.put(new Double(threshold), performance); if ((bestPerformance == null) || (performance.compareTo(bestPerformance) > 0)) { bestThreshold = threshold; bestPerformance = performance; } } public PerformanceVector getBestPerformance() { return bestPerformance; } public double getBestThreshold() { return bestThreshold; } public Set getThresholds() { return performanceMap.keySet(); } public PerformanceVector getPerformance(Double threshold) { return (PerformanceVector)performanceMap.get(threshold); } public String toString() { return "Best threshold: "+bestThreshold+ "; "+bestPerformance; } public String toHTML() { String result = "<h1>Threshold</h1><b>Best threshold is: "+bestThreshold+"</b>"; result += bestPerformance.toHTML(); return result; } public java.awt.Component getVisualisationComponent() { javax.swing.JLabel label = new javax.swing.JLabel("<html>"+toHTML()+"</html>"); label.setFont(label.getFont().deriveFont(java.awt.Font.PLAIN)); return label; } } public IOObject[] apply() throws OperatorException { thresholdPerformances = new ThresholdPerformances(); ExampleSet exampleSet = (ExampleSet)getInput(ExampleSet.class); Attribute confidence = exampleSet.getPredictedLabel(); if (confidence == null) throw new UserError(this, 107, new Object[0]); exampleSet.createPredictedLabel(); switch (getParameterAsInt("search_type")) { case LINEAR: linearSearch(exampleSet, confidence); return new IOObject[] { exampleSet, thresholdPerformances.bestPerformance, thresholdPerformances };// case BINARY:// return new IOObject[] { exampleSet, binarySearch(exampleSet, confidence, // confidence.getMinimum(), null,// confidence.getMaximum(), null),// new SimpleResultObject("Threshold", "Best threshold is: "+bestThreshold) }; case FIXED: setPredictions(exampleSet, confidence, getParameterAsDouble("fixed_threshold")); return new IOObject[] { exampleSet }; default: throw new RuntimeException("Illegal search type!"); } } private void addRecord(double threshold, PerformanceVector performance) { Statistics statistics = getExperiment().getStatistics(getName()); if (statistics.getNumberOfRows() == 0) { String[] names = new String[1 + performance.size()]; names[0] = "threshold"; for (int i = 0; i < performance.size(); i++) { names[i+1] = performance.get(i).getName(); } statistics.init(names); } Object row[] = new Object[1 + performance.size()]; row[0] = new Double(threshold); for (int i = 0; i < performance.size(); i++) { row[i+1] = new Double(performance.get(i).getValue()); } statistics.add(row); } private void linearSearch(ExampleSet exampleSet, Attribute confidence) throws OperatorException { double delta = getParameterAsDouble("delta_threshold_min"); LogService.logMessage("Starting linear search from "+confidence.getMinimum() + " to "+confidence.getMaximum()+"; step size is "+delta+".", LogService.TASK); for (int i = 0; confidence.getMinimum() + delta*i <= confidence.getMaximum(); i++) { threshold = confidence.getMinimum() + delta*(double)i; setPredictions(exampleSet, confidence, threshold); PerformanceVector performance = evaluate(exampleSet); addRecord(threshold, performance); thresholdPerformances.add(threshold, performance); } }// private PerformanceVector binarySearch(ExampleSet exampleSet, // Attribute confidence,// double lower, // PerformanceVector lowerPerformance,// double upper,// PerformanceVector upperPerformance) throws OperatorException {// if (lowerPerformance == null) {// setPredictions(exampleSet, confidence, lower);// lowerPerformance = evaluate(exampleSet);// }// if (upperPerformance == null) {// setPredictions(exampleSet, confidence, upper);// upperPerformance = evaluate(exampleSet);// }// boolean lowerIsBetter = lowerPerformance.compareTo(upperPerformance) > 0;// if (upper - lower < getParameterAsDouble("delta_threshold_min")) {// if (lowerIsBetter) {// bestThreshold = lower;// return lowerPerformance;// } else {// bestThreshold = upper;// return upperPerformance;// }// } else {// double mean = (lower + upper) / 2;// if (lowerIsBetter) {// return binarySearch(exampleSet, confidence, lower, lowerPerformance, mean, null);// } else {// return binarySearch(exampleSet, confidence, mean, null, upper, upperPerformance);// }// }// } private void setPredictions(ExampleSet exampleSet, Attribute confidence, double threshold) { ExampleReader reader = exampleSet.getExampleReader(); while (reader.hasNext()) { Example example = reader.next(); example.setPredictedLabel(example.getValue(confidence) > threshold ? positiveIndex : negativeIndex); } } private PerformanceVector evaluate(ExampleSet exampleSet) throws OperatorException { return (PerformanceVector)getOperator(0).apply(new IOContainer(new IOObject[] { exampleSet } )).getInput(PerformanceVector.class); } private boolean needEvaluation() { return getParameterAsInt("search_type") != FIXED; } public void experimentFinished() { super.experimentFinished(); String filename = getParameterAsString("filename"); if (filename != null) { LogService.logMessage(getName()+": writing statistics to '"+filename+"'", LogService.INIT); try { PrintWriter out = new PrintWriter(new FileWriter(filename)); getExperiment().getStatistics(getName()).write(out); out.close(); } catch (IOException e) { LogService.logMessage(getName() + ": Could not write to file '"+filename+"'", LogService.ERROR); } } } public Class[] getInputClasses() { return INPUT_CLASSES; } public Class[] getOutputClasses() { return OUTPUT_CLASSES; } public int getNumberOfSteps() { return 1; } public int getMinNumberOfInnerOperators() { return 0; } public int getMaxNumberOfInnerOperators() { return 1; } public Class[] checkIO(Class[] input) throws IllegalInputException{ if (!IODescription.containsClass(ExampleSet.class, input)) throw new IllegalInputException("No ExampleSet found in input!", this); if (needEvaluation()) { if (getNumberOfOperators() != 1) { throw new IllegalInputException("If search_type is not \"fixed\", an inner evaluator is needed!", this); } Operator evaluator = getOperator(0); if (!IODescription.containsClass(PerformanceVector.class, evaluator.checkIO(new Class[] {ExampleSet.class}))) { throw new IllegalInputException("Inner operator cannot evaluate example set!", this); } return new Class[] {ExampleSet.class, PerformanceVector.class, SimpleResultObject.class}; } else { if (getNumberOfOperators() != 0) { throw new IllegalInputException("If search_type is \"fixed\", there must not be an inner evaluator!", this); } return new Class[] {ExampleSet.class}; } } public List getParameterTypes() { List types = super.getParameterTypes(); types.add(new ParameterTypeCategory("search_type", "Specifies the way the best threshold is searched. Linear tries all values in a range, fixed uses the value specified by the parameter fixed_threshold.", SEARCH_TYPES, 0)); types.add(new ParameterTypeDouble("fixed_threshold", "Fixed threshold for choosing the predicted label.", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0.0)); types.add(new ParameterTypeDouble("delta_threshold_min", "Minimum threshold difference.", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0.025)); types.add(new ParameterTypeFile("filename", "File to save the statistics to.", true)); return types; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -