⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 evaluation.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
	}
	testTimeElapsed = System.currentTimeMillis() - testTimeStart;
	trainReader.close();
      } else {
	testTimeStart = System.currentTimeMillis();
	trainingEvaluation.evaluateModel(classifier, 
					 train);
	testTimeElapsed = System.currentTimeMillis() - testTimeStart;
      }

      // Print the results of the training evaluation
      if (printMargins) {
	return trainingEvaluation.toCumulativeMarginDistributionString();
      } else {
	text.append("\nTime taken to build model: " +
		    Utils.doubleToString(trainTimeElapsed / 1000.0,2) +
		    " seconds");
	text.append("\nTime taken to test model on training data: " +
		    Utils.doubleToString(testTimeElapsed / 1000.0,2) +
		    " seconds");
	text.append(trainingEvaluation.
		    toSummaryString("\n\n=== Error on training" + 
				    " data ===\n", printComplexityStatistics));
	if (template.classAttribute().isNominal()) {
	  if (classStatistics) {
	    text.append("\n\n" + trainingEvaluation.toClassDetailsString());
	  }
	  text.append("\n\n" + trainingEvaluation.toMatrixString());
	}
	
      }
    }

    // Compute proper error estimates
    if (testFileName.length() != 0) {

      // Testing is on the supplied test data
      while (test.readInstance(testReader)) {
	  
	testingEvaluation.evaluateModelOnce((Classifier)classifier, 
                                            test.instance(0));
	test.delete(0);
      }
      testReader.close();

      text.append("\n\n" + testingEvaluation.
		  toSummaryString("=== Error on test data ===\n",
				  printComplexityStatistics));
    } else if (trainFileName.length() != 0) {

      // Testing is via cross-validation on training data
      Random random = new Random(seed);
      // use untrained (!) classifier for cross-validation
      classifier = Classifier.makeCopy(classifierBackup);
      testingEvaluation.crossValidateModel(classifier, train, folds, random);
      if (template.classAttribute().isNumeric()) {
	text.append("\n\n\n" + testingEvaluation.
		    toSummaryString("=== Cross-validation ===\n",
				    printComplexityStatistics));
      } else {
	text.append("\n\n\n" + testingEvaluation.
		    toSummaryString("=== Stratified " + 
				    "cross-validation ===\n",
				    printComplexityStatistics));
      }
    }
    if (template.classAttribute().isNominal()) {
      if (classStatistics) {
	text.append("\n\n" + testingEvaluation.toClassDetailsString());
      }
      text.append("\n\n" + testingEvaluation.toMatrixString());
    }
    return text.toString();
  }


  /**
   * Attempts to load a cost matrix.
   *
   * @param costFileName the filename of the cost matrix
   * @param numClasses the number of classes that should be in the cost matrix
   * (only used if the cost file is in old format).
   * @return a <code>CostMatrix</code> value, or null if costFileName is empty
   * @throws Exception if an error occurs.
   */
  protected static CostMatrix handleCostOption(String costFileName, 
                                             int numClasses) 
    throws Exception {

    if ((costFileName != null) && (costFileName.length() != 0)) {
      System.out.println(
           "NOTE: The behaviour of the -m option has changed between WEKA 3.0"
           +" and WEKA 3.1. -m now carries out cost-sensitive *evaluation*"
           +" only. For cost-sensitive *prediction*, use one of the"
           +" cost-sensitive metaschemes such as"
           +" weka.classifiers.meta.CostSensitiveClassifier or"
           +" weka.classifiers.meta.MetaCost");

      Reader costReader = null;
      try {
        costReader = new BufferedReader(new FileReader(costFileName));
      } catch (Exception e) {
        throw new Exception("Can't open file " + e.getMessage() + '.');
      }
      try {
        // First try as a proper cost matrix format
        return new CostMatrix(costReader);
      } catch (Exception ex) {
        try {
          // Now try as the poxy old format :-)
          //System.err.println("Attempting to read old format cost file");
          try {
            costReader.close(); // Close the old one
            costReader = new BufferedReader(new FileReader(costFileName));
          } catch (Exception e) {
            throw new Exception("Can't open file " + e.getMessage() + '.');
          }
          CostMatrix costMatrix = new CostMatrix(numClasses);
          //System.err.println("Created default cost matrix");
          costMatrix.readOldFormat(costReader);
          return costMatrix;
          //System.err.println("Read old format");
        } catch (Exception e2) {
          // re-throw the original exception
          //System.err.println("Re-throwing original exception");
          throw ex;
        }
      }
    } else {
      return null;
    }
  }

  /**
   * Evaluates the classifier on a given set of instances. Note that
   * the data must have exactly the same format (e.g. order of
   * attributes) as the data used to train the classifier! Otherwise
   * the results will generally be meaningless.
   *
   * @param classifier machine learning classifier
   * @param data set of test instances for evaluation
   * @return the predictions
   * @throws Exception if model could not be evaluated 
   * successfully 
   */
  public double[] evaluateModel(Classifier classifier,
			    Instances data) throws Exception {
    
    double predictions[] = new double[data.numInstances()];

    for (int i = 0; i < data.numInstances(); i++) {
      predictions[i] = evaluateModelOnce((Classifier)classifier, 
                                data.instance(i));
    }
    return predictions;
  }
  
  /**
   * Evaluates the classifier on a single instance.
   *
   * @param classifier machine learning classifier
   * @param instance the test instance to be classified
   * @return the prediction made by the clasifier
   * @throws Exception if model could not be evaluated 
   * successfully or the data contains string attributes
   */
  public double evaluateModelOnce(Classifier classifier,
				  Instance instance) throws Exception {
  
    Instance classMissing = (Instance)instance.copy();
    double pred = 0;
    classMissing.setDataset(instance.dataset());
    classMissing.setClassMissing();
    if (m_ClassIsNominal) {
	double [] dist = classifier.distributionForInstance(classMissing);
	pred = Utils.maxIndex(dist);
	if (dist[(int)pred] <= 0) {
	  pred = Instance.missingValue();
	}
	updateStatsForClassifier(dist, instance);
    } else {
      pred = classifier.classifyInstance(classMissing);
      updateStatsForPredictor(pred, instance);
    }
    return pred;
  }

  /**
   * Evaluates the supplied distribution on a single instance.
   *
   * @param dist the supplied distribution
   * @param instance the test instance to be classified
   * @return the prediction
   * @throws Exception if model could not be evaluated 
   * successfully
   */
  public double evaluateModelOnce(double [] dist, 
				  Instance instance) throws Exception {
    double pred;
    if (m_ClassIsNominal) {
      pred = Utils.maxIndex(dist);
      if (dist[(int)pred] <= 0) {
	pred = Instance.missingValue();
      }
      updateStatsForClassifier(dist, instance);
    } else {
      pred = dist[0];
      updateStatsForPredictor(pred, instance);
    }
    return pred;
  }

  /**
   * Evaluates the supplied prediction on a single instance.
   *
   * @param prediction the supplied prediction
   * @param instance the test instance to be classified
   * @throws Exception if model could not be evaluated 
   * successfully
   */
  public void evaluateModelOnce(double prediction,
				Instance instance) throws Exception {
    
    if (m_ClassIsNominal) {
      updateStatsForClassifier(makeDistribution(prediction), 
			       instance);
    } else {
      updateStatsForPredictor(prediction, instance);
    }
  }


  /**
   * Wraps a static classifier in enough source to test using the weka
   * class libraries.
   *
   * @param classifier a Sourcable Classifier
   * @param className the name to give to the source code class
   * @return the source for a static classifier that can be tested with
   * weka libraries.
   * @throws Exception if code-generation fails
   */
  protected static String wekaStaticWrapper(Sourcable classifier, 
                                            String className) 
    throws Exception {
    
    //String className = "StaticClassifier";
    String staticClassifier = classifier.toSource(className);
    return "package weka.classifiers;\n\n"
    +"import weka.core.Attribute;\n"
    +"import weka.core.Instance;\n"
    +"import weka.core.Instances;\n"
    +"import weka.classifiers.Classifier;\n\n"
    +"public class WekaWrapper extends Classifier {\n\n"
    +"  public void buildClassifier(Instances i) throws Exception {\n"
    +"  }\n\n"
    +"  public double classifyInstance(Instance i) throws Exception {\n\n"
    +"    Object [] s = new Object [i.numAttributes()];\n"
    +"    for (int j = 0; j < s.length; j++) {\n"
    +"      if (!i.isMissing(j)) {\n"
    +"        if (i.attribute(j).type() == Attribute.NOMINAL) {\n"
    +"          s[j] = i.attribute(j).value((int) i.value(j));\n"
    +"        } else if (i.attribute(j).type() == Attribute.NUMERIC) {\n"
    +"          s[j] = new Double(i.value(j));\n"
    +"        }\n"
    +"      }\n"
    +"    }\n"
    +"    return " + className + ".classify(s);\n"
    +"  }\n\n"
    +"}\n\n"
    +staticClassifier; // The static classifer class
  }

  /**
   * Gets the number of test instances that had a known class value
   * (actually the sum of the weights of test instances with known 
   * class value).
   *
   * @return the number of test instances with known class
   */
  public final double numInstances() {
    
    return m_WithClass;
  }

  /**
   * Gets the number of instances incorrectly classified (that is, for
   * which an incorrect prediction was made). (Actually the sum of the weights
   * of these instances)
   *
   * @return the number of incorrectly classified instances 
   */
  public final double incorrect() {

    return m_Incorrect;
  }

  /**
   * Gets the percentage of instances incorrectly classified (that is, for
   * which an incorrect prediction was made).
   *
   * @return the percent of incorrectly classified instances 
   * (between 0 and 100)
   */
  public final double pctIncorrect() {

    return 100 * m_Incorrect / m_WithClass;
  }

  /**
   * Gets the total cost, that is, the cost of each prediction times the
   * weight of the instance, summed over all instances.
   *
   * @return the total cost
   */
  public final double totalCost() {

    return m_TotalCost;
  }
  
  /**
   * Gets the average cost, that is, total cost of misclassifications
   * (incorrect plus unclassified) over the total number of instances.
   *
   * @return the average cost.  
   */
  public final double avgCost() {

    return m_TotalCost / m_WithClass;
  }

  /**
   * Gets the number of instances correctly classified (that is, for
   * which a correct prediction was made). (Actually the sum of the weights
   * of these instances)
   *
   * @return the number of correctly classified instances
   */
  public final double correct() {
    
    return m_Correct;
  }

  /**
   * Gets the percentage of instances correctly classified (that is, for
   * which a correct prediction was made).
   *
   * @return the percent of correctly classified instances (between 0 and 100)
   */
  public final double pctCorrect() {
    
    return 100 * m_Correct / m_WithClass;
  }
  
  /**
   * Gets the number of instances not classified (that is, for
   * which no prediction was made by the classifier). (Actually the sum
   * of the weights of these instances)
   *
   * @return the number of unclassified instances
   */
  public final double unclassified() {
    
    return m_Unclassified;
  }

  /**
   * Gets the percentage of instances not classified (that is, for
   * which no prediction was made by the classifier).
   *
   * @return the percent of unclassified instances (between 0 and 100)
   */
  public final double pctUnclassified() {
    
    return 100 * m_Unclassified / m_WithClass;
  }

  /**
   * Returns the estimated error rate or the root mean squared error
   * (if the class is numeric). If a cost matrix was given this
   * error rate gives the average cost.
   *
   * @return the estimated error rate (between 0 and 1, or between 0 and 
   * maximum cost)
   */
  public final double errorRate() {

    if (!m_ClassIsNominal) {
      return Math.sqrt(m_SumSqrErr / m_WithClass);
    }
    if (m_CostMatrix == null) {
      return m_Incorrect / m_WithClass;
    } else {
      return avgCost();
    }
  }

  /**
   * Returns value of kappa statistic if class is nominal.
   *
   * @return the value of the kappa statistic
   */
  public final double kappa() {
    

    double[] sumRows = new double[m_ConfusionMatrix.length];
    double[] sumColumns = new double[m_ConfusionMatrix.length];
    double sumOfWeights = 0;
    for (int i = 0; i < m_ConfusionMatrix.length; i++) {
      for (int j = 0; j < m_ConfusionMatrix.length; j++) {
	sumRows[i] += m_ConfusionMatrix[i][j];
	sumColumns[j] += m_ConfusionMatrix[i][j];
	sumOfWeights += m_ConfusionMatrix[i][j];
      }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -