📄 multiclassificationperformance.java
字号:
/*
* YALE - Yet Another Learning Environment
* Copyright (C) 2001-2004
* Simon Fischer, Ralf Klinkenberg, Ingo Mierswa,
* Katharina Morik, Oliver Ritthoff
* Artificial Intelligence Unit
* Computer Science Department
* University of Dortmund
* 44221 Dortmund, Germany
* email: yale-team@lists.sourceforge.net
* web: http://yale.cs.uni-dortmund.de/
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License as
* published by the Free Software Foundation; either version 2 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
* USA.
*/
package edu.udo.cs.yale.operator.performance;
import edu.udo.cs.yale.tools.math.Averagable;
import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.tools.LogService;
import java.util.Collection;
import java.util.Iterator;
/** Measures the accuracy and classification error for both binary classification problems and multi class problems.
*
* @version $Id: MultiClassificationPerformance.java,v 2.4 2004/08/28 17:04:48 ingomierswa Exp $
*/
public class MultiClassificationPerformance extends MeasuredPerformance {
/** Indicates an undefined type (should not happen). */
public static final int UNDEFINED = -1;
/** Indicates accuracy. */
public static final int ACCURACY = 0;
/** Indicates classification error. */
public static final int ERROR = 1;
/** The names of the criteria. */
public static final String[] NAME = { "accuracy", "classification_error" };
/** The descriptions of the criteria. */
public static final String[] DESCRIPTION = {
"Relative number of correctly classified examples",
"Relative number of misclassified examples"
};
/** The counter for true labels and the prediction. First: true label, second: predicted label. */
private int[][] counter;
/** The class names of the label. Used for logging and result display. */
private String[] classNames;
/** The type of this performance: accuracy or classification error. */
private int type = ACCURACY;
/** Creates a MultiClassificationPerformance with undefined type. */
public MultiClassificationPerformance() { this(UNDEFINED); }
/** Creates a MultiClassificationPerformance with the given type. */
public MultiClassificationPerformance(int type) { this.type = type; }
/** Creates a MultiClassificationPerformance with the given type. */
public static MultiClassificationPerformance newInstance(String name) {
for (int i = 0; i < NAME.length; i++) {
if (NAME[i].equals(name)) return new MultiClassificationPerformance(i);
}
return null;
}
/** Initializes the criterion and sets the label. */
public void startCounting(ExampleSet eSet) {
Collection values = eSet.getLabel().getValues();
this.counter = new int[values.size()][values.size()];
this.classNames = new String[values.size()];
Iterator i = values.iterator();
int n = 0;
while (i.hasNext()) {
classNames[n] = (String)i.next();
n++;
}
}
/** Increases the prediction value in the matrix. */
public void countExample(Example example) {
int label = (int)example.getLabel();
int plabel = (int)example.getPredictedLabel();
counter[label][plabel]++;
}
/** Returns either the accuracy or the classification error. */
public double getValue() {
int x=0, y=0;
for (int i = 0; i < counter.length; i++) {
x += counter[i][i];
for (int j = 0; j < counter[i].length; j++)
y += counter[i][j];
}
if (y == 0) return Double.NaN;
// returns either the accuracy or the error
if (type == ACCURACY)
return (double)x/(double)y;
else
return (1.0d - (double)x/(double)y);
}
/** Returns true. */
public boolean formatPercent() { return true; }
public double getVariance() {
return Double.NaN;
}
/** Returns the name. */
public String getName() {
return NAME[type];
}
/** Returns the description. */
public String getDescription() {
return DESCRIPTION[type];
}
// ================================================================================
/** Returns the accuracy or 1 - error. */
public double getFitness() {
if (type == ACCURACY)
return getValue();
else
return (1.0d - getValue());
}
/** Returns 1. */
public double getMaxFitness() {
return 1.0d;
}
protected void cloneAveragable(Averagable newPC) {
MultiClassificationPerformance newMulti = (MultiClassificationPerformance)newPC;
this.type = newMulti.type;
this.classNames = new String[newMulti.classNames.length];
for (int i = 0; i < this.classNames.length; i++)
this.classNames[i] = newMulti.classNames[i];
this.counter = new int[newMulti.counter.length][newMulti.counter.length];
for (int i = 0; i < this.counter.length; i++)
for (int j = 0; j < this.counter[i].length; j++)
this.counter[i][j] = newMulti.counter[i][j];
}
public void buildAverage(Averagable performance) {
super.buildAverage(performance);
MultiClassificationPerformance other = (MultiClassificationPerformance)performance;
for (int i = 0; i < this.counter.length; i++)
for (int j = 0; j < this.counter[i].length; j++)
this.counter[i][j] += other.counter[i][j];
}
// ================================================================================
public String toString() {
StringBuffer result = new StringBuffer(super.toString());
result.append("\nConfusionMatrix:\nTrue:");
for (int i = 0; i < this.counter.length; i++)
result.append("\t" + classNames[i]);
for (int i = 0; i < this.counter.length; i++) {
result.append("\n" + classNames[i] + ":");
for (int j = 0; j < this.counter[i].length; j++) {
result.append("\t" + this.counter[j][i]);
}
}
result.append("\nper class:");
for (int i = 0; i < this.counter.length; i++) {
int total = 0;
for (int j = 0; j < this.counter[i].length; j++)
total += counter[i][j];
result.append("\t" + formatValue(counter[i][i] / (double)total));
}
return result.toString();
}
/** Returns a HTML table for the confusion matrix. */
public String toHTML() {
StringBuffer result = new StringBuffer(super.toString());
result.append("<table bgcolor=\"#E3D8C3\" border=\"1\"><tr bgcolor=\"#ccccff\"><td></td>");
for (int i = 0; i < this.counter.length; i++)
result.append("<td><b>true " + classNames[i] + "</b></td>");
result.append("</tr>");
for (int i = 0; i < this.counter.length; i++) {
result.append("<tr><td bgcolor=\"#ccccff\"><b>pred. " + classNames[i] + "</b></td>");
for (int j = 0; j < this.counter[i].length; j++) {
result.append("<td>" + this.counter[j][i] + "</td>");
}
result.append("</tr>");
}
result.append("<tr bgcolor=\"#ccccff\"><td><b>per class:</b></td>");
for (int i = 0; i < this.counter.length; i++) {
int total = 0;
for (int j = 0; j < this.counter[i].length; j++)
total += counter[i][j];
if (type == ACCURACY)
result.append("<td>" + formatValue(counter[i][i] / (double)total) + "</td>");
else
result.append("<td>" + formatValue(1.0d - counter[i][i] / (double)total) + "</td>");
}
result.append("</tr>");
result.append("</table>");
return result.toString();
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -