⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 accuracycoverage.java

📁 常用机器学习算法,java编写源代码,内含常用分类算法,包括说明文档
💻 JAVA
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This software is provided under the terms of the Common Public License,   version 1.0, as published by http://www.opensource.org.  For further   information, see the file `LICENSE' included with this distribution. *//**    @author Aron Culotta <a href="mailto:culotta@cs.umass.edu">culotta@cs.umass.edu</a> */package edu.umass.cs.mallet.base.classify.evaluate;import edu.umass.cs.mallet.base.types.LabelVector;import edu.umass.cs.mallet.base.types.Instance;import edu.umass.cs.mallet.base.types.InstanceList;import edu.umass.cs.mallet.base.classify.Trial;import edu.umass.cs.mallet.base.classify.Classifier;import edu.umass.cs.mallet.base.classify.Classification;import edu.umass.cs.mallet.base.classify.Trial;import edu.umass.cs.mallet.base.classify.evaluate.GraphItem;import edu.umass.cs.mallet.base.util.MalletLogger;import edu.umass.cs.mallet.base.util.PrintUtilities;import java.awt.*;import java.awt.event.*;import javax.swing.*;import java.util.*;import java.util.logging.*;import java.text.DecimalFormat;/** * Methods for calculating and displaying the accuracy v. * coverage data for a Trial */public class AccuracyCoverage implements ActionListener{	private static Logger logger = MalletLogger.getLogger(AccuracyCoverage.class.getName());	static final int DEFAULT_NUM_BUCKETS = 20;	static final int DEFAULT_MAX_X = 100;	private ArrayList classifications;	private double [] accuracyValues;	private int numBuckets;	private double step;	private Graph2 graph;	private JFrame frame;		/**	 * Constructs object, sorts classifications, and creates	 * accuracyValues array     * @param t trial to get data from     * @param numBuckets number of x-axis measurements to find accuracy     */	public AccuracyCoverage(Trial t, int numBuckets, String title, String dataName)	{		this.classifications = new ArrayList(t.toArrayList());		this.numBuckets = numBuckets;		this.step = (double)DEFAULT_MAX_X/numBuckets;		this.accuracyValues = new double[numBuckets];		this.frame = null;		logger.info("Constructing AccCov with " + 											 this.classifications.size()); 		sortClassifications();/*		for(int i=0; i<classifications.size(); i++)		{			Classification c = (Classification)this.classifications.get(i);			LabelVector distr = c.getLabelVector();			System.out.println(distr.getBestValue());		}*/		createAccuracyArray();		this.graph = new Graph2(			title, 0, 100,			"Coverage", "Accuracy");		addDataToGraph(this.accuracyValues, numBuckets, dataName);	}	public AccuracyCoverage(Trial t, String title, String name)	{		this(t, DEFAULT_NUM_BUCKETS, title, name);	}	public AccuracyCoverage(Trial t, String title)	{		this(t, DEFAULT_NUM_BUCKETS, title, "unnamed");	}		public AccuracyCoverage(Classifier C, InstanceList ilist, String title)	{		this(new Trial(C, ilist), DEFAULT_NUM_BUCKETS, title, "unnamed");	}		public AccuracyCoverage(Classifier C, InstanceList ilist, int numBuckets, String title)	{		this(new Trial(C, ilist), numBuckets, title, "unnamed");	}		/**	 * Finds the "area under the acc/cov curve"	 * steps by one percentage point and calcs area	 * of trapezoid	 */	public double cumulativeAccuracy()	{		double area = 0.0;		for(int i=1; i<100; i++)		{			double leftAccuracy = accuracyAtCoverage((double)i/100);			double rightAccuracy = accuracyAtCoverage((double)(i+1)/100);			area += .5*(leftAccuracy + rightAccuracy);		}		return area;			}		/**	 * Creates array of accuracy values for coverage	 * at each step as defined by numBuckets.	 	 */	public void createAccuracyArray()	{//		System.out.println("Creating accuracyArray. Step= "+step);		for(int i=0 ; i<numBuckets; i++)		{			accuracyValues[i] =				accuracyAtCoverage(step													 *(double)(i+1)/100.0);		}	}		/**	 * accuracy at a given coverage percentage	 * @param cov coverage percentage	 * @return accuracy value	 */	public double accuracyAtCoverage(double cov)	{		assert(cov <= 1 && cov > 0);		int numTrials = (int)(Math.round((double)classifications.size()*cov));		int numCorrect = 0;//		System.out.println("NumTrials="+numTrials);		for(int i= classifications.size()-1; 				i >= classifications.size()-numTrials; i--)		{			Classification temp = (Classification)classifications.get(i); 			if(temp.bestLabelIsCorrect())		    numCorrect++;		}//		System.out.println("Accuracy at cov "+cov+" is "+		//(double)numCorrect/numTrials);		return((double)numCorrect/numTrials);	}		/**	 * Sort classifications ArrayList 	 * by winner's value	 */	public void sortClassifications()	{		Collections.sort(classifications, new  ClassificationComparator());	}			public void addDataToGraph(double [] accValues, int nBuckets, String name)	{		Vector values = new Vector(nBuckets);		for(int i=0; i<nBuckets; i++)		{			GraphItem temp = new GraphItem("",																		 (int)(accValues[i]*100),																		 Color.black);			values.add(temp);		}		logger.info("Sending "+values.size()+" elements to graph");		this.graph.addItemVector(values, name);	}	/** * Displays the accuracy v. coverage graph */	public void displayGraph()	{		Vector values = new Vector(this.numBuckets);		JButton printButton = new JButton("Print");	  frame = new JFrame("Graph");		DecimalFormat df = new DecimalFormat();		printButton.addActionListener(this);				frame.addWindowListener			(new WindowAdapter() 				{					public void windowClosing(WindowEvent e) 					{						System.exit(0);					}				}				);		// Get content pane		Container pane = frame.getContentPane();				// Set layout manager		pane.setLayout( new FlowLayout() );		assert(graph!= null); // make sure we've got data in the graph		// Add to pane		pane.add( graph );		pane.add( printButton );		frame.pack();				// Center the frame		Toolkit toolkit = Toolkit.getDefaultToolkit();				// Get the current screen size		Dimension scrnsize = toolkit.getScreenSize();				// Get the frame size		Dimension framesize= frame.getSize();				// Set X,Y location		frame.setLocation ( (int) (scrnsize.getWidth()															 - frame.getWidth() ) / 2 ,												(int) (scrnsize.getHeight()															 - frame.getHeight()) / 2);				frame.setVisible(true);	}			public void actionPerformed(ActionEvent event)	{		PrintUtilities.printComponent(graph);	}	public void addTrial(Trial t, String name)	{		addTrial(t, DEFAULT_NUM_BUCKETS, name);	}		public void addTrial(Trial t, int nBuckets, String name)	{		AccuracyCoverage newData = new AccuracyCoverage(t, nBuckets, "untitled", name);		double [] accValues = newData.accuracyValues();		addDataToGraph(accValues, nBuckets, name);	}	public double[] accuracyValues()	{		return this.accuracyValues;	}	public class ClassificationComparator implements Comparator	{		public final int compare (Object a, Object b)		{			LabelVector x = (LabelVector) (((Classification)a).getLabelVector());			LabelVector y = (LabelVector) (((Classification)b).getLabelVector());			double difference = x.getBestValue() - y.getBestValue();			int toReturn = 0;			if(difference > 0)				toReturn = 1;			else if (difference < 0)				toReturn = -1;			return(toReturn);				}			}	}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -