evaluation.java
来自「Java 编写的多种数据挖掘算法 包括聚类、分类、预处理等」· Java 代码 · 共 2,078 行 · 第 1/5 页
JAVA
2,078 行
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * Evaluation.java * Copyright (C) 1999 Eibe Frank,Len Trigg * */package weka.classifiers;import java.util.*;import java.io.*;import weka.classifiers.xml.XMLClassifier;import weka.classifiers.evaluation.NominalPrediction;import weka.classifiers.evaluation.ThresholdCurve;import weka.core.*;import weka.core.xml.KOML;import weka.core.xml.XMLOptions;import weka.core.xml.XMLSerialization;import weka.estimators.*;import java.util.zip.GZIPInputStream;import java.util.zip.GZIPOutputStream;/** * Class for evaluating machine learning models. <p> * * ------------------------------------------------------------------- <p> * * General options when evaluating a learning scheme from the command-line: <p> * * -t filename <br> * Name of the file with the training data. (required) <p> * * -T filename <br> * Name of the file with the test data. If missing a cross-validation * is performed. <p> * * -c index <br> * Index of the class attribute (1, 2, ...; default: last). <p> * * -x number <br> * The number of folds for the cross-validation (default: 10). <p> * * -s seed <br> * Random number seed for the cross-validation (default: 1). <p> * * -m filename <br> * The name of a file containing a cost matrix. <p> * * -l filename <br> * Loads classifier from the given file. In case the filename ends with ".xml" * the options are loaded from XML. <p> * * -d filename <br> * Saves classifier built from the training data into the given file. In case * the filename ends with ".xml" the options are saved XML, not the model. <p> * * -v <br> * Outputs no statistics for the training data. <p> * * -o <br> * Outputs statistics only, not the classifier. <p> * * -i <br> * Outputs information-retrieval statistics per class. <p> * * -k <br> * Outputs information-theoretic statistics. <p> * * -p range <br> * Outputs predictions for test instances, along with the attributes in * the specified range (and nothing else). Use '-p 0' if no attributes are * desired. <p> * * -r <br> * Outputs cumulative margin distribution (and nothing else). <p> * * -g <br> * Only for classifiers that implement "Graphable." Outputs * the graph representation of the classifier (and nothing * else). <p> * * -xml filename | xml-string <br> * Retrieves the options from the XML-data instead of the command line. <p> * * ------------------------------------------------------------------- <p> * * Example usage as the main of a classifier (called FunkyClassifier): * <code> <pre> * public static void main(String [] args) { * try { * Classifier scheme = new FunkyClassifier(); * System.out.println(Evaluation.evaluateModel(scheme, args)); * } catch (Exception e) { * System.err.println(e.getMessage()); * } * } * </pre> </code> * <p> * * ------------------------------------------------------------------ <p> * * Example usage from within an application: * <code> <pre> * Instances trainInstances = ... instances got from somewhere * Instances testInstances = ... instances got from somewhere * Classifier scheme = ... scheme got from somewhere * * Evaluation evaluation = new Evaluation(trainInstances); * evaluation.evaluateModel(scheme, testInstances); * System.out.println(evaluation.toSummaryString()); * </pre> </code> * * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @author Len Trigg (trigg@cs.waikato.ac.nz) * @version $Revision: 1.62 $ */public class Evaluation implements Summarizable { /** The number of classes. */ protected int m_NumClasses; /** The number of folds for a cross-validation. */ protected int m_NumFolds; /** The weight of all incorrectly classified instances. */ protected double m_Incorrect; /** The weight of all correctly classified instances. */ protected double m_Correct; /** The weight of all unclassified instances. */ protected double m_Unclassified; /*** The weight of all instances that had no class assigned to them. */ protected double m_MissingClass; /** The weight of all instances that had a class assigned to them. */ protected double m_WithClass; /** Array for storing the confusion matrix. */ protected double [][] m_ConfusionMatrix; /** The names of the classes. */ protected String [] m_ClassNames; /** Is the class nominal or numeric? */ protected boolean m_ClassIsNominal; /** The prior probabilities of the classes */ protected double [] m_ClassPriors; /** The sum of counts for priors */ protected double m_ClassPriorsSum; /** The cost matrix (if given). */ protected CostMatrix m_CostMatrix; /** The total cost of predictions (includes instance weights) */ protected double m_TotalCost; /** Sum of errors. */ protected double m_SumErr; /** Sum of absolute errors. */ protected double m_SumAbsErr; /** Sum of squared errors. */ protected double m_SumSqrErr; /** Sum of class values. */ protected double m_SumClass; /** Sum of squared class values. */ protected double m_SumSqrClass; /*** Sum of predicted values. */ protected double m_SumPredicted; /** Sum of squared predicted values. */ protected double m_SumSqrPredicted; /** Sum of predicted * class values. */ protected double m_SumClassPredicted; /** Sum of absolute errors of the prior */ protected double m_SumPriorAbsErr; /** Sum of absolute errors of the prior */ protected double m_SumPriorSqrErr; /** Total Kononenko & Bratko Information */ protected double m_SumKBInfo; /*** Resolution of the margin histogram */ protected static int k_MarginResolution = 500; /** Cumulative margin distribution */ protected double m_MarginCounts []; /** Number of non-missing class training instances seen */ protected int m_NumTrainClassVals; /** Array containing all numeric training class values seen */ protected double [] m_TrainClassVals; /** Array containing all numeric training class weights */ protected double [] m_TrainClassWeights; /** Numeric class error estimator for prior */ protected Estimator m_PriorErrorEstimator; /** Numeric class error estimator for scheme */ protected Estimator m_ErrorEstimator; /** * The minimum probablility accepted from an estimator to avoid * taking log(0) in Sf calculations. */ protected static final double MIN_SF_PROB = Double.MIN_VALUE; /** Total entropy of prior predictions */ protected double m_SumPriorEntropy; /** Total entropy of scheme predictions */ protected double m_SumSchemeEntropy; /** The list of predictions that have been generated (for computing AUC) */ private FastVector m_Predictions; /** * Initializes all the counters for the evaluation. * * @param data set of training instances, to get some header * information and prior class distribution information * @exception Exception if the class is not defined */ public Evaluation(Instances data) throws Exception { this(data, null); } /** * Initializes all the counters for the evaluation and also takes a * cost matrix as parameter. * * @param data set of instances, to get some header information * @param costMatrix the cost matrix---if null, default costs will be used * @exception Exception if cost matrix is not compatible with * data, the class is not defined or the class is numeric */ public Evaluation(Instances data, CostMatrix costMatrix) throws Exception { m_NumClasses = data.numClasses(); m_NumFolds = 1; m_ClassIsNominal = data.classAttribute().isNominal(); if (m_ClassIsNominal) { m_ConfusionMatrix = new double [m_NumClasses][m_NumClasses]; m_ClassNames = new String [m_NumClasses]; for(int i = 0; i < m_NumClasses; i++) { m_ClassNames[i] = data.classAttribute().value(i); } } m_CostMatrix = costMatrix; if (m_CostMatrix != null) { if (!m_ClassIsNominal) { throw new Exception("Class has to be nominal if cost matrix " + "given!"); } if (m_CostMatrix.size() != m_NumClasses) { throw new Exception("Cost matrix not compatible with data!"); } } m_ClassPriors = new double [m_NumClasses]; setPriors(data); m_MarginCounts = new double [k_MarginResolution + 1]; } /** * Returns the area under ROC for those predictions that have been collected * in the evaluateClassifier(Classifier, Instances) method. Returns * Instance.missingValue() if the area is not available. * * @param classIndex the index of the class to consider as "positive" * @return the area under the ROC curve or not a number */ public double areaUnderROC(int classIndex) { // Check if any predictions have been collected if (m_Predictions == null) { return Instance.missingValue(); } else { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(m_Predictions, classIndex); return tc.getROCArea(result); } } /** * Returns a copy of the confusion matrix. * * @return a copy of the confusion matrix as a two-dimensional array */ public double[][] confusionMatrix() { double[][] newMatrix = new double[m_ConfusionMatrix.length][0]; for (int i = 0; i < m_ConfusionMatrix.length; i++) { newMatrix[i] = new double[m_ConfusionMatrix[i].length]; System.arraycopy(m_ConfusionMatrix[i], 0, newMatrix[i], 0, m_ConfusionMatrix[i].length); } return newMatrix; } /** * Performs a (stratified if class is nominal) cross-validation * for a classifier on a set of instances. Now performs * a deep copy of the classifier before each call to * buildClassifier() (just in case the classifier is not * initialized properly). * * @param classifier the classifier with any options set. * @param data the data on which the cross-validation is to be * performed * @param numFolds the number of folds for the cross-validation * @param random random number generator for randomization * @exception Exception if a classifier could not be generated * successfully or the class is not defined */ public void crossValidateModel(Classifier classifier, Instances data, int numFolds, Random random) throws Exception { // Make a copy of the data we can reorder data = new Instances(data); data.randomize(random); if (data.classAttribute().isNominal()) { data.stratify(numFolds); } // Do the folds for (int i = 0; i < numFolds; i++) { Instances train = data.trainCV(numFolds, i, random); setPriors(train); Classifier copiedClassifier = Classifier.makeCopy(classifier); copiedClassifier.buildClassifier(train); Instances test = data.testCV(numFolds, i); evaluateModel(copiedClassifier, test); } m_NumFolds = numFolds; } /** * Performs a (stratified if class is nominal) cross-validation * for a classifier on a set of instances. * * @param classifier a string naming the class of the classifier * @param data the data on which the cross-validation is to be * performed * @param numFolds the number of folds for the cross-validation * @param options the options to the classifier. Any options * @param random the random number generator for randomizing the data * accepted by the classifier will be removed from this array. * @exception Exception if a classifier could not be generated * successfully or the class is not defined */ public void crossValidateModel(String classifierString, Instances data, int numFolds, String[] options, Random random) throws Exception { crossValidateModel(Classifier.forName(classifierString, options), data, numFolds, random); } /** * Evaluates a classifier with the options given in an array of * strings. <p> * * Valid options are: <p> * * -t filename <br> * Name of the file with the training data. (required) <p> * * -T filename <br> * Name of the file with the test data. If missing a cross-validation * is performed. <p> * * -c index <br> * Index of the class attribute (1, 2, ...; default: last). <p> * * -x number <br> * The number of folds for the cross-validation (default: 10). <p> * * -s seed <br> * Random number seed for the cross-validation (default: 1). <p> * * -m filename <br> * The name of a file containing a cost matrix. <p>
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?