evaluation.java

来自「Java 编写的多种数据挖掘算法 包括聚类、分类、预处理等」· Java 代码 · 共 2,078 行 · 第 1/5 页

JAVA
2,078
字号
         if (objectOutputFileName.endsWith(".koml")) {            KOML.write(xmlOutputStream, classifier);         }         xmlOutputStream.close();      }    }    // If classifier is drawable output string describing graph    if ((classifier instanceof Drawable)	&& (printGraph)){      return ((Drawable)classifier).graph();    }    // Output the classifier as equivalent source    if ((classifier instanceof Sourcable)	&& (printSource)){      return wekaStaticWrapper((Sourcable) classifier, sourceClass);    }    // Output test instance predictions only    if (printClassifications) {      return printClassifications(classifier, new Instances(template, 0),				  testFileName, classIndex, attributesToOutput);    }    // Output model    if (!(noOutput || printMargins)) {      if (classifier instanceof OptionHandler) {	if (schemeOptionsText != null) {	  text.append("\nOptions: "+schemeOptionsText);	  text.append("\n");	}      }      text.append("\n" + classifier.toString() + "\n");    }    if (!printMargins && (costMatrix != null)) {      text.append("\n=== Evaluation Cost Matrix ===\n\n")        .append(costMatrix.toString());    }    // Compute error estimate from training data    if ((trainStatistics) &&	(trainFileName.length() != 0)) {      if ((classifier instanceof UpdateableClassifier) &&	  (testFileName.length() != 0) &&	  (costMatrix == null)) {	// Classifier was trained incrementally, so we have to 	// reopen the training data in order to test on it.	trainReader = new BufferedReader(new FileReader(trainFileName));	// Incremental testing	train = new Instances(trainReader, 1);	if (classIndex != -1) {	  train.setClassIndex(classIndex - 1);	} else {	  train.setClassIndex(train.numAttributes() - 1);	}	testTimeStart = System.currentTimeMillis();	while (train.readInstance(trainReader)) {	  trainingEvaluation.	  evaluateModelOnce((Classifier)classifier, 			    train.instance(0));	  train.delete(0);	}	testTimeElapsed = System.currentTimeMillis() - testTimeStart;	trainReader.close();      } else {	testTimeStart = System.currentTimeMillis();	trainingEvaluation.evaluateModel(classifier, 					 train);	testTimeElapsed = System.currentTimeMillis() - testTimeStart;      }      // Print the results of the training evaluation      if (printMargins) {	return trainingEvaluation.toCumulativeMarginDistributionString();      } else {	text.append("\nTime taken to build model: " +		    Utils.doubleToString(trainTimeElapsed / 1000.0,2) +		    " seconds");	text.append("\nTime taken to test model on training data: " +		    Utils.doubleToString(testTimeElapsed / 1000.0,2) +		    " seconds");	text.append(trainingEvaluation.		    toSummaryString("\n\n=== Error on training" + 				    " data ===\n", printComplexityStatistics));	if (template.classAttribute().isNominal()) {	  if (classStatistics) {	    text.append("\n\n" + trainingEvaluation.toClassDetailsString());	  }	  text.append("\n\n" + trainingEvaluation.toMatrixString());	}	      }    }    // Compute proper error estimates    if (testFileName.length() != 0) {      // Testing is on the supplied test data      while (test.readInstance(testReader)) {        testingEvaluation.evaluateModelOnce((Classifier)classifier,                                             test.instance(0));        test.delete(0);      }      testReader.close();      text.append("\n\n" + testingEvaluation.		  toSummaryString("=== Error on test data ===\n",				  printComplexityStatistics));    } else if (trainFileName.length() != 0) {      // Testing is via cross-validation on training data      Random random = new Random(seed);      // use untrained (!) classifier for cross-validation      classifier = Classifier.makeCopy(classifierBackup);      testingEvaluation.crossValidateModel(classifier, train, folds, random);      if (template.classAttribute().isNumeric()) {	text.append("\n\n\n" + testingEvaluation.		    toSummaryString("=== Cross-validation ===\n",				    printComplexityStatistics));      } else {	text.append("\n\n\n" + testingEvaluation.		    toSummaryString("=== Stratified " + 				    "cross-validation ===\n",				    printComplexityStatistics));      }    }    if (template.classAttribute().isNominal()) {      if (classStatistics) {	text.append("\n\n" + testingEvaluation.toClassDetailsString());      }      text.append("\n\n" + testingEvaluation.toMatrixString());    }    return text.toString();  }  /**   * Attempts to load a cost matrix.   *   * @param costFileName the filename of the cost matrix   * @param numClasses the number of classes that should be in the cost matrix   * (only used if the cost file is in old format).   * @return a <code>CostMatrix</code> value, or null if costFileName is empty   * @exception Exception if an error occurs.   */  protected static CostMatrix handleCostOption(String costFileName,                                              int numClasses)     throws Exception {    if ((costFileName != null) && (costFileName.length() != 0)) {      System.out.println(           "NOTE: The behaviour of the -m option has changed between WEKA 3.0"           +" and WEKA 3.1. -m now carries out cost-sensitive *evaluation*"           +" only. For cost-sensitive *prediction*, use one of the"           +" cost-sensitive metaschemes such as"           +" weka.classifiers.meta.CostSensitiveClassifier or"           +" weka.classifiers.meta.MetaCost");      Reader costReader = null;      try {        costReader = new BufferedReader(new FileReader(costFileName));      } catch (Exception e) {        throw new Exception("Can't open file " + e.getMessage() + '.');      }      try {        // First try as a proper cost matrix format        return new CostMatrix(costReader);      } catch (Exception ex) {        try {          // Now try as the poxy old format :-)          //System.err.println("Attempting to read old format cost file");          try {            costReader.close(); // Close the old one            costReader = new BufferedReader(new FileReader(costFileName));          } catch (Exception e) {            throw new Exception("Can't open file " + e.getMessage() + '.');          }          CostMatrix costMatrix = new CostMatrix(numClasses);          //System.err.println("Created default cost matrix");          costMatrix.readOldFormat(costReader);          return costMatrix;          //System.err.println("Read old format");        } catch (Exception e2) {          // re-throw the original exception          //System.err.println("Re-throwing original exception");          throw ex;        }      }    } else {      return null;    }  }  /**   * Evaluates the classifier on a given set of instances. Note that   * the data must have exactly the same format (e.g. order of   * attributes) as the data used to train the classifier! Otherwise   * the results will generally be meaningless.   *   * @param classifier machine learning classifier   * @param data set of test instances for evaluation   * @exception Exception if model could not be evaluated    * successfully    */  public double[] evaluateModel(Classifier classifier,			    Instances data) throws Exception {    double predictions[] = new double[data.numInstances()];    // Need to be able to collect predictions if appropriate (for AUC)    for (int i = 0; i < data.numInstances(); i++) {      predictions[i] = evaluateModelOnceAndRecordPrediction((Classifier)classifier, 							    data.instance(i));    }    return predictions;  }  /**   * Evaluates the classifier on a single instance and records the   * prediction (if the class is nominal).   *   * @param classifier machine learning classifier   * @param instance the test instance to be classified   * @return the prediction made by the clasifier   * @exception Exception if model could not be evaluated    * successfully or the data contains string attributes   */  public double evaluateModelOnceAndRecordPrediction(Classifier classifier,						     Instance instance) throws Exception {    Instance classMissing = (Instance)instance.copy();    double pred = 0;    classMissing.setDataset(instance.dataset());    classMissing.setClassMissing();    if (m_ClassIsNominal) {      if (m_Predictions == null) {	m_Predictions = new FastVector();      }      double [] dist = classifier.distributionForInstance(classMissing);      pred = Utils.maxIndex(dist);      if (dist[(int)pred] <= 0) {	pred = Instance.missingValue();      }      updateStatsForClassifier(dist, instance);      m_Predictions.addElement(new NominalPrediction(instance.classValue(), dist, 						     instance.weight()));    } else {      pred = classifier.classifyInstance(classMissing);      updateStatsForPredictor(pred, instance);    }    return pred;  }  /**   * Evaluates the classifier on a single instance.   *   * @param classifier machine learning classifier   * @param instance the test instance to be classified   * @return the prediction made by the clasifier   * @exception Exception if model could not be evaluated    * successfully or the data contains string attributes   */  public double evaluateModelOnce(Classifier classifier,				  Instance instance) throws Exception {    Instance classMissing = (Instance)instance.copy();    double pred = 0;    classMissing.setDataset(instance.dataset());    classMissing.setClassMissing();    if (m_ClassIsNominal) {	double [] dist = classifier.distributionForInstance(classMissing);	pred = Utils.maxIndex(dist);	if (dist[(int)pred] <= 0) {	  pred = Instance.missingValue();	}	updateStatsForClassifier(dist, instance);    } else {      pred = classifier.classifyInstance(classMissing);      updateStatsForPredictor(pred, instance);    }    return pred;  }  /**   * Evaluates the supplied distribution on a single instance.   *   * @param dist the supplied distribution   * @param instance the test instance to be classified   * @exception Exception if model could not be evaluated    * successfully   */  public double evaluateModelOnce(double [] dist, 				  Instance instance) throws Exception {    double pred;    if (m_ClassIsNominal) {      pred = Utils.maxIndex(dist);      if (dist[(int)pred] <= 0) {	pred = Instance.missingValue();      }      updateStatsForClassifier(dist, instance);    } else {      pred = dist[0];      updateStatsForPredictor(pred, instance);    }    return pred;  }  /**   * Evaluates the supplied prediction on a single instance.   *   * @param prediction the supplied prediction   * @param instance the test instance to be classified   * @exception Exception if model could not be evaluated    * successfully   */  public void evaluateModelOnce(double prediction,				Instance instance) throws Exception {    if (m_ClassIsNominal) {      updateStatsForClassifier(makeDistribution(prediction), 			       instance);    } else {      updateStatsForPredictor(prediction, instance);    }  }  /**   * Returns the predictions that have been collected.   *   * @return a reference to the FastVector containing the predictions   * that have been collected. This should be null if no predictions   * have been collected (e.g. if the class is numeric).   */  public FastVector predictions() {    return m_Predictions;  }  /**   * Wraps a static classifier in enough source to test using the weka   * class libraries.   *   * @param classifier a Sourcable Classifier   * @param className the name to give to the source code class   * @return the source for a static classifier that can be tested with   * weka libraries.   */  protected static String wekaStaticWrapper(Sourcable classifier,                                             String className)     throws Exception {    //String className = "StaticClassifier";    String staticClassifier = classifier.toSource(className);    return "package weka.classifiers;\n"    +"import weka.core.Attribute;\n"    +"import weka.core.Instance;\n"    +"import weka.core.Instances;\n"    +"import weka.classifiers.Classifier;\n\n"    +"public class WekaWrapper extends Classifier {\n\n"    +"  public void buildClassifier(Instances i) throws Exception {\n"    +"  }\n\n"    +"  public double classifyInstance(Instance i) throws Exception {\n\n"    +"    Object [] s = new Object [i.numAttributes()];\n"    +"    for (int j = 0; j < s.length; j++) {\n"    +"      if (!i.isMissing(j)) {\n"    +"        if (i.attribute(j).type() == Attribute.NOMINAL) {\n"    +"          s[j] = i.attribute(j).value((int) i.value(j));\n"    +"        } else if (i.attribute(j).type() == Attribute.NUMERIC) {\n"    +"          s[j] = new Double(i.value(j));\n"    +"        }\n"    +"      }\n"    +"    }\n"    +"    return " + className + ".classify(s);\n"    +"  }\n\n"    +"}\n\n"    +staticClassifier; // The static classifer class  }  /**   * Gets the number of test instances that had a known class value   * (actually the sum of the weights of test instances with known    * class value).   *   * @return the number of test instances with known class   */  public final double numInstances() {    return m_WithClass;  }  /**   * Gets the number of instances incorrectly classified (that is, for   * which an incorrect prediction was made). (Actually the sum of the weights   * of these instances)   *   * @return the number of incorrectly classified instances    */  public final double incorrect() {    return m_Incorrect;  }  /**   * Gets the percentage of instances incorrectly classified (that is, for   * which an incorrect prediction was made).   *   * @return the percent of incorrectly classified instances    * (between 0 and 100)

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?