📄 evaluation.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* Evaluation.java
* Copyright (C) 1999 Eibe Frank,Len Trigg
*
*/
package weka.classifiers;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Reader;
import java.util.Enumeration;
import java.util.Random;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import weka.classifiers.functions.Logistic;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Range;
import weka.core.Summarizable;
import weka.core.Utils;
import weka.estimators.Estimator;
import weka.estimators.KernelEstimator;
/**
* Class for evaluating machine learning models. <p>
*
* ------------------------------------------------------------------- <p>
*
* General options when evaluating a learning scheme from the command-line: <p>
*
* -t filename <br>
* Name of the file with the training data. (required) <p>
*
* -T filename <br>
* Name of the file with the test data. If missing a cross-validation
* is performed. <p>
*
* -c index <br>
* Index of the class attribute (1, 2, ...; default: last). <p>
*
* -x number <br>
* The number of folds for the cross-validation (default: 10). <p>
*
* -s seed <br>
* Random number seed for the cross-validation (default: 1). <p>
*
* -m filename <br>
* The name of a file containing a cost matrix. <p>
*
* -l filename <br>
* Loads classifier from the given file. <p>
*
* -d filename <br>
* Saves classifier built from the training data into the given file. <p>
*
* -v <br>
* Outputs no statistics for the training data. <p>
*
* -o <br>
* Outputs statistics only, not the classifier. <p>
*
* -i <br>
* Outputs information-retrieval statistics per class. <p>
*
* -k <br>
* Outputs information-theoretic statistics. <p>
*
* -p range <br>
* Outputs predictions for test instances, along with the attributes in
* the specified range (and nothing else). Use '-p 0' if no attributes are
* desired. <p>
*
* -r <br>
* Outputs cumulative margin distribution (and nothing else). <p>
*
* -g <br>
* Only for classifiers that implement "Graphable." Outputs
* the graph representation of the classifier (and nothing
* else). <p>
*
* ------------------------------------------------------------------- <p>
*
* Example usage as the main of a classifier (called FunkyClassifier):
* <code> <pre>
* public static void main(String [] args) {
* try {
* Classifier scheme = new FunkyClassifier();
* System.out.println(Evaluation.evaluateModel(scheme, args));
* } catch (Exception e) {
* System.err.println(e.getMessage());
* }
* }
* </pre> </code>
* <p>
*
* ------------------------------------------------------------------ <p>
*
* Example usage from within an application:
* <code> <pre>
* Instances trainInstances = ... instances got from somewhere
* Instances testInstances = ... instances got from somewhere
* Classifier scheme = ... scheme got from somewhere
*
* Evaluation evaluation = new Evaluation(trainInstances);
* evaluation.evaluateModel(scheme, testInstances);
* System.out.println(evaluation.toSummaryString());
* </pre> </code>
*
*
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @author Len Trigg (trigg@cs.waikato.ac.nz)
* @version $Revision$
*/
public class Evaluation implements Summarizable {
/** The number of classes. */
protected int m_NumClasses;
/** The number of folds for a cross-validation. */
protected int m_NumFolds;
/** The weight of all incorrectly classified instances. */
protected double m_Incorrect;
/** The weight of all correctly classified instances. */
protected double m_Correct;
/** The weight of all unclassified instances. */
protected double m_Unclassified;
/*** The weight of all instances that had no class assigned to them. */
protected double m_MissingClass;
/** The weight of all instances that had a class assigned to them. */
protected double m_WithClass;
/** Array for storing the confusion matrix. */
protected double [][] m_ConfusionMatrix;
/** The names of the classes. */
protected String [] m_ClassNames;
/** Is the class nominal or numeric? */
protected boolean m_ClassIsNominal;
/** The prior probabilities of the classes */
protected double [] m_ClassPriors;
/** The sum of counts for priors */
protected double m_ClassPriorsSum;
/** The cost matrix (if given). */
protected CostMatrix m_CostMatrix;
/** The total cost of predictions (includes instance weights) */
protected double m_TotalCost;
/** Sum of errors. */
protected double m_SumErr;
/** Sum of absolute errors. */
protected double m_SumAbsErr;
/** Sum of squared errors. */
protected double m_SumSqrErr;
/** Sum of class values. */
protected double m_SumClass;
/** Sum of squared class values. */
protected double m_SumSqrClass;
/*** Sum of predicted values. */
protected double m_SumPredicted;
/** Sum of squared predicted values. */
protected double m_SumSqrPredicted;
/** Sum of predicted * class values. */
protected double m_SumClassPredicted;
/** Sum of absolute errors of the prior */
protected double m_SumPriorAbsErr;
/** Sum of absolute errors of the prior */
protected double m_SumPriorSqrErr;
/** Total Kononenko & Bratko Information */
protected double m_SumKBInfo;
/*** Resolution of the margin histogram */
protected static int k_MarginResolution = 500;
/** Cumulative margin distribution */
protected double m_MarginCounts [];
/** Number of non-missing class training instances seen */
protected int m_NumTrainClassVals;
/** Array containing all numeric training class values seen */
protected double [] m_TrainClassVals;
/** Array containing all numeric training class weights */
protected double [] m_TrainClassWeights;
/** Numeric class error estimator for prior */
protected Estimator m_PriorErrorEstimator;
/** Numeric class error estimator for scheme */
protected Estimator m_ErrorEstimator;
/**
* The minimum probablility accepted from an estimator to avoid
* taking log(0) in Sf calculations.
*/
protected static final double MIN_SF_PROB = Double.MIN_VALUE;
/** Total entropy of prior predictions */
protected double m_SumPriorEntropy;
/** Total entropy of scheme predictions */
protected double m_SumSchemeEntropy;
//<<30/01/2005, Frank J. Xu
//For calculating decision tree classifier evaluation measure.
/**Corretly classified records with confidence larger than specified value.*/
protected int [] m_CorrectNumWithSpecifiedConf;
//30/01/2005, Frank J. Xu>>
/**
* Initializes all the counters for the evaluation.
*
* @param data set of training instances, to get some header
* information and prior class distribution information
* @exception Exception if the class is not defined
*/
public Evaluation(Instances data) throws Exception {
this(data, null);
}
/**
* Initializes all the counters for the evaluation and also takes a
* cost matrix as parameter.
*
* @param data set of instances, to get some header information
* @param costMatrix the cost matrix---if null, default costs will be used
* @exception Exception if cost matrix is not compatible with
* data, the class is not defined or the class is numeric
*/
public Evaluation(Instances data, CostMatrix costMatrix)
throws Exception {
m_NumClasses = data.numClasses();
m_NumFolds = 1;
m_ClassIsNominal = data.classAttribute().isNominal();
if (m_ClassIsNominal) {
m_ConfusionMatrix = new double [m_NumClasses][m_NumClasses];
m_ClassNames = new String [m_NumClasses];
m_CorrectNumWithSpecifiedConf = new int[m_NumClasses];
for(int i = 0; i < m_NumClasses; i++) {
m_ClassNames[i] = data.classAttribute().value(i);
//<<Frank J. Xu, 30/01/2005
//For calculating decision tree classifier evaluation measure.
m_CorrectNumWithSpecifiedConf[i] = 0;
//Frank J. Xu, 30/01/2005>>
}
}
m_CostMatrix = costMatrix;
if (m_CostMatrix != null) {
if (!m_ClassIsNominal) {
throw new Exception("Class has to be nominal if cost matrix " +
"given!");
}
if (m_CostMatrix.size() != m_NumClasses) {
throw new Exception("Cost matrix not compatible with data!");
}
}
m_ClassPriors = new double [m_NumClasses];
setPriors(data);
m_MarginCounts = new double [k_MarginResolution + 1];
}
/**
* Returns a copy of the confusion matrix.
*
* @return a copy of the confusion matrix as a two-dimensional array
*/
public double[][] confusionMatrix() {
double[][] newMatrix = new double[m_ConfusionMatrix.length][0];
for (int i = 0; i < m_ConfusionMatrix.length; i++) {
newMatrix[i] = new double[m_ConfusionMatrix[i].length];
System.arraycopy(m_ConfusionMatrix[i], 0, newMatrix[i], 0,
m_ConfusionMatrix[i].length);
}
return newMatrix;
}
/**
* Performs a (stratified if class is nominal) cross-validation
* for a classifier on a set of instances.
*
* @param classifier the classifier with any options set.
* @param data the data on which the cross-validation is to be
* performed
* @param numFolds the number of folds for the cross-validation
* @param random random number generator for randomization
* @exception Exception if a classifier could not be generated
* successfully or the class is not defined
*/
public void crossValidateModel(Classifier classifier,
Instances data, int numFolds, Random random)
throws Exception {
// Make a copy of the data we can reorder
data = new Instances(data);
data.randomize(random);
if (data.classAttribute().isNominal()) {
data.stratify(numFolds);
}
// Do the folds
for (int i = 0; i < numFolds; i++) {
Instances train = data.trainCV(numFolds, i, random);
setPriors(train);
classifier.buildClassifier(train);
Instances test = data.testCV(numFolds, i);
evaluateModel(classifier, test);
}
m_NumFolds = numFolds;
}
/**
* Performs a (stratified if class is nominal) cross-validation
* for a classifier on a set of instances.
*
* @param classifier a string naming the class of the classifier
* @param data the data on which the cross-validation is to be
* performed
* @param numFolds the number of folds for the cross-validation
* @param options the options to the classifier. Any options
* @param random the random number generator for randomizing the data
* accepted by the classifier will be removed from this array.
* @exception Exception if a classifier could not be generated
* successfully or the class is not defined
*/
public void crossValidateModel(String classifierString,
Instances data, int numFolds,
String[] options, Random random)
throws Exception {
crossValidateModel(Classifier.forName(classifierString, options),
data, numFolds, random);
}
/**
* Evaluates a classifier with the options given in an array of
* strings. <p>
*
* Valid options are: <p>
*
* -t filename <br>
* Name of the file with the training data. (required) <p>
*
* -T filename <br>
* Name of the file with the test data. If missing a cross-validation
* is performed. <p>
*
* -c index <br>
* Index of the class attribute (1, 2, ...; default: last). <p>
*
* -x number <br>
* The number of folds for the cross-validation (default: 10). <p>
*
* -s seed <br>
* Random number seed for the cross-validation (default: 1). <p>
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -