📄 evaluation.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * Evaluation.java * Copyright (C) 1999 Eibe Frank,Len Trigg * */package weka.classifiers;import java.util.*;import java.io.*;import weka.core.*;import weka.estimators.*;import java.util.zip.GZIPInputStream;import java.util.zip.GZIPOutputStream;/** * Class for evaluating machine learning models. <p> * * ------------------------------------------------------------------- <p> * * General options when evaluating a learning scheme from the command-line: <p> * * -t filename <br> * Name of the file with the training data. (required) <p> * * -T filename <br> * Name of the file with the test data. If missing a cross-validation * is performed. <p> * * -c index <br> * Index of the class attribute (1, 2, ...; default: last). <p> * * -x number <br> * The number of folds for the cross-validation (default: 10). <p> * * -s seed <br> * Random number seed for the cross-validation (default: 1). <p> * * -m filename <br> * The name of a file containing a cost matrix. <p> * * -l filename <br> * Loads classifier from the given file. <p> * * -d filename <br> * Saves classifier built from the training data into the given file. <p> * * -v <br> * Outputs no statistics for the training data. <p> * * -o <br> * Outputs statistics only, not the classifier. <p> * * -i <br> * Outputs information-retrieval statistics per class. <p> * * -k <br> * Outputs information-theoretic statistics. <p> * * -p range <br> * Outputs predictions for test instances, along with the attributes in * the specified range (and nothing else). Use '-p 0' if no attributes are * desired. <p> * * -r <br> * Outputs cumulative margin distribution (and nothing else). <p> * * -g <br> * Only for classifiers that implement "Graphable." Outputs * the graph representation of the classifier (and nothing * else). <p> * * ------------------------------------------------------------------- <p> * * Example usage as the main of a classifier (called FunkyClassifier): * <code> <pre> * public static void main(String [] args) { * try { * Classifier scheme = new FunkyClassifier(); * System.out.println(Evaluation.evaluateModel(scheme, args)); * } catch (Exception e) { * System.err.println(e.getMessage()); * } * } * </pre> </code> * <p> * * ------------------------------------------------------------------ <p> * * Example usage from within an application: * <code> <pre> * Instances trainInstances = ... instances got from somewhere * Instances testInstances = ... instances got from somewhere * Classifier scheme = ... scheme got from somewhere * * Evaluation evaluation = new Evaluation(trainInstances); * evaluation.evaluateModel(scheme, testInstances); * System.out.println(evaluation.toSummaryString()); * </pre> </code> * * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @author Len Trigg (trigg@cs.waikato.ac.nz) * @version $Revision: 1.1.1.1 $ */public class Evaluation implements Summarizable { /** The number of classes. */ private int m_NumClasses; /** The number of folds for a cross-validation. */ private int m_NumFolds; /** The weight of all incorrectly classified instances. */ private double m_Incorrect; /** The weight of all correctly classified instances. */ private double m_Correct; /** The weight of all unclassified instances. */ private double m_Unclassified; /*** The weight of all instances that had no class assigned to them. */ private double m_MissingClass; /** The weight of all instances that had a class assigned to them. */ private double m_WithClass; /** Array for storing the confusion matrix. */ private double [][] m_ConfusionMatrix; /** The names of the classes. */ private String [] m_ClassNames; /** Is the class nominal or numeric? */ private boolean m_ClassIsNominal; /** The prior probabilities of the classes */ private double [] m_ClassPriors; /** The sum of counts for priors */ private double m_ClassPriorsSum; /** The cost matrix (if given). */ private CostMatrix m_CostMatrix; /** The total cost of predictions (includes instance weights) */ private double m_TotalCost; /** Sum of errors. */ private double m_SumErr; /** Sum of absolute errors. */ private double m_SumAbsErr; /** Sum of squared errors. */ private double m_SumSqrErr; /** Sum of class values. */ private double m_SumClass; /** Sum of squared class values. */ private double m_SumSqrClass; /*** Sum of predicted values. */ private double m_SumPredicted; /** Sum of squared predicted values. */ private double m_SumSqrPredicted; /** Sum of predicted * class values. */ private double m_SumClassPredicted; /** Sum of absolute errors of the prior */ private double m_SumPriorAbsErr; /** Sum of absolute errors of the prior */ private double m_SumPriorSqrErr; /** Total Kononenko & Bratko Information */ private double m_SumKBInfo; /*** Resolution of the margin histogram */ private static int k_MarginResolution = 500; /** Cumulative margin distribution */ private double m_MarginCounts []; /** Number of non-missing class training instances seen */ private int m_NumTrainClassVals; /** Array containing all numeric training class values seen */ private double [] m_TrainClassVals; /** Array containing all numeric training class weights */ private double [] m_TrainClassWeights; /** Numeric class error estimator for prior */ private Estimator m_PriorErrorEstimator; /** Numeric class error estimator for scheme */ private Estimator m_ErrorEstimator; /** * The minimum probablility accepted from an estimator to avoid * taking log(0) in Sf calculations. */ private static final double MIN_SF_PROB = Double.MIN_VALUE; /** Total entropy of prior predictions */ private double m_SumPriorEntropy; /** Total entropy of scheme predictions */ private double m_SumSchemeEntropy; /** * Initializes all the counters for the evaluation. * * @param data set of training instances, to get some header * information and prior class distribution information * @exception Exception if the class is not defined */ public Evaluation(Instances data) throws Exception { this(data, null); } /** * Initializes all the counters for the evaluation and also takes a * cost matrix as parameter. * * @param data set of instances, to get some header information * @param costMatrix the cost matrix---if null, default costs will be used * @exception Exception if cost matrix is not compatible with * data, the class is not defined or the class is numeric */ public Evaluation(Instances data, CostMatrix costMatrix) throws Exception { m_NumClasses = data.numClasses(); m_NumFolds = 1; m_ClassIsNominal = data.classAttribute().isNominal(); if (m_ClassIsNominal) { m_ConfusionMatrix = new double [m_NumClasses][m_NumClasses]; m_ClassNames = new String [m_NumClasses]; for(int i = 0; i < m_NumClasses; i++) { m_ClassNames[i] = data.classAttribute().value(i); } } m_CostMatrix = costMatrix; if (m_CostMatrix != null) { if (!m_ClassIsNominal) { throw new Exception("Class has to be nominal if cost matrix " + "given!"); } if (m_CostMatrix.size() != m_NumClasses) { throw new Exception("Cost matrix not compatible with data!"); } } m_ClassPriors = new double [m_NumClasses]; setPriors(data); m_MarginCounts = new double [k_MarginResolution + 1]; } /** * Returns a copy of the confusion matrix. * * @return a copy of the confusion matrix as a two-dimensional array */ public double[][] confusionMatrix() { double[][] newMatrix = new double[m_ConfusionMatrix.length][0]; for (int i = 0; i < m_ConfusionMatrix.length; i++) { newMatrix[i] = new double[m_ConfusionMatrix[i].length]; System.arraycopy(m_ConfusionMatrix[i], 0, newMatrix[i], 0, m_ConfusionMatrix[i].length); } return newMatrix; } /** * Performs a (stratified if class is nominal) cross-validation * for a classifier on a set of instances. * * @param classifier the classifier with any options set. * @param data the data on which the cross-validation is to be * performed * @param numFolds the number of folds for the cross-validation * @exception Exception if a classifier could not be generated * successfully or the class is not defined */ public void crossValidateModel(Classifier classifier, Instances data, int numFolds) throws Exception { // Make a copy of the data we can reorder data = new Instances(data); if (data.classAttribute().isNominal()) { data.stratify(numFolds); } // Do the folds for (int i = 0; i < numFolds; i++) { Instances train = data.trainCV(numFolds, i); setPriors(train); classifier.buildClassifier(train); Instances test = data.testCV(numFolds, i); evaluateModel(classifier, test); } m_NumFolds = numFolds; } /** * Performs a (stratified if class is nominal) cross-validation * for a classifier on a set of instances. * * @param classifier a string naming the class of the classifier * @param data the data on which the cross-validation is to be * performed * @param numFolds the number of folds for the cross-validation * @param options the options to the classifier. Any options * accepted by the classifier will be removed from this array. * @exception Exception if a classifier could not be generated * successfully or the class is not defined */ public void crossValidateModel(String classifierString, Instances data, int numFolds, String[] options) throws Exception { crossValidateModel(Classifier.forName(classifierString, options), data, numFolds); } /** * Evaluates a classifier with the options given in an array of * strings. <p> * * Valid options are: <p> * * -t filename <br> * Name of the file with the training data. (required) <p> * * -T filename <br> * Name of the file with the test data. If missing a cross-validation * is performed. <p> * * -c index <br> * Index of the class attribute (1, 2, ...; default: last). <p> * * -x number <br> * The number of folds for the cross-validation (default: 10). <p> * * -s seed <br> * Random number seed for the cross-validation (default: 1). <p> * * -m filename <br> * The name of a file containing a cost matrix. <p> * * -l filename <br> * Loads classifier from the given file. <p> * * -d filename <br> * Saves classifier built from the training data into the given file. <p> * * -v <br> * Outputs no statistics for the training data. <p> * * -o <br> * Outputs statistics only, not the classifier. <p> * * -i <br> * Outputs detailed information-retrieval statistics per class. <p> * * -k <br> * Outputs information-theoretic statistics. <p> * * -p range <br> * Outputs predictions for test instances, along with the attributes in * the specified range (and nothing else). Use '-p 0' if no attributes are * desired. <p> * * -r <br> * Outputs cumulative margin distribution (and nothing else). <p> * * -g <br> * Only for classifiers that implement "Graphable." Outputs * the graph representation of the classifier (and nothing * else). <p> * * @param classifierString class of machine learning classifier as a string * @param options the array of string containing the options * @exception Exception if model could not be evaluated successfully * @return a string describing the results */ public static String evaluateModel(String classifierString, String [] options) throws Exception { Classifier classifier; // Create classifier try { classifier = (Classifier)Class.forName(classifierString).newInstance(); } catch (Exception e) { throw new Exception("Can't find class with name " + classifierString + '.'); } return evaluateModel(classifier, options); } /** * A test method for this class. Just extracts the first command line * argument as a classifier class name and calls evaluateModel. * @param args an array of command line arguments, the first of which
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -