⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 confusionmatrix.java

📁 常用机器学习算法,java编写源代码,内含常用分类算法,包括说明文档
💻 JAVA
字号:
/* Copyright (C) 2002 Dept. of Computer Science, Univ. of Massachusetts, Amherst   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This program toolkit free software; you can redistribute it and/or   modify it under the terms of the GNU General Public License as   published by the Free Software Foundation; either version 2 of the   License, or (at your option) any later version.   This program is distributed in the hope that it will be useful, but   WITHOUT ANY WARRANTY; without even the implied warranty of   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  For more   details see the GNU General Public License and the file README-LEGAL.   You should have received a copy of the GNU General Public License   along with this program; if not, write to the Free Software   Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA   02111-1307, USA. *//**    @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a> */package edu.umass.cs.mallet.base.classify.evaluate;import edu.umass.cs.mallet.base.classify.Classification;import edu.umass.cs.mallet.base.classify.Trial;import edu.umass.cs.mallet.base.types.Labeling;import edu.umass.cs.mallet.base.types.LabelVector;import edu.umass.cs.mallet.base.types.LabelAlphabet;import edu.umass.cs.mallet.base.types.Instance;import edu.umass.cs.mallet.base.types.Label;import edu.umass.cs.mallet.base.types.MatrixOps;import edu.umass.cs.mallet.base.util.MalletLogger;import java.util.ArrayList;import java.util.HashMap;import java.util.logging.*;import java.text.*;/** * Calculates and prints confusion matrix, accuracy, * and precision for a given clasification trial. */public class ConfusionMatrix{	private static Logger logger = MalletLogger.getLogger(ConfusionMatrix.class.getName());		int numClasses;  /**   * the list of classifications from the trial   */	ArrayList classifications;	/**	 * 2-d confiusion matrix	 */	int[][] values;	Trial trial;	/**	 * Constructs matrix and calculates values	 * @param t the trial to build matrix from	 */	public ConfusionMatrix(Trial t)	{		this.trial = t;		this.classifications = t.toArrayList();		Labeling tempLabeling =			((Classification)classifications.get(0)).getLabeling();		this.numClasses = tempLabeling.getLabelAlphabet().size();		values = new int[numClasses][numClasses];		for(int i=0; i < classifications.size(); i++)		{			LabelVector lv =				((Classification)classifications.get(i)).getLabelVector();			Instance inst = ((Classification)classifications.get(i)).getInstance();			int bestIndex = lv.getBestIndex();			int correctIndex = inst.getLabeling().getBestIndex();			assert(correctIndex != -1);			//System.out.println("Best index="+bestIndex+". Correct="+correctIndex);			values[correctIndex][bestIndex]++;		}				}	/** Return the count at row i (true) , column j (predicted) */	 double value(int i, int j) 	 {	    assert(i >= 0 && j >= 0 && i < numClasses && j < numClasses);	    return values[i][j];	    	 }	static private void appendJustifiedInt (StringBuffer sb, int i, boolean zeroDot) {		if (i < 100) sb.append (' ');		if (i < 10) sb.append (' ');		if (i == 0 && zeroDot)			sb.append (".");		else			sb.append (""+i);	}	public String toString () {		StringBuffer sb = new StringBuffer ();		int maxLabelNameLength = 0;		LabelAlphabet labelAlphabet = trial.getClassifier().getLabelAlphabet();		for (int i = 0; i < numClasses; i++) {			int len = labelAlphabet.lookupLabel(i).toString().length();			if (maxLabelNameLength < len)				maxLabelNameLength = len;		}		sb.append ("Confusion Matrix, row=true, column=predicted  accuracy="+trial.accuracy()+"\n");		for (int i = 0; i < maxLabelNameLength-5+4; i++) sb.append (' ');		sb.append ("label");		for (int c2 = 0; c2 < Math.min(10,numClasses); c2++)	sb.append ("   "+c2);		for (int c2 = 10; c2 < numClasses; c2++)	sb.append ("  "+c2);		sb.append ("  |total\n");		for (int c = 0; c < numClasses; c++) {			appendJustifiedInt (sb, c, false);			String labelName = labelAlphabet.lookupLabel(c).toString();			for (int i = 0; i < maxLabelNameLength-labelName.length(); i++) sb.append (' ');			sb.append (" "+labelName+" ");			for (int c2 = 0; c2 < numClasses; c2++) {				appendJustifiedInt (sb, values[c][c2], true);				sb.append (' ');			}			sb.append (" |"+ MatrixOps.sum(values[c]));			sb.append ('\n');		}		return sb.toString();	}		/**	 * Returns the precision of this predicted class	 */	public double getPrecision (int predictedClassIndex)	{		int total = 0;		for (int trueClassIndex=0; trueClassIndex < this.numClasses; trueClassIndex++) {			total += values[trueClassIndex][predictedClassIndex];		}		if (total == 0)			return 0.0;		else			return (double) (values[predictedClassIndex][predictedClassIndex]) / total;	}		/**	 * Returns percent of time that class2 is true class when 	 * class1 is predicted class	 * 	 */	public double getConfusionBetween (int class1, int class2)	{		int total = 0;		for (int trueClassIndex=0; trueClassIndex < this.numClasses; trueClassIndex++) {			total += values[trueClassIndex][class1];		}		if (total == 0)			return 0.0;		else			return (double) (values[class2][class1]) / total;	    	}	/**	 * Returns the percentage of instances with	 * true label = classIndex	 */	public double getClassPrior (int classIndex)	{		int sum= 0;		for(int i=0; i < numClasses; i++) 			sum += values[classIndex][i];		return (double)sum / classifications.size();	}  /**	 * prints to stdout the confusion matrix,	 * class frequency, precision, and recall	 */	/*	public void print()	{		double totalPrecision = 0;		double totalRecall = 0;		double totalF1 = 0;		HashMap index2class = new HashMap();		LabelVector lv =			((Classification)classifications.get(0)).getLabelVector();		DecimalFormat df = new DecimalFormat("###.##");		int [] numInstances = new int[this.numClasses];		for(int i=0; i<this.numClasses; i++){			int count = 0;			for(int j=0; j<this.numClasses; j++)				count += values[i][j];			numInstances[i] = count;			String label = lv.labelAtLocation(i).toString();			System.out.println("index "+i+": "+label+												 " "+count+" instances "+												 df.format(100*(double)count/classifications.size())				+"%");			index2class.put (new Integer (i), label);		}		System.out.println("Confusion Matrix");		for(int i=0; i<this.numClasses; i++){			for(int j=0; j<this.numClasses; j++)				System.out.print(values[j][i]+"\t\t");			System.out.println("");		}		for(int i=0; i<this.numClasses; i++){			double recall = 100.0*(double)values[j][j]/numInstances[i];			double precision;			int rowCount = 0;			for(int j=0; j<this.numClasses; j++)				rowCount += values[j][i];			if (rowCount == 0)				precision = 0;			else				precision = 100.0*(double)values[j][j] / rowCount;			double f1;			if (precision + recall == 0.0)				f1 = 0;			else				f1 = 2 * precision * recall / (precision + recall);			System.out.println("Class " +												 (String)index2class.get(new Integer (i)));			System.out.println("F1="+df.format(f1)+"%");			System.out.println("Recall="+df.format(recall)+"%");			System.out.println("Precision="+df.format(precision)+"%");			totalPrecision += precision;			totalRecall += recall;			totalF1 += f1;		}				int numCorrect = 0;		int totalInstances = 0;		for(int i=0; i<this.numClasses; i++)		{			numCorrect += values[j][j];			totalInstances+=numInstances[i];		}		System.out.println("Overall Accuracy="+											 df.format(100.0*(double)numCorrect/totalInstances)+"%");		System.out.println ("Average F1: " + (totalF1 / this.numClasses) +												"\nAverage Precision: " +												(totalPrecision / this.numClasses) +												"\nAverage Recall: " + (totalRecall / this.numClasses));	}*/}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -