📄 evaluation.java
字号:
} // Build the classifier if no object file provided if ((classifier instanceof UpdateableClassifier) && (testFileName.length() != 0) && (costMatrix == null) && (trainFileName.length() != 0)) { // Build classifier incrementally trainingEvaluation.setPriors(train); testingEvaluation.setPriors(train); trainTimeStart = System.currentTimeMillis(); if (objectInputFileName.length() == 0) { classifier.buildClassifier(train); } while (train.readInstance(trainReader)) { trainingEvaluation.updatePriors(train.instance(0)); testingEvaluation.updatePriors(train.instance(0)); ((UpdateableClassifier)classifier). updateClassifier(train.instance(0)); train.delete(0); } trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; trainReader.close(); } else if (objectInputFileName.length() == 0) { // Build classifier in one go tempTrain = new Instances(train); trainingEvaluation.setPriors(tempTrain); testingEvaluation.setPriors(tempTrain); trainTimeStart = System.currentTimeMillis(); classifier.buildClassifier(tempTrain); trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; } // Save the classifier if an object output file is provided if (objectOutputFileName.length() != 0) { OutputStream os = new FileOutputStream(objectOutputFileName); if (objectOutputFileName.endsWith(".gz")) { os = new GZIPOutputStream(os); } ObjectOutputStream objectOutputStream = new ObjectOutputStream(os); objectOutputStream.writeObject(classifier); objectOutputStream.flush(); objectOutputStream.close(); } // If classifier is drawable output string describing graph if ((classifier instanceof Drawable) && (printGraph)){ return ((Drawable)classifier).graph(); } // Output the classifier as equivalent source if ((classifier instanceof Sourcable) && (printSource)){ return wekaStaticWrapper((Sourcable) classifier, sourceClass); } // Output test instance predictions only if (printClassifications) { return printClassifications(classifier, new Instances(template, 0), testFileName, classIndex, attributesToOutput); } // Output model if (!(noOutput || printMargins)) { if (classifier instanceof OptionHandler) { if (schemeOptionsText != null) { text.append("\nOptions: "+schemeOptionsText); text.append("\n"); } } text.append("\n" + classifier.toString() + "\n"); } if (!printMargins && (costMatrix != null)) { text.append("\n=== Evaluation Cost Matrix ===\n\n") .append(costMatrix.toString()); } // Compute error estimate from training data if ((trainStatistics) && (trainFileName.length() != 0)) { if ((classifier instanceof UpdateableClassifier) && (testFileName.length() != 0) && (costMatrix == null)) { // Classifier was trained incrementally, so we have to // reopen the training data in order to test on it. trainReader = new BufferedReader(new FileReader(trainFileName)); // Incremental testing train = new Instances(trainReader, 1); if (classIndex != -1) { train.setClassIndex(classIndex - 1); } else { train.setClassIndex(train.numAttributes() - 1); } testTimeStart = System.currentTimeMillis(); while (train.readInstance(trainReader)) { trainingEvaluation. evaluateModelOnce((Classifier)classifier, train.instance(0)); train.delete(0); } testTimeElapsed = System.currentTimeMillis() - testTimeStart; trainReader.close(); } else { testTimeStart = System.currentTimeMillis(); trainingEvaluation.evaluateModel(classifier, train); testTimeElapsed = System.currentTimeMillis() - testTimeStart; } // Print the results of the training evaluation if (printMargins) { return trainingEvaluation.toCumulativeMarginDistributionString(); } else { text.append("\nTime taken to build model: " + Utils.doubleToString(trainTimeElapsed / 1000.0,2) + " seconds"); text.append("\nTime taken to test model on training data: " + Utils.doubleToString(testTimeElapsed / 1000.0,2) + " seconds"); text.append(trainingEvaluation. toSummaryString("\n\n=== Error on training" + " data ===\n", printComplexityStatistics)); if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + trainingEvaluation.toClassDetailsString()); } text.append("\n\n" + trainingEvaluation.toMatrixString()); } } } // Compute proper error estimates if (testFileName.length() != 0) { // Testing is on the supplied test data while (test.readInstance(testReader)) { testingEvaluation.evaluateModelOnce((Classifier)classifier, test.instance(0)); test.delete(0); } testReader.close(); text.append("\n\n" + testingEvaluation. toSummaryString("=== Error on test data ===\n", printComplexityStatistics)); } else if (trainFileName.length() != 0) { // Testing is via cross-validation on training data random = new Random(seed); random.setSeed(seed); train.randomize(random); // Added by Sebastian Celis for parallelization if(runInParallel) { testingEvaluation. crossValidateModelParallel(classifier, train, folds, otherComputers); } else { testingEvaluation. crossValidateModel(classifier, train, folds); } if (template.classAttribute().isNumeric()) { text.append("\n\n\n" + testingEvaluation. toSummaryString("=== Cross-validation ===\n", printComplexityStatistics)); } else { text.append("\n\n\n" + testingEvaluation. toSummaryString("=== Stratified " + "cross-validation ===\n", printComplexityStatistics)); } // Added by Sebastian Celis for parallelization if(runInParallel) { text.append("\nCross-validation ran in parallel using this "+ "computer and the following machines:\n"); if(otherComputers.length() == 0) text.append("Was not able to connect to other computers. "+ "Parallelism did not occur."); else text.append(otherComputers); } } if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + testingEvaluation.toClassDetailsString()); } text.append("\n\n" + testingEvaluation.toMatrixString()); } return text.toString(); } /** * Attempts to load a cost matrix. * * @param costFileName the filename of the cost matrix * @param numClasses the number of classes that should be in the cost matrix * (only used if the cost file is in old format). * @return a <code>CostMatrix</code> value, or null if costFileName is empty * @exception Exception if an error occurs. */ private static CostMatrix handleCostOption(String costFileName, int numClasses) throws Exception { if ((costFileName != null) && (costFileName.length() != 0)) { System.out.println( "NOTE: The behaviour of the -m option has changed between WEKA 3.0" +" and WEKA 3.1. -m now carries out cost-sensitive *evaluation*" +" only. For cost-sensitive *prediction*, use one of the" +" cost-sensitive metaschemes such as" +" weka.classifiers.CostSensitiveClassifier or" +" weka.classifiers.MetaCost"); Reader costReader = null; try { costReader = new BufferedReader(new FileReader(costFileName)); } catch (Exception e) { throw new Exception("Can't open file " + e.getMessage() + '.'); } try { // First try as a proper cost matrix format return new CostMatrix(costReader); } catch (Exception ex) { try { // Now try as the poxy old format :-) //System.err.println("Attempting to read old format cost file"); try { costReader.close(); // Close the old one costReader = new BufferedReader(new FileReader(costFileName)); } catch (Exception e) { throw new Exception("Can't open file " + e.getMessage() + '.'); } CostMatrix costMatrix = new CostMatrix(numClasses); //System.err.println("Created default cost matrix"); costMatrix.readOldFormat(costReader); return costMatrix; //System.err.println("Read old format"); } catch (Exception e2) { // re-throw the original exception //System.err.println("Re-throwing original exception"); throw ex; } } } else { return null; } } /** * Evaluates the classifier on a given set of instances. * * @param classifier machine learning classifier * @param data set of test instances for evaluation * @exception Exception if model could not be evaluated * successfully */ public void evaluateModel(Classifier classifier, Instances data) throws Exception { double [] predicted; for (int i = 0; i < data.numInstances(); i++) { evaluateModelOnce((Classifier)classifier, data.instance(i)); } } /** * Evaluates the classifier on a single instance. * * @param classifier machine learning classifier * @param instance the test instance to be classified * @return the prediction made by the clasifier * @exception Exception if model could not be evaluated * successfully or the data contains string attributes */ public double evaluateModelOnce(Classifier classifier, Instance instance) throws Exception { Instance classMissing = (Instance)instance.copy(); double pred=0; classMissing.setDataset(instance.dataset()); classMissing.setClassMissing(); if (m_ClassIsNominal) { if (classifier instanceof DistributionClassifier) { double [] dist = ((DistributionClassifier)classifier). distributionForInstance(classMissing); pred = Utils.maxIndex(dist); updateStatsForClassifier(dist, instance); } else { pred = classifier.classifyInstance(classMissing); updateStatsForClassifier(makeDistribution(pred), instance); } } else { pred = classifier.classifyInstance(classMissing); updateStatsForPredictor(pred, instance); } return pred; } /** * Evaluates the supplied distribution on a single instance. * * @param dist the supplied distribution * @param instance the test instance to be classified * @exception Exception if model could not be evaluated * successfully */ public double evaluateModelOnce(double [] dist, Instance instance) throws Exception { double pred; if (m_ClassIsNominal) { pred = Utils.maxIndex(dist); updateStatsForClassifier(dist, instance); } else { pred = dist[0]; updateStatsForPredictor(pred, instance); } return pred; } /** * Evaluates the supplied prediction on a single instance. * * @param prediction the supplied prediction * @param instance the test instance to be classified * @exception Exception if model could not be evaluated * successfully */ public void evaluateModelOnce(double prediction, Instance instance) throws Exception { if (m_ClassIsNominal) { updateStatsForClassifier(makeDistribution(prediction), instance); } else { updateStatsForPredictor(prediction, instance); } } /** * Wraps a static classifier in enough source to test using the weka * class libraries. * * @param classifier a Sourcable Classifier * @param className the name to give to the source code class * @return the source for a static classifier that can be tested with * weka libraries. */ protected static String wekaStaticWrapper(Sourcable classifier, String className) throws Exception { //String className = "StaticClassifier"; String staticClassifier = classifier.toSource(className); return "package weka.classifiers;\n" +"import weka.core.Attribute;\n" +"import weka.core.Instance;\n" +"import weka.core.Instances;\n" +"import weka.classifiers.Classifier;\n\n" +"public class WekaWrapper extends Classifier {\n\n" +" public void buildClassifier(Instances i) throws Exception {\n" +" }\n\n" +" public double classifyInstance(Instance i) throws Exception {\n\n" +" Object [] s = new Object [i.numAttributes()];\n" +" for (int j = 0; j < s.length; j++) {\n" +" if (!i.isMissing(j)) {\n" +" if (i.attribute(j).type() == Attribute.NOMINAL) {\n" +" s[j] = i.attribute(j).value((int) i.value(j));\n" +" } else if (i.attribute(j).type() == Attribute.NUMERIC) {\n" +" s[j] = new Double(i.value(j));\n" +" }\n" +" }\n" +" }\n" +" return " + className + ".classify(s);\n" +" }\n\n" +"}\n\n" +staticClassifier; // The static classifer class } /** * Returns the estimated error rate or the root mean squared error * (if the class is numeric). If a cost matrix was given this * error rate gives the average cost. * * @return the estimated error rate (between 0 and 1, or between 0 and * maximum cost) */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -