⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 sequentialevaluation.java

📁 把 sequential 有导师学习问题转化为传统的有导师学习问题
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
    // Compute error estimate from training data    if ((trainStatistics) &&	(trainFileName.length() != 0)) {      if ((classifier instanceof UpdateableClassifier) &&	  (testFileName.length() != 0) &&	  (costMatrix == null)) {	// Classifier was trained incrementally, so we have to 	// reopen the training data in order to test on it.	trainReader = new BufferedReader(new FileReader(trainFileName));	// Incremental testing	train = new Instances(trainReader);		train.sort(0);		trainReader.close();	double count = train.lastInstance().value(0);	if (classIndex != -1) {	  train.setClassIndex(classIndex - 1);	} else {	  train.setClassIndex(train.numAttributes() - 1);	}	testTimeStart = System.currentTimeMillis();		//start	for(int i=0; i<count; i++) {	  Instances instances = new Instances(train, 0);	  while((train.numInstances() > 0) &&		  ((int)train.firstInstance().value(0) == i+1))	{													  	instances.add(train.firstInstance());								train.delete(0);										  }		//end		  trainingEvaluation.evaluateModelOnce((SequentialClassifier)classifier, 			    instances);	}		testTimeElapsed = System.currentTimeMillis() - testTimeStart;      } else {	testTimeStart = System.currentTimeMillis();	trainingEvaluation.evaluateModel(classifier, 					 train);	testTimeElapsed = System.currentTimeMillis() - testTimeStart;      }	      // Print the results of the training evaluation      if (printMargins) {	return trainingEvaluation.toCumulativeMarginDistributionString();      } else {	text.append("\nTime taken to build model: " +		    Utils.doubleToString(trainTimeElapsed / 1000.0,2) +		    " seconds");	text.append("\nTime taken to test model on training data: " +		    Utils.doubleToString(testTimeElapsed / 1000.0,2) +		    " seconds");	text.append(trainingEvaluation.		    toSummaryString("\n\n=== Error on training" + 				    " data ===\n", printComplexityStatistics));	if (template.classAttribute().isNominal()) {	  if (classStatistics) {	    text.append("\n\n" + trainingEvaluation.toClassDetailsString());	  }	  text.append("\n\n" + trainingEvaluation.toMatrixString());	}	      }    }    // Compute proper error estimates    if (testFileName.length() != 0) {         // Testing is on the supplied test data      testReader = new BufferedReader(new FileReader(testFileName));      test = new Instances(testReader);      //test.setClassIndex(classIndex-1);      if (classIndex != -1) {	test.setClassIndex(classIndex - 1);      } else {	test.setClassIndex(test.numAttributes() - 1);      }      testReader.close();      test.sort(0);      /*            for(int i=0; i<max; i++) {	Instances instances = new Instances(test, 0);	while((test.numInstances() > 0) &&	   ((int)test.firstInstance().value(0) == i+1))	{													  	instances.add(test.firstInstance());								test.delete(0);										}	testingEvaluation.evaluateModelOnce((SequentialClassifier)classifier, instances);      }*/	try {      testingEvaluation.evaluateModel((SequentialClassifier)classifier, test);      text.append("\n\n" + testingEvaluation.		  toSummaryString("=== Error on test data ===\n",				  printComplexityStatistics));		} catch(Exception e) {	 e.printStackTrace();	}			  	// System.out.println("and here too");			      } else if (trainFileName.length() != 0) {      // Testing is via cross-validation on training data      random = new Random(seed);      random.setSeed(seed);//      train = randomize(train, random);      testingEvaluation.	crossValidateModel(classifier, train, folds);      if (template.classAttribute().isNumeric()) {	text.append("\n\n\n" + testingEvaluation.		    toSummaryString("=== Cross-validation ===\n",				    printComplexityStatistics));      } else {	text.append("\n\n\n" + testingEvaluation.		    toSummaryString("=== Stratified " + 				    "cross-validation ===\n",				    printComplexityStatistics));      }    }    if (template.classAttribute().isNominal()) {      if (classStatistics) {	text.append("\n\n" + testingEvaluation.toClassDetailsString());      }      text.append("\n\n" + testingEvaluation.toMatrixString());    }        return text.toString();//	return("over");  }  /**   * Attempts to load a cost matrix.   *   * @param costFileName the filename of the cost matrix   * @param numClasses the number of classes that should be in the cost matrix   * (only used if the cost file is in old format).   * @return a <code>CostMatrix</code> value, or null if costFileName is empty   * @exception Exception if an error occurs.   */  private static CostMatrix handleCostOption(String costFileName,                                              int numClasses)     throws Exception {    if ((costFileName != null) && (costFileName.length() != 0)) {      System.out.println(           "NOTE: The behaviour of the -m option has changed between WEKA 3.0"           +" and WEKA 3.1. -m now carries out cost-sensitive *evaluation*"           +" only. For cost-sensitive *prediction*, use one of the"           +" cost-sensitive metaschemes such as"           +" weka.classifiers.meta.CostSensitiveClassifier or"           +" weka.classifiers.meta.MetaCost");      Reader costReader = null;      try {        costReader = new BufferedReader(new FileReader(costFileName));      } catch (Exception e) {        throw new Exception("Can't open file " + e.getMessage() + '.');      }      try {        // First try as a proper cost matrix format        return new CostMatrix(costReader);      } catch (Exception ex) {        try {          // Now try as the poxy old format :-)          //System.err.println("Attempting to read old format cost file");          try {            costReader.close(); // Close the old one            costReader = new BufferedReader(new FileReader(costFileName));          } catch (Exception e) {            throw new Exception("Can't open file " + e.getMessage() + '.');          }          CostMatrix costMatrix = new CostMatrix(numClasses);          //System.err.println("Created default cost matrix");          costMatrix.readOldFormat(costReader);          return costMatrix;          //System.err.println("Read old format");        } catch (Exception e2) {          // re-throw the original exception          //System.err.println("Re-throwing original exception");          throw ex;        }      }    } else {      return null;    }  }  /**   * Evaluates the classifier on a given set of instances.   *   * @param classifier machine learning classifier   * @param data set of test instances for evaluation   * @exception Exception if model could not be evaluated    * successfully   */  public void evaluateModel(SequentialClassifier classifier,			    Instances data) throws Exception {    double [] predicted;    Instances testData = new Instances(data);    m_SeqCount = testData.lastInstance().value(0);	    for(int i=0; i<m_SeqCount; i++) {      Instances instances = new Instances(testData, 0);      while((testData.numInstances() > 0) &&		  ((int)testData.firstInstance().value(0) == i+1)) {													instances.add(testData.firstInstance());							testData.delete(0);      }      instances.sort(1);      evaluateModelOnce((SequentialClassifier)classifier, instances);    }  }    /**   * Evaluates the classifier on a single instance.   *   * @param classifier machine learning classifier   * @param instance the test instance to be classified   * @return the prediction made by the clasifier   * @exception Exception if model could not be evaluated    * successfully or the data contains string attributes   */  public double [] evaluateModelOnce(SequentialClassifier classifier,				  Instances instances) throws Exception {      Instances classMissing = new Instances(instances);    int count = instances.numInstances();    double [] pred = new double[count];    if (m_ClassIsNominal) {      if (classifier instanceof SequentialClassifier) {	double [][] dist = ((SequentialClassifier)classifier).				 distributionForSequence(classMissing);		for(int i = 0; i<count; i++) {	  pred[i] = Utils.maxIndex(dist[i]);	}  		updateStatsForClassifier(dist,				 instances);      } else {	pred = classifier.classifySequence(classMissing);	updateStatsForClassifier(makeDistribution(pred),				 instances);      }    } else {      pred = classifier.classifySequence(classMissing);      updateStatsForPredictor(pred,			      instances);    }    return pred;  }  /**   * Evaluates the supplied distribution on a single instance.   *   * @param dist the supplied distribution   * @param instance the test instance to be classified   * @exception Exception if model could not be evaluated    * successfully   */  public double [] evaluateModelOnce(double [][] dist, 				  Instances instances) throws Exception {    double pred[] = new double[dist.length];    if (m_ClassIsNominal) {      for(int i = 0; i<pred.length; i++) {        pred[i] = Utils.maxIndex(dist[i]);      }	      updateStatsForClassifier(dist, instances);    } else {      pred = dist[0];      updateStatsForPredictor(pred, instances);    }    return pred;  }  /**   * Evaluates the supplied prediction on a single instance.   *   * @param prediction the supplied prediction   * @param instance the test instance to be classified   * @exception Exception if model could not be evaluated    * successfully   */  public void evaluateModelOnce(double [] prediction,				Instances instances) throws Exception {        if (m_ClassIsNominal) {      updateStatsForClassifier(makeDistribution(prediction), 			       instances);    } else {      updateStatsForPredictor(prediction, instances);    }  }  /**   * Wraps a static classifier in enough source to test using the weka   * class libraries.   *   * @param classifier a Sourcable Classifier   * @param className the name to give to the source code class   * @return the source for a static classifier that can be tested with   * weka libraries.   */  protected static String wekaStaticWrapper(Sourcable classifier,                                             String className)     throws Exception {        //String className = "StaticClassifier";    String staticClassifier = classifier.toSource(className);    return "package weka.classifiers;\n"    +"import weka.core.Attribute;\n"    +"import weka.core.Instance;\n"    +"import weka.core.Instances;\n"    +"import weka.classifiers.Classifier;\n\n"    +"public class WekaWrapper extends Classifier {\n\n"    +"  public void buildClassifier(Instances i) throws Exception {\n"    +"  }\n\n"    +"  public double classifyInstance(Instance i) throws Exception {\n\n"    +"    Object [] s = new Object [i.numAttributes()];\n"    +"    for (int j = 0; j < s.length; j++) {\n"    +"      if (!i.isMissing(j)) {\n"    +"        if (i.attribute(j).type() == Attribute.NOMINAL) {\n"    +"          s[j] = i.attribute(j).value((int) i.value(j));\n"    +"        } else if (i.attribute(j).type() == Attribute.NUMERIC) {\n"    +"          s[j] = new Double(i.value(j));\n"    +"        }\n"    +"      }\n"    +"    }\n"    +"    return " + className + ".classify(s);\n"    +"  }\n\n"    +"}\n\n"    +staticClassifier; // The static classifer class  }  /**   * Gets the number of test instances that had a known class value   * (actually the sum of the weights of test instances with known    * class value).   *   * @return the number of test instances with known class   */  public final double numInstances() {        return m_WithClass;  }  /**   * Gets the number of instances incorrectly classified (that is, for   * which an incorrect prediction was made). (Actually the sum of the weights   * of these instances)   *   * @return the number of incorrectly classified instances    */  public final double incorrect() {    return m_Incorrect;  }  /**   * Gets the percentage of instances incorrectly classified (that is, for   * which an incorrect prediction was made).   *   * @return the percent of incorrectly classified instances    * (between 0 and 100)   */  public final double pctIncorrect() {    return 100 * m_Incorrect / m_WithClass;  } /**   * Gets the number of sequences incorrectly classified (that is, for   * which an incorrect prediction was made even for one instance).   *   * @return the number of incorrectly classified sequences    */  public final double seqIncorrect() {    return m_SeqIncorrect;  }  /**   * Gets the percentage of sequences incorrectly classified (that is, for   * which an incorrect prediction was made even for one instance).   *   * @return the percent of incorrectly classified sequences    * (between 0 and 100)   */  public final double seqPctIncorrect() {    return 100 * m_SeqIncorrect / m_SeqCount;  }  /**   * Gets the total cost, that is, the cost of each prediction times the   * weight of the instance, summed over all instances.   *   * @return the total cost   */  public final double totalCost() {    return m_TotalCost;  }    /**   * Gets the average cost, that is, total cost of misclassifications   * (incorrect plus unclassified) over the total number of instances.   *   * @return the average cost.     */  public final double avgCost() {    return m_TotalCost / m_WithClass;  }  /**   * Gets the number of instances correctly classified (that is, for

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -