📄 sequentialevaluation.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * SequentialEvaluation.java * Copyright (C) 1999 Eibe Frank,Len Trigg * */package weka.classifiers;import java.util.*;import java.io.*;import weka.core.*;import weka.estimators.*;import java.util.zip.GZIPInputStream;import java.util.zip.GZIPOutputStream;import weka.classifiers.meta.*;/** * Class for evaluating sequential machine learning models. <p> * See classifiers/meta/RSW.java for more info. <p> * * ------------------------------------------------------------------- <p> * * General options when evaluating a learning scheme from the command-line: <p> * * -t filename <br> * Name of the file with the training data. (required) <p> * * -T filename <br> * Name of the file with the test data. If missing a cross-validation * is performed. <p> * * -c index <br> * Index of the class attribute (1, 2, ...; default: last). <p> * * -x number <br> * The number of folds for the cross-validation (default: 10). <p> * * -s seed <br> * Random number seed for the cross-validation (default: 1). <p> * * -m filename <br> * The name of a file containing a cost matrix. <p> * * -l filename <br> * Loads classifier from the given file. <p> * * -d filename <br> * Saves classifier built from the training data into the given file. <p> * * -v <br> * Outputs no statistics for the training data. <p> * * -o <br> * Outputs statistics only, not the classifier. <p> * * -i <br> * Outputs information-retrieval statistics per class. <p> * * -k <br> * Outputs information-theoretic statistics. <p> * * -p range <br> * Outputs predictions for test instances, along with the attributes in * the specified range (and nothing else). Use '-p 0' if no attributes are * desired. <p> * * -r <br> * Outputs cumulative margin distribution (and nothing else). <p> * * -g <br> * Only for classifiers that implement "Graphable." Outputs * the graph representation of the classifier (and nothing * else). <p> * * ------------------------------------------------------------------- <p> * * Example usage as the main of a classifier (called FunkyClassifier): * <code> <pre> * public static void main(String [] args) { * try { * Classifier scheme = new FunkyClassifier(); * System.out.println(Evaluation.evaluateModel(scheme, args)); * } catch (Exception e) { * System.err.println(e.getMessage()); * } * } * </pre> </code> * <p> * * ------------------------------------------------------------------ <p> * * Example usage from within an application: * <code> <pre> * Instances trainInstances = ... instances got from somewhere * Instances testInstances = ... instances got from somewhere * Classifier scheme = ... scheme got from somewhere * * Evaluation evaluation = new Evaluation(trainInstances); * evaluation.evaluateModel(scheme, testInstances); * System.out.println(evaluation.toSummaryString()); * </pre> </code> * * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @author Len Trigg (trigg@cs.waikato.ac.nz) * @author Saket Joshi (joshi@cs.orst.edu) * @version $Revision: 1.1 $ */public class SequentialEvaluation implements Summarizable { /** The number of classes. */ private int m_NumClasses; /** The number of folds for a cross-validation. */ private int m_NumFolds; /** The weight of all incorrectly classified instances. */ private double m_Incorrect; /** The weight of all correctly classified instances. */ private double m_Correct; /** The weight of all incorrectly classified sequences. */ private double m_SeqIncorrect; /** The weight of all correctly classified sequences. */ private double m_SeqCorrect; /** The weight of all unclassified instances. */ private double m_Unclassified; /*** The weight of all instances that had no class assigned to them. */ private double m_MissingClass; /** The weight of all instances that had a class assigned to them. */ private double m_WithClass; /** The number of all sequences. */ private double m_SeqCount; /** Array for storing the confusion matrix. */ private double [][] m_ConfusionMatrix; /** The names of the classes. */ private String [] m_ClassNames; /** Is the class nominal or numeric? */ private boolean m_ClassIsNominal; /** The prior probabilities of the classes */ private double [] m_ClassPriors; /** The sum of counts for priors */ private double m_ClassPriorsSum; /** The cost matrix (if given). */ private CostMatrix m_CostMatrix; /** The total cost of predictions (includes instance weights) */ private double m_TotalCost; /** Sum of errors. */ private double m_SumErr; /** Sum of absolute errors. */ private double m_SumAbsErr; /** Sum of squared errors. */ private double m_SumSqrErr; /** Sum of class values. */ private double m_SumClass; /** Sum of squared class values. */ private double m_SumSqrClass; /*** Sum of predicted values. */ private double m_SumPredicted; /** Sum of squared predicted values. */ private double m_SumSqrPredicted; /** Sum of predicted * class values. */ private double m_SumClassPredicted; /** Sum of absolute errors of the prior */ private double m_SumPriorAbsErr; /** Sum of absolute errors of the prior */ private double m_SumPriorSqrErr; /** Total Kononenko & Bratko Information */ private double m_SumKBInfo; /*** Resolution of the margin histogram */ private static int k_MarginResolution = 500; /** Cumulative margin distribution */ private double m_MarginCounts []; /** Number of non-missing class training instances seen */ private int m_NumTrainClassVals; /** Array containing all numeric training class values seen */ private double [] m_TrainClassVals; /** Array containing all numeric training class weights */ private double [] m_TrainClassWeights; /** Numeric class error estimator for prior */ private Estimator m_PriorErrorEstimator; /** Numeric class error estimator for scheme */ private Estimator m_ErrorEstimator; /** * The minimum probablility accepted from an estimator to avoid * taking log(0) in Sf calculations. */ private static final double MIN_SF_PROB = Double.MIN_VALUE; /** Total entropy of prior predictions */ private double m_SumPriorEntropy; /** Total entropy of scheme predictions */ private double m_SumSchemeEntropy; /** * Initializes all the counters for the evaluation. * * @param data set of training instances, to get some header * information and prior class distribution information * @exception Exception if the class is not defined */ public SequentialEvaluation(Instances data) throws Exception { this(data, null); } /** * Initializes all the counters for the evaluation and also takes a * cost matrix as parameter. * * @param data set of instances, to get some header information * @param costMatrix the cost matrix---if null, default costs will be used * @exception Exception if cost matrix is not compatible with * data, the class is not defined or the class is numeric */ public SequentialEvaluation(Instances data, CostMatrix costMatrix) throws Exception { m_NumClasses = data.numClasses(); m_NumFolds = 1; m_ClassIsNominal = data.classAttribute().isNominal(); if (m_ClassIsNominal) { m_ConfusionMatrix = new double [m_NumClasses][m_NumClasses]; m_ClassNames = new String [m_NumClasses]; for(int i = 0; i < m_NumClasses; i++) { m_ClassNames[i] = data.classAttribute().value(i); } } m_CostMatrix = costMatrix; if (m_CostMatrix != null) { if (!m_ClassIsNominal) { throw new Exception("Class has to be nominal if cost matrix " + "given!"); } if (m_CostMatrix.size() != m_NumClasses) { throw new Exception("Cost matrix not compatible with data!"); } } m_ClassPriors = new double [m_NumClasses]; setPriors(data); m_MarginCounts = new double [k_MarginResolution + 1]; } /** * Returns a copy of the confusion matrix. * * @return a copy of the confusion matrix as a two-dimensional array */ public double[][] confusionMatrix() { double[][] newMatrix = new double[m_ConfusionMatrix.length][0]; for (int i = 0; i < m_ConfusionMatrix.length; i++) { newMatrix[i] = new double[m_ConfusionMatrix[i].length]; System.arraycopy(m_ConfusionMatrix[i], 0, newMatrix[i], 0, m_ConfusionMatrix[i].length); } return newMatrix; } /** * Performs a (stratified if class is nominal) cross-validation * for a classifier on a set of instances. * * @param classifier the classifier with any options set. * @param data the data on which the cross-validation is to be * performed * @param numFolds the number of folds for the cross-validation * @exception Exception if a classifier could not be generated * successfully or the class is not defined */ public void crossValidateModel(SequentialClassifier classifier, Instances data, int numFolds) throws Exception { // Make a copy of the data we can reorder data = new Instances(data);/* if (data.classAttribute().isNominal()) { data.stratify(numFolds); }*/ // Do the folds try { for (int i = 0; i < numFolds; i++) { Instances [] dataCV = getCVData(data, numFolds, i); Instances train = dataCV[0]; setPriors(train); classifier.buildClassifier(train); Instances test = dataCV[1]; evaluateModel(classifier, test); } m_NumFolds = numFolds; } catch(Exception e) { e.printStackTrace(); } } /** * Produces the training and test sets for Cross Validation. * * @param random a random number generator */ public Instances [] getCVData(Instances data, int folds, int i) throws Exception { Instances [] result = new Instances[2]; int slice = (int) m_SeqCount/folds; if(slice <= 0) { throw new Exception("Too many folds specified"); } if(m_SeqCount%folds != 0) slice++; int lower = slice*i +1; int upper = slice*(i+1); Instances trainCV = new Instances(data); Instances testCV = new Instances(data,0); for(int j = 0; j<trainCV.numInstances(); j++) { if((trainCV.instance(j).value(0) >= lower) && (trainCV.instance(j).value(0) <= upper)) { Instance inst = new Instance(trainCV.instance(j)); trainCV.delete(j); inst.setValue(0, inst.value(0)-lower+1); testCV.add(inst); } if(trainCV.instance(j).value(0) > upper) { trainCV.instance(j).setValue(0, trainCV.instance(j).value(0) - upper); } } result[0] = trainCV; result[1] = testCV; return result; } /** * Shuffles the instances in the set so that they are ordered * randomly. * * @param random a random number generator */ public Instances seqRandomize(Random random, Instances data) { Instances result = new Instances(data, 0); Instances trainResult = new Instances(data, 0); Instances testResult = new Instances(data, 0); Vector mod = new Vector((int)m_SeqCount); for(int j = 0; j<m_SeqCount; j++) { mod.addElement(new Integer(j+1)); } for (int j = (int)m_SeqCount; j > 0; j--) { int n = random.nextInt(j); int index = ((Integer)mod.elementAt(n)).intValue(); mod.remove(n); for(int i = 0; i<data.numInstances(); i++) { if(data.instance(i).value(0) == index) { result.add(data.instance(i)); } } } return result; } /** * Performs a (stratified if class is nominal) cross-validation * for a classifier on a set of instances. * * @param classifier a string naming the class of the classifier * @param data the data on which the cross-validation is to be * performed * @param numFolds the number of folds for the cross-validation * @param options the options to the classifier. Any options * accepted by the classifier will be removed from this array. * @exception Exception if a classifier could not be generated * successfully or the class is not defined */ public void crossValidateModel(String classifierString, Instances data, int numFolds, String[] options) throws Exception { crossValidateModel((SequentialClassifier)Classifier.forName(classifierString, options), data, numFolds);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -