📄 sequentialevaluation.java
字号:
// Compute error estimate from training data if ((trainStatistics) && (trainFileName.length() != 0)) { if ((classifier instanceof UpdateableClassifier) && (testFileName.length() != 0) && (costMatrix == null)) { // Classifier was trained incrementally, so we have to // reopen the training data in order to test on it. trainReader = new BufferedReader(new FileReader(trainFileName)); // Incremental testing train = new Instances(trainReader); train.sort(0); trainReader.close(); double count = train.lastInstance().value(0); if (classIndex != -1) { train.setClassIndex(classIndex - 1); } else { train.setClassIndex(train.numAttributes() - 1); } testTimeStart = System.currentTimeMillis(); //start for(int i=0; i<count; i++) { Instances instances = new Instances(train, 0); while((train.numInstances() > 0) && ((int)train.firstInstance().value(0) == i+1)) { instances.add(train.firstInstance()); train.delete(0); } //end trainingEvaluation.evaluateModelOnce((SequentialClassifier)classifier, instances); } testTimeElapsed = System.currentTimeMillis() - testTimeStart; } else { testTimeStart = System.currentTimeMillis(); trainingEvaluation.evaluateModel(classifier, train); testTimeElapsed = System.currentTimeMillis() - testTimeStart; } // Print the results of the training evaluation if (printMargins) { return trainingEvaluation.toCumulativeMarginDistributionString(); } else { text.append("\nTime taken to build model: " + Utils.doubleToString(trainTimeElapsed / 1000.0,2) + " seconds"); text.append("\nTime taken to test model on training data: " + Utils.doubleToString(testTimeElapsed / 1000.0,2) + " seconds"); text.append(trainingEvaluation. toSummaryString("\n\n=== Error on training" + " data ===\n", printComplexityStatistics)); if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + trainingEvaluation.toClassDetailsString()); } text.append("\n\n" + trainingEvaluation.toMatrixString()); } } } // Compute proper error estimates if (testFileName.length() != 0) { // Testing is on the supplied test data testReader = new BufferedReader(new FileReader(testFileName)); test = new Instances(testReader); //test.setClassIndex(classIndex-1); if (classIndex != -1) { test.setClassIndex(classIndex - 1); } else { test.setClassIndex(test.numAttributes() - 1); } testReader.close(); test.sort(0); /* for(int i=0; i<max; i++) { Instances instances = new Instances(test, 0); while((test.numInstances() > 0) && ((int)test.firstInstance().value(0) == i+1)) { instances.add(test.firstInstance()); test.delete(0); } testingEvaluation.evaluateModelOnce((SequentialClassifier)classifier, instances); }*/ try { testingEvaluation.evaluateModel((SequentialClassifier)classifier, test); text.append("\n\n" + testingEvaluation. toSummaryString("=== Error on test data ===\n", printComplexityStatistics)); } catch(Exception e) { e.printStackTrace(); } // System.out.println("and here too"); } else if (trainFileName.length() != 0) { // Testing is via cross-validation on training data random = new Random(seed); random.setSeed(seed);// train = randomize(train, random); testingEvaluation. crossValidateModel(classifier, train, folds); if (template.classAttribute().isNumeric()) { text.append("\n\n\n" + testingEvaluation. toSummaryString("=== Cross-validation ===\n", printComplexityStatistics)); } else { text.append("\n\n\n" + testingEvaluation. toSummaryString("=== Stratified " + "cross-validation ===\n", printComplexityStatistics)); } } if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + testingEvaluation.toClassDetailsString()); } text.append("\n\n" + testingEvaluation.toMatrixString()); } return text.toString();// return("over"); } /** * Attempts to load a cost matrix. * * @param costFileName the filename of the cost matrix * @param numClasses the number of classes that should be in the cost matrix * (only used if the cost file is in old format). * @return a <code>CostMatrix</code> value, or null if costFileName is empty * @exception Exception if an error occurs. */ private static CostMatrix handleCostOption(String costFileName, int numClasses) throws Exception { if ((costFileName != null) && (costFileName.length() != 0)) { System.out.println( "NOTE: The behaviour of the -m option has changed between WEKA 3.0" +" and WEKA 3.1. -m now carries out cost-sensitive *evaluation*" +" only. For cost-sensitive *prediction*, use one of the" +" cost-sensitive metaschemes such as" +" weka.classifiers.meta.CostSensitiveClassifier or" +" weka.classifiers.meta.MetaCost"); Reader costReader = null; try { costReader = new BufferedReader(new FileReader(costFileName)); } catch (Exception e) { throw new Exception("Can't open file " + e.getMessage() + '.'); } try { // First try as a proper cost matrix format return new CostMatrix(costReader); } catch (Exception ex) { try { // Now try as the poxy old format :-) //System.err.println("Attempting to read old format cost file"); try { costReader.close(); // Close the old one costReader = new BufferedReader(new FileReader(costFileName)); } catch (Exception e) { throw new Exception("Can't open file " + e.getMessage() + '.'); } CostMatrix costMatrix = new CostMatrix(numClasses); //System.err.println("Created default cost matrix"); costMatrix.readOldFormat(costReader); return costMatrix; //System.err.println("Read old format"); } catch (Exception e2) { // re-throw the original exception //System.err.println("Re-throwing original exception"); throw ex; } } } else { return null; } } /** * Evaluates the classifier on a given set of instances. * * @param classifier machine learning classifier * @param data set of test instances for evaluation * @exception Exception if model could not be evaluated * successfully */ public void evaluateModel(SequentialClassifier classifier, Instances data) throws Exception { double [] predicted; Instances testData = new Instances(data); m_SeqCount = testData.lastInstance().value(0); for(int i=0; i<m_SeqCount; i++) { Instances instances = new Instances(testData, 0); while((testData.numInstances() > 0) && ((int)testData.firstInstance().value(0) == i+1)) { instances.add(testData.firstInstance()); testData.delete(0); } instances.sort(1); evaluateModelOnce((SequentialClassifier)classifier, instances); } } /** * Evaluates the classifier on a single instance. * * @param classifier machine learning classifier * @param instance the test instance to be classified * @return the prediction made by the clasifier * @exception Exception if model could not be evaluated * successfully or the data contains string attributes */ public double [] evaluateModelOnce(SequentialClassifier classifier, Instances instances) throws Exception { Instances classMissing = new Instances(instances); int count = instances.numInstances(); double [] pred = new double[count]; if (m_ClassIsNominal) { if (classifier instanceof SequentialClassifier) { double [][] dist = ((SequentialClassifier)classifier). distributionForSequence(classMissing); for(int i = 0; i<count; i++) { pred[i] = Utils.maxIndex(dist[i]); } updateStatsForClassifier(dist, instances); } else { pred = classifier.classifySequence(classMissing); updateStatsForClassifier(makeDistribution(pred), instances); } } else { pred = classifier.classifySequence(classMissing); updateStatsForPredictor(pred, instances); } return pred; } /** * Evaluates the supplied distribution on a single instance. * * @param dist the supplied distribution * @param instance the test instance to be classified * @exception Exception if model could not be evaluated * successfully */ public double [] evaluateModelOnce(double [][] dist, Instances instances) throws Exception { double pred[] = new double[dist.length]; if (m_ClassIsNominal) { for(int i = 0; i<pred.length; i++) { pred[i] = Utils.maxIndex(dist[i]); } updateStatsForClassifier(dist, instances); } else { pred = dist[0]; updateStatsForPredictor(pred, instances); } return pred; } /** * Evaluates the supplied prediction on a single instance. * * @param prediction the supplied prediction * @param instance the test instance to be classified * @exception Exception if model could not be evaluated * successfully */ public void evaluateModelOnce(double [] prediction, Instances instances) throws Exception { if (m_ClassIsNominal) { updateStatsForClassifier(makeDistribution(prediction), instances); } else { updateStatsForPredictor(prediction, instances); } } /** * Wraps a static classifier in enough source to test using the weka * class libraries. * * @param classifier a Sourcable Classifier * @param className the name to give to the source code class * @return the source for a static classifier that can be tested with * weka libraries. */ protected static String wekaStaticWrapper(Sourcable classifier, String className) throws Exception { //String className = "StaticClassifier"; String staticClassifier = classifier.toSource(className); return "package weka.classifiers;\n" +"import weka.core.Attribute;\n" +"import weka.core.Instance;\n" +"import weka.core.Instances;\n" +"import weka.classifiers.Classifier;\n\n" +"public class WekaWrapper extends Classifier {\n\n" +" public void buildClassifier(Instances i) throws Exception {\n" +" }\n\n" +" public double classifyInstance(Instance i) throws Exception {\n\n" +" Object [] s = new Object [i.numAttributes()];\n" +" for (int j = 0; j < s.length; j++) {\n" +" if (!i.isMissing(j)) {\n" +" if (i.attribute(j).type() == Attribute.NOMINAL) {\n" +" s[j] = i.attribute(j).value((int) i.value(j));\n" +" } else if (i.attribute(j).type() == Attribute.NUMERIC) {\n" +" s[j] = new Double(i.value(j));\n" +" }\n" +" }\n" +" }\n" +" return " + className + ".classify(s);\n" +" }\n\n" +"}\n\n" +staticClassifier; // The static classifer class } /** * Gets the number of test instances that had a known class value * (actually the sum of the weights of test instances with known * class value). * * @return the number of test instances with known class */ public final double numInstances() { return m_WithClass; } /** * Gets the number of instances incorrectly classified (that is, for * which an incorrect prediction was made). (Actually the sum of the weights * of these instances) * * @return the number of incorrectly classified instances */ public final double incorrect() { return m_Incorrect; } /** * Gets the percentage of instances incorrectly classified (that is, for * which an incorrect prediction was made). * * @return the percent of incorrectly classified instances * (between 0 and 100) */ public final double pctIncorrect() { return 100 * m_Incorrect / m_WithClass; } /** * Gets the number of sequences incorrectly classified (that is, for * which an incorrect prediction was made even for one instance). * * @return the number of incorrectly classified sequences */ public final double seqIncorrect() { return m_SeqIncorrect; } /** * Gets the percentage of sequences incorrectly classified (that is, for * which an incorrect prediction was made even for one instance). * * @return the percent of incorrectly classified sequences * (between 0 and 100) */ public final double seqPctIncorrect() { return 100 * m_SeqIncorrect / m_SeqCount; } /** * Gets the total cost, that is, the cost of each prediction times the * weight of the instance, summed over all instances. * * @return the total cost */ public final double totalCost() { return m_TotalCost; } /** * Gets the average cost, that is, total cost of misclassifications * (incorrect plus unclassified) over the total number of instances. * * @return the average cost. */ public final double avgCost() { return m_TotalCost / m_WithClass; } /** * Gets the number of instances correctly classified (that is, for
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -