⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 grapher.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    Graph.java *    Copyright (C) 2002 Raymond J. Mooney * */package weka.experiment;import java.util.*;import java.io.*;import weka.core.*;/** * Class for producing performance graphs for any metric from learning curve results. * Currently supports gnuplot format with various types of error bars */public class Grapher {    /** Experimental result data in arff format */    protected Instances data;    /** Names of datasets in data */    protected String[] datasets;    /** Map from scheme + options name to result data in the form of an	array of Stats's, one for each learning curve point in points */    protected HashMap schemeMap;    /** Ordered array of points on learning in number of training examples */    protected int[] points;    /** Name of original file of experimental result data in arff format */    protected String arffFileName;    /** The name of the performance metric to plot */    public String metric = "Percent_correct";    /** Set if desire error bars of particular type */    public short errorBars = NONE;    /** errorBar value for no error bars */    public static short NONE = 0;    /** errorBar value for error bars using standard deviations */    public static short STD_DEV = 1;    /** errorBar value for error bars using 95% confidence intervals */    public static short CONF_INF = 2;    /** errorBar value for error bars using min and max values */    public static short MIN_MAX = 3;    /** Set if desire error bars based on 95% confidence intervals */    public boolean confIntErrorBars = false;    /** The name of the dataset to plot performance for */    public String dataset;    /** Create an initial Grapher and load in data, names of datasets,     * and set of points on learning curve.     */    public Grapher (String arffFileName, short errorBars) throws Exception {	this.arffFileName = arffFileName;	this.errorBars = errorBars;	setData();	setDatasets();	setPoints();	dataset = datasets[0];    }    /** Load data for graph in from the given Experiment result file in arff format */    protected void setData () throws Exception {	data = new Instances (new BufferedReader(new FileReader(arffFileName)));    }    /** Set array of points on learning curve from Key_Total_instances values in data */    protected void setPoints() throws Exception {	Attribute attr = data.attribute("Key_Total_instances");	points = new int[attr.numValues()];	for (int i =0; i < points.length; i++) 	    points[i] = Integer.parseInt(attr.value(i));	Arrays.sort(points);    }        /** Set array of points on learning curve from Key_Dataset values in data */    protected void setDatasets() throws Exception {	Attribute attr = data.attribute("Key_Dataset");	datasets = new String[attr.numValues()];	for (int i =0; i < datasets.length; i++) 	    datasets[i] = attr.value(i);	}        /** Read in data for the current values of dataset and metric by indexing      *	for each scheme+options name an array of Stats objects for each point on the      *  learning curve */    protected void processData () throws Exception {	schemeMap = new HashMap();	// Go through each data line in the data	Enumeration enum = data.enumerateInstances();	while (enum.hasMoreElements()) {	    Instance inst = (Instance)enum.nextElement();	    // If this is not a line for the current dataset, skip it	    if (!inst.stringValue(data.attribute("Key_Dataset")).equals(dataset))		continue;	    // Get the full name of the scheme by concatenating the system	    // name and the set of system options 	    String name = inst.stringValue(data.attribute("Key_Scheme")) +		inst.stringValue(data.attribute("Key_Scheme_options"));	    // See if this scheme already has and Stats vector for points	    Stats[] pointsStats = (Stats[])schemeMap.get(name);	    if (pointsStats == null) {		// If not create one		pointsStats = new Stats[points.length];		schemeMap.put(name, pointsStats);	    }	    // Get the number of training instances for this line	    int point = Integer.parseInt(inst.stringValue(data.attribute("Key_Total_instances")));	    // Find the position in the array of points associated with this point	    int pointPos = Arrays.binarySearch(points, point);	    // Get the Stats performance metric object for this point	    Stats stats = pointsStats[pointPos];	    if (stats == null) {		// If there is none, create one		stats = new Stats();		pointsStats[pointPos] = stats;	    }	    Attribute metricAttr = data.attribute(metric);	    if (metricAttr == null) throw new Error("Unrecognized metric:" + metric);	    // Get the value of the performance metric for this line	    double metricValue = inst.value(metricAttr);	    // Add this value to the Stats object for this scheme and point	    // that keeps track of the running sum to eventually compute an average	    stats.add(metricValue);	}    }    /** Generate gnuplot files for plotting a learning curve for the current     *  dataset and metric.  Assumes a processData was last performed for     * this case dataset and metric */    public void gnuplot() throws Exception {	// Find min and max values of the performance metric	double yMin=Double.POSITIVE_INFINITY, yMax=Double.NEGATIVE_INFINITY;	// Iterate though each scheme and each of its plots points	Iterator schemeEntries = schemeMap.entrySet().iterator();	// Index of last point on the learning curve (this may differ	// for different datasets).	int last_point=-1, last_index=0;	while (schemeEntries.hasNext()) {	    Map.Entry schemeEntry = (Map.Entry)schemeEntries.next();	    Stats[] pointsStats = (Stats[])schemeEntry.getValue();	    for (int i=0; i < points.length; i++) {		// First calculate final mean and other summary stats				//PM		if(pointsStats[i]==null) continue;		// Keep track of which is the last point on the		// learning curve on this dataset		if(points[i]>last_point) {		    last_point = points[i];		    last_index = i;		}				pointsStats[i].calculateDerived();				if (pointsStats[i].mean < yMin)		    yMin = pointsStats[i].mean;		if (pointsStats[i].mean > yMax)		    yMax = pointsStats[i].mean;	    }	}	// Use result file name stem as a stem for plot files	String fileStem = removeFileExtension(arffFileName);	// Also include the name of the dataset in the plot-file stem if	// there is results for more than one dataset in this result file	if (datasets.length > 1)	    fileStem = fileStem + dataset;	String fileName = fileStem + "_" + metric + ".gplot";	// Create a file for the gnuplot	PrintWriter out = new PrintWriter(new FileWriter(fileName));	// Write proper gnuplot commands in this file	out.println("set xlabel \"Number of Training Examples\"");	out.println("set ylabel \"" + metric.replace('_', ' ') + "\"");	out.println("\nset terminal postscript color\nset size 0.75,0.75\n\nset data style linespoints");	// Move the key of curve names to the lower right corner, good for learning	// curves and train time plots that go from lower left to top right	out.println("set key " + 0.85 * points[last_index] + "," +		    (yMin + 0.25 * (yMax - yMin)));	out.print("\nplot ");	// For each scheme, add it to the plot command to plot this scheme's learning curve	// for the metric and create a data file for the average data for the learning curve points	schemeEntries = schemeMap.entrySet().iterator();	while (schemeEntries.hasNext()) {	    Map.Entry schemeEntry = (Map.Entry)schemeEntries.next();	    String scheme = cleanSchemeName((String)schemeEntry.getKey());	    Stats[] pointsStats = (Stats[])schemeEntry.getValue();	    // Create a data file for this scheme	    String dataFileName = fileStem + "_" + metric + "_" + scheme;	    out.print("'" + dataFileName + "' title \""  + scheme + "\"");	    if (errorBars != NONE) 		out.print(", '" + dataFileName + "' notitle with errorbars");	    if (schemeEntries.hasNext())		out.print(", ");	    PrintWriter dataOut = new PrintWriter(new FileWriter(dataFileName));	    // Write out a line for each data point on the learning curve for the metric	    for (int i=0; i <= last_index; i++) {		dataOut.print(points[i] + " " + pointsStats[i].mean);		// Add a third (and maybe fourth) entry for the error bar.		// Just a third indicates a delta about the mean, a third		// and fourth indicates a lower and upper bound		if (errorBars == STD_DEV) {		    dataOut.print(" " + pointsStats[i].stdDev);		}		else if (errorBars == CONF_INF) {		    // a 95% confidence interval is a delta of 1.96 standard deviations		    dataOut.print(" " + 1.96 * pointsStats[i].stdDev);		}		else if (errorBars == MIN_MAX) {		    dataOut.print(" " + pointsStats[i].min + " " + pointsStats[i].max);		}		dataOut.println("");			    }	    dataOut.close();	}	out.close();    }    /** Clean the name of a scheme to make it appropriate for a file name */    private String cleanSchemeName(String schemeName) {	return Utils.removeSubstring(schemeName, "weka.classifiers.").replace(' ','_');    }        /** Return the name of a file with the extension removed */    public static String removeFileExtension(String fileName) {	int pos = fileName.lastIndexOf(".");	if (pos == -1)	    return fileName;	else	    return fileName.substring(0,pos);    }    /** Produce a gnuplot for each dataset in the result file */    public void gnuplotAllDatasets () throws Exception{	for(int i =0; i < datasets.length; i++) {	    dataset = datasets[i];	    processData();	    gnuplot();		}    }        /** Create gnuplot graphs of  learning curves. The first argument should     * be the name of an arff file of experimental result for a learning curve experiment.     * If present, the second argument should be the name of a performance metric in     * result file to plot (which defaults to Percent_correct). Options are:     * <ul>     * <li> -s: Plot error bars of standard deviations.     * <li> -c: Plot error bars of 95% confidence intervals.     * <li> -m: Plot error bars of min and max values.     *</ul>     */    public static void main (String[] args) throws Exception {	int current = 0;	short errorBars = NONE;	if (args[current].equals("-s")){	    errorBars = STD_DEV;	    current++;	}	else if (args[current].equals("-c")){	    errorBars = CONF_INF;	    current++;	}	else if (args[current].equals("-m")){	    errorBars = MIN_MAX;	    current++;	}	Grapher grapher = new Grapher(args[current++],errorBars);	if (args.length > current)	    grapher.metric = args[current++];	grapher.gnuplotAllDatasets();    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -