⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 sequentialevaluation.java

📁 把 sequential 有导师学习问题转化为传统的有导师学习问题
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    SequentialEvaluation.java *    Copyright (C) 1999 Eibe Frank,Len Trigg * */package weka.classifiers;import java.util.*;import java.io.*;import weka.core.*;import weka.estimators.*;import java.util.zip.GZIPInputStream;import java.util.zip.GZIPOutputStream;import weka.classifiers.meta.*;/** * Class for evaluating sequential machine learning models. <p> * See classifiers/meta/RSW.java for more info. <p> * * ------------------------------------------------------------------- <p> * * General options when evaluating a learning scheme from the command-line: <p> * * -t filename <br> * Name of the file with the training data. (required) <p> * * -T filename <br> * Name of the file with the test data. If missing a cross-validation  * is performed. <p> * * -c index <br> * Index of the class attribute (1, 2, ...; default: last). <p> * * -x number <br> * The number of folds for the cross-validation (default: 10). <p> * * -s seed <br> * Random number seed for the cross-validation (default: 1). <p> * * -m filename <br> * The name of a file containing a cost matrix. <p> * * -l filename <br> * Loads classifier from the given file. <p> * * -d filename <br> * Saves classifier built from the training data into the given file. <p> * * -v <br> * Outputs no statistics for the training data. <p> * * -o <br> * Outputs statistics only, not the classifier. <p> *  * -i <br> * Outputs information-retrieval statistics per class. <p> * * -k <br> * Outputs information-theoretic statistics. <p> * * -p range <br> * Outputs predictions for test instances, along with the attributes in  * the specified range (and nothing else). Use '-p 0' if no attributes are * desired. <p> * * -r <br> * Outputs cumulative margin distribution (and nothing else). <p> * * -g <br>  * Only for classifiers that implement "Graphable." Outputs * the graph representation of the classifier (and nothing * else). <p> * * ------------------------------------------------------------------- <p> * * Example usage as the main of a classifier (called FunkyClassifier): * <code> <pre> * public static void main(String [] args) { *   try { *     Classifier scheme = new FunkyClassifier(); *     System.out.println(Evaluation.evaluateModel(scheme, args)); *   } catch (Exception e) { *     System.err.println(e.getMessage()); *   } * } * </pre> </code>  * <p> * * ------------------------------------------------------------------ <p> * * Example usage from within an application: * <code> <pre> * Instances trainInstances = ... instances got from somewhere * Instances testInstances = ... instances got from somewhere * Classifier scheme = ... scheme got from somewhere * * Evaluation evaluation = new Evaluation(trainInstances); * evaluation.evaluateModel(scheme, testInstances); * System.out.println(evaluation.toSummaryString()); * </pre> </code>  * * * @author   Eibe Frank (eibe@cs.waikato.ac.nz) * @author   Len Trigg (trigg@cs.waikato.ac.nz) * @author   Saket Joshi (joshi@cs.orst.edu) * @version  $Revision: 1.1 $  */public class SequentialEvaluation implements Summarizable {  /** The number of classes. */  private int m_NumClasses;  /** The number of folds for a cross-validation. */  private int m_NumFolds;   /** The weight of all incorrectly classified instances. */  private double m_Incorrect;  /** The weight of all correctly classified instances. */  private double m_Correct;  /** The weight of all incorrectly classified sequences. */  private double m_SeqIncorrect;  /** The weight of all correctly classified sequences. */  private double m_SeqCorrect;  /** The weight of all unclassified instances. */  private double m_Unclassified;  /*** The weight of all instances that had no class assigned to them. */  private double m_MissingClass;  /** The weight of all instances that had a class assigned to them. */  private double m_WithClass;  /** The number of all sequences. */  private double m_SeqCount;     /** Array for storing the confusion matrix. */  private double [][] m_ConfusionMatrix;  /** The names of the classes. */  private String [] m_ClassNames;  /** Is the class nominal or numeric? */  private boolean m_ClassIsNominal;    /** The prior probabilities of the classes */  private double [] m_ClassPriors;  /** The sum of counts for priors */  private double m_ClassPriorsSum;  /** The cost matrix (if given). */  private CostMatrix m_CostMatrix;  /** The total cost of predictions (includes instance weights) */  private double m_TotalCost;  /** Sum of errors. */  private double m_SumErr;    /** Sum of absolute errors. */  private double m_SumAbsErr;  /** Sum of squared errors. */  private double m_SumSqrErr;  /** Sum of class values. */  private double m_SumClass;    /** Sum of squared class values. */  private double m_SumSqrClass;  /*** Sum of predicted values. */  private double m_SumPredicted;  /** Sum of squared predicted values. */  private double m_SumSqrPredicted;  /** Sum of predicted * class values. */  private double m_SumClassPredicted;  /** Sum of absolute errors of the prior */  private double m_SumPriorAbsErr;  /** Sum of absolute errors of the prior */  private double m_SumPriorSqrErr;  /** Total Kononenko & Bratko Information */  private double m_SumKBInfo;  /*** Resolution of the margin histogram */  private static int k_MarginResolution = 500;  /** Cumulative margin distribution */  private double m_MarginCounts [];  /** Number of non-missing class training instances seen */  private int m_NumTrainClassVals;  /** Array containing all numeric training class values seen */  private double [] m_TrainClassVals;  /** Array containing all numeric training class weights */  private double [] m_TrainClassWeights;  /** Numeric class error estimator for prior */  private Estimator m_PriorErrorEstimator;  /** Numeric class error estimator for scheme */  private Estimator m_ErrorEstimator;  /**   * The minimum probablility accepted from an estimator to avoid   * taking log(0) in Sf calculations.   */  private static final double MIN_SF_PROB = Double.MIN_VALUE;  /** Total entropy of prior predictions */  private double m_SumPriorEntropy;    /** Total entropy of scheme predictions */  private double m_SumSchemeEntropy;    /**   * Initializes all the counters for the evaluation.   *   * @param data set of training instances, to get some header    * information and prior class distribution information   * @exception Exception if the class is not defined   */  public SequentialEvaluation(Instances data) throws Exception {        this(data, null);  }  /**   * Initializes all the counters for the evaluation and also takes a   * cost matrix as parameter.   *   * @param data set of instances, to get some header information   * @param costMatrix the cost matrix---if null, default costs will be used   * @exception Exception if cost matrix is not compatible with    * data, the class is not defined or the class is numeric   */  public SequentialEvaluation(Instances data, CostMatrix costMatrix)        throws Exception {        m_NumClasses = data.numClasses();    m_NumFolds = 1;    m_ClassIsNominal = data.classAttribute().isNominal();    if (m_ClassIsNominal) {      m_ConfusionMatrix = new double [m_NumClasses][m_NumClasses];      m_ClassNames = new String [m_NumClasses];      for(int i = 0; i < m_NumClasses; i++) {	m_ClassNames[i] = data.classAttribute().value(i);      }    }    m_CostMatrix = costMatrix;    if (m_CostMatrix != null) {      if (!m_ClassIsNominal) {	throw new Exception("Class has to be nominal if cost matrix " + 			    "given!");      }      if (m_CostMatrix.size() != m_NumClasses) {	throw new Exception("Cost matrix not compatible with data!");      }    }    m_ClassPriors = new double [m_NumClasses];    setPriors(data);    m_MarginCounts = new double [k_MarginResolution + 1];  }  /**   * Returns a copy of the confusion matrix.   *   * @return a copy of the confusion matrix as a two-dimensional array   */  public double[][] confusionMatrix() {    double[][] newMatrix = new double[m_ConfusionMatrix.length][0];    for (int i = 0; i < m_ConfusionMatrix.length; i++) {      newMatrix[i] = new double[m_ConfusionMatrix[i].length];      System.arraycopy(m_ConfusionMatrix[i], 0, newMatrix[i], 0,		       m_ConfusionMatrix[i].length);    }    return newMatrix;  }  /**   * Performs a (stratified if class is nominal) cross-validation    * for a classifier on a set of instances.   *   * @param classifier the classifier with any options set.   * @param data the data on which the cross-validation is to be    * performed    * @param numFolds the number of folds for the cross-validation   * @exception Exception if a classifier could not be generated    * successfully or the class is not defined   */  public void crossValidateModel(SequentialClassifier classifier,				 Instances data, int numFolds)     throws Exception {        // Make a copy of the data we can reorder    data = new Instances(data);/*    if (data.classAttribute().isNominal()) {      data.stratify(numFolds);    }*/    // Do the folds    try {    for (int i = 0; i < numFolds; i++) {      Instances [] dataCV = getCVData(data, numFolds, i);      Instances train = dataCV[0];      setPriors(train);      classifier.buildClassifier(train);      Instances test = dataCV[1];      evaluateModel(classifier, test);    }     m_NumFolds = numFolds;    } catch(Exception e) {      e.printStackTrace();    }  }  /**   * Produces the training and test sets for Cross Validation.   *   * @param random a random number generator   */  public Instances [] getCVData(Instances data, int folds, int i) throws Exception {    Instances [] result = new Instances[2];    int slice = (int) m_SeqCount/folds;    if(slice <= 0) {	throw new Exception("Too many folds specified");    }    if(m_SeqCount%folds != 0) slice++;          int lower = slice*i +1;    int upper = slice*(i+1);    Instances trainCV = new Instances(data);    Instances testCV = new Instances(data,0);        for(int j = 0; j<trainCV.numInstances(); j++) {		if((trainCV.instance(j).value(0) >= lower) && (trainCV.instance(j).value(0) <= upper)) {	    Instance inst = new Instance(trainCV.instance(j));	    trainCV.delete(j);	    inst.setValue(0, inst.value(0)-lower+1);	    testCV.add(inst);	}	if(trainCV.instance(j).value(0) > upper) {	    trainCV.instance(j).setValue(0, trainCV.instance(j).value(0) - upper);	}    }    result[0] = trainCV;    result[1] = testCV;    return result;  }      /**   * Shuffles the instances in the set so that they are ordered    * randomly.   *   * @param random a random number generator   */  public Instances seqRandomize(Random random, Instances data) {    Instances result = new Instances(data, 0);    Instances trainResult = new Instances(data, 0);    Instances testResult = new Instances(data, 0);    Vector mod = new Vector((int)m_SeqCount);    for(int j = 0; j<m_SeqCount; j++) {      mod.addElement(new Integer(j+1));    }        for (int j = (int)m_SeqCount; j > 0; j--) {      int n = random.nextInt(j);      int index = ((Integer)mod.elementAt(n)).intValue();      mod.remove(n);      for(int i = 0; i<data.numInstances(); i++) {        if(data.instance(i).value(0) == index) {	  result.add(data.instance(i));	}      }    }      return result;  }  /**   * Performs a (stratified if class is nominal) cross-validation    * for a classifier on a set of instances.   *   * @param classifier a string naming the class of the classifier   * @param data the data on which the cross-validation is to be    * performed    * @param numFolds the number of folds for the cross-validation   * @param options the options to the classifier. Any options   * accepted by the classifier will be removed from this array.   * @exception Exception if a classifier could not be generated    * successfully or the class is not defined   */  public void crossValidateModel(String classifierString,				 Instances data, int numFolds,				 String[] options)        throws Exception {        crossValidateModel((SequentialClassifier)Classifier.forName(classifierString, options),		       data, numFolds);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -