📄 universalclassificationcriterion.java
字号:
/* * YALE - Yet Another Learning Environment * Copyright (C) 2002, 2003 * Simon Fischer, Ralf Klinkenberg, Ingo Mierswa, * Katharina Morik, Oliver Ritthoff * Artificial Intelligence Unit * Computer Science Department * University of Dortmund * 44221 Dortmund, Germany * email: yale@ls8.cs.uni-dortmund.de * web: http://yale.cs.uni-dortmund.de/ * * This program is free software; you can redistribute it and/or * modify it under the terms of the GNU General Public License as * published by the Free Software Foundation; either version 2 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 * USA. */package edu.udo.cs.yale.operator.performance;import edu.udo.cs.yale.example.Attribute;import edu.udo.cs.yale.example.Example;import edu.udo.cs.yale.example.ExampleSet;import edu.udo.cs.yale.tools.LogService;/** This class encapsulates the four well known classification criteria: * accuracy, precision, recall, and classification error. Furthermore it can be used to * calculate the fallout. * * @version $Id: UniversalClassificationCriterion.java,v 2.11 2003/07/02 13:57:11 fischer Exp $ */public class UniversalClassificationCriterion extends MeasuredPerformance { public static final int ACCURACY = 0; public static final int PRECISION = 1; public static final int RECALL = 2; public static final int CLASS_ERROR = 3; public static final int FALLOUT = 4; public static final int F_MEASURE = 5; public static final int FALSE_POSITIVE = 6; public static final int FALSE_NEGATIVE = 7; public static final int TRUE_POSITIVE = 8; public static final int TRUE_NEGATIVE = 9; private static final int POSITIVE = 0; private static final int NEGATIVE = 1; public static final String[] NAME = { "accuracy", "precision", "recall", "classification_error", "fallout", "f_measure", "false_positive", "false_negative", "true_positive", "true_negative" }; private int type = 0; private int positive; public UniversalClassificationCriterion() { type = positive = -1; } public UniversalClassificationCriterion(int type) { this(type, Attribute.FIRST_CLASS_INDEX); } public UniversalClassificationCriterion(int type, int positive) { this.type = type; this.positive = positive; } /** For test cases only. */ public UniversalClassificationCriterion(int type, int[][] counter) { this.type = type; this.counter[0][0] = counter[0][0]; this.counter[0][1] = counter[0][1]; this.counter[1][0] = counter[1][0]; this.counter[1][1] = counter[1][1]; } /** true label, predicted label. */ private int[][] counter = new int[2][2]; public void startCounting(ExampleSet set) { counter = new int[2][2]; } public void countExample(Example example) { int label = (example.getLabel() == positive) ? POSITIVE : NEGATIVE; int plabel = (example.getPredictedLabel() == positive) ? POSITIVE : NEGATIVE; counter[label][plabel]++; } public double getValue() { int x=0, y=0; switch (type) { case ACCURACY: x = counter[POSITIVE][POSITIVE] + counter[NEGATIVE][NEGATIVE]; y = counter[POSITIVE][POSITIVE] + counter[POSITIVE][NEGATIVE] + counter[NEGATIVE][POSITIVE] + counter[NEGATIVE][NEGATIVE]; break; case PRECISION: x = counter[POSITIVE][POSITIVE]; y = counter[POSITIVE][POSITIVE] + counter[NEGATIVE][POSITIVE]; break; case RECALL: x = counter[POSITIVE][POSITIVE]; y = counter[POSITIVE][POSITIVE] + counter[POSITIVE][NEGATIVE]; break; case CLASS_ERROR: x = counter[POSITIVE][NEGATIVE] + counter[NEGATIVE][POSITIVE]; y = counter[POSITIVE][POSITIVE] + counter[POSITIVE][NEGATIVE] + counter[NEGATIVE][POSITIVE] + counter[NEGATIVE][NEGATIVE]; break; case FALLOUT: x = counter[POSITIVE][NEGATIVE]; y = counter[NEGATIVE][POSITIVE] + counter[NEGATIVE][NEGATIVE]; break; case F_MEASURE: x = counter[POSITIVE][POSITIVE]; x *= x; x *= 2; y = x + counter[POSITIVE][POSITIVE]*counter[POSITIVE][NEGATIVE] + counter[POSITIVE][POSITIVE]*counter[NEGATIVE][POSITIVE]; break; case FALSE_NEGATIVE: x = counter[POSITIVE][NEGATIVE]; y = 1; break; case FALSE_POSITIVE: x = counter[NEGATIVE][POSITIVE]; y = 1; break; case TRUE_NEGATIVE: x = counter[NEGATIVE][NEGATIVE]; y = 1; break; case TRUE_POSITIVE: x = counter[POSITIVE][POSITIVE]; y = 1; break; default: throw new RuntimeException("Illegal value for type in UniversalClassificationCriterion: "+type); } if (y == 0) return Double.NaN; return (double)x/(double)y; } public double getVariance() { return Double.NaN; } public String getName() { return NAME[type]; } public double getFitness() { switch (type) { case ACCURACY: case PRECISION: case RECALL: case TRUE_POSITIVE: case TRUE_NEGATIVE: case F_MEASURE: return getValue(); case FALLOUT: case CLASS_ERROR: case FALSE_POSITIVE: case FALSE_NEGATIVE: if (getValue() == 0d) return Double.POSITIVE_INFINITY; return 1/getValue(); default: throw new RuntimeException("Illegal value for type in UniversalClassificationCriterion: "+type); } } public double getMaxFitness() { switch (type) { case ACCURACY: case PRECISION: case RECALL: case F_MEASURE: return 1.0d; case TRUE_POSITIVE: case TRUE_NEGATIVE: case FALLOUT: case CLASS_ERROR: case FALSE_POSITIVE: case FALSE_NEGATIVE: return Double.POSITIVE_INFINITY; default: throw new RuntimeException("Illegal value for type in UniversalClassificationCriterion: "+type); } } public static UniversalClassificationCriterion newInstance(String name) { for (int i = 0; i < NAME.length; i++) { if (NAME[i].equals(name)) return new UniversalClassificationCriterion(i); } return null; } public boolean formatPercent() { switch (type) { case TRUE_POSITIVE: case TRUE_NEGATIVE: case FALSE_POSITIVE: case FALSE_NEGATIVE: return false; default: return true; } } void clonePerformanceCriterion(PerformanceCriterion newPC) { super.clonePerformanceCriterion(newPC); UniversalClassificationCriterion newUCC = (UniversalClassificationCriterion)newPC; this.type = newUCC.type; this.positive = newUCC.positive; this.counter = new int[2][2]; this.counter[0][0] = newUCC.counter[0][0]; this.counter[1][0] = newUCC.counter[1][0]; this.counter[0][1] = newUCC.counter[0][1]; this.counter[1][1] = newUCC.counter[1][1]; } public void buildAverage(PerformanceCriterion performance) { super.buildAverage(performance); UniversalClassificationCriterion other = (UniversalClassificationCriterion)performance; if (this.type != other.type) throw new RuntimeException("Cannot build average of different error types ("+NAME[this.type]+"/"+NAME[other.type]+")."); if (this.positive != other.positive) throw new RuntimeException("Cannot build average for different positive classes ("+this.positive+"/"+other.positive+")."); this.counter[0][0] += other.counter[0][0]; this.counter[0][1] += other.counter[0][1]; this.counter[1][0] += other.counter[1][0]; this.counter[1][1] += other.counter[1][1]; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -