evaluation.java
来自「Java 编写的多种数据挖掘算法 包括聚类、分类、预处理等」· Java 代码 · 共 2,078 行 · 第 1/5 页
JAVA
2,078 行
if (objectOutputFileName.endsWith(".koml")) { KOML.write(xmlOutputStream, classifier); } xmlOutputStream.close(); } } // If classifier is drawable output string describing graph if ((classifier instanceof Drawable) && (printGraph)){ return ((Drawable)classifier).graph(); } // Output the classifier as equivalent source if ((classifier instanceof Sourcable) && (printSource)){ return wekaStaticWrapper((Sourcable) classifier, sourceClass); } // Output test instance predictions only if (printClassifications) { return printClassifications(classifier, new Instances(template, 0), testFileName, classIndex, attributesToOutput); } // Output model if (!(noOutput || printMargins)) { if (classifier instanceof OptionHandler) { if (schemeOptionsText != null) { text.append("\nOptions: "+schemeOptionsText); text.append("\n"); } } text.append("\n" + classifier.toString() + "\n"); } if (!printMargins && (costMatrix != null)) { text.append("\n=== Evaluation Cost Matrix ===\n\n") .append(costMatrix.toString()); } // Compute error estimate from training data if ((trainStatistics) && (trainFileName.length() != 0)) { if ((classifier instanceof UpdateableClassifier) && (testFileName.length() != 0) && (costMatrix == null)) { // Classifier was trained incrementally, so we have to // reopen the training data in order to test on it. trainReader = new BufferedReader(new FileReader(trainFileName)); // Incremental testing train = new Instances(trainReader, 1); if (classIndex != -1) { train.setClassIndex(classIndex - 1); } else { train.setClassIndex(train.numAttributes() - 1); } testTimeStart = System.currentTimeMillis(); while (train.readInstance(trainReader)) { trainingEvaluation. evaluateModelOnce((Classifier)classifier, train.instance(0)); train.delete(0); } testTimeElapsed = System.currentTimeMillis() - testTimeStart; trainReader.close(); } else { testTimeStart = System.currentTimeMillis(); trainingEvaluation.evaluateModel(classifier, train); testTimeElapsed = System.currentTimeMillis() - testTimeStart; } // Print the results of the training evaluation if (printMargins) { return trainingEvaluation.toCumulativeMarginDistributionString(); } else { text.append("\nTime taken to build model: " + Utils.doubleToString(trainTimeElapsed / 1000.0,2) + " seconds"); text.append("\nTime taken to test model on training data: " + Utils.doubleToString(testTimeElapsed / 1000.0,2) + " seconds"); text.append(trainingEvaluation. toSummaryString("\n\n=== Error on training" + " data ===\n", printComplexityStatistics)); if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + trainingEvaluation.toClassDetailsString()); } text.append("\n\n" + trainingEvaluation.toMatrixString()); } } } // Compute proper error estimates if (testFileName.length() != 0) { // Testing is on the supplied test data while (test.readInstance(testReader)) { testingEvaluation.evaluateModelOnce((Classifier)classifier, test.instance(0)); test.delete(0); } testReader.close(); text.append("\n\n" + testingEvaluation. toSummaryString("=== Error on test data ===\n", printComplexityStatistics)); } else if (trainFileName.length() != 0) { // Testing is via cross-validation on training data Random random = new Random(seed); // use untrained (!) classifier for cross-validation classifier = Classifier.makeCopy(classifierBackup); testingEvaluation.crossValidateModel(classifier, train, folds, random); if (template.classAttribute().isNumeric()) { text.append("\n\n\n" + testingEvaluation. toSummaryString("=== Cross-validation ===\n", printComplexityStatistics)); } else { text.append("\n\n\n" + testingEvaluation. toSummaryString("=== Stratified " + "cross-validation ===\n", printComplexityStatistics)); } } if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + testingEvaluation.toClassDetailsString()); } text.append("\n\n" + testingEvaluation.toMatrixString()); } return text.toString(); } /** * Attempts to load a cost matrix. * * @param costFileName the filename of the cost matrix * @param numClasses the number of classes that should be in the cost matrix * (only used if the cost file is in old format). * @return a <code>CostMatrix</code> value, or null if costFileName is empty * @exception Exception if an error occurs. */ protected static CostMatrix handleCostOption(String costFileName, int numClasses) throws Exception { if ((costFileName != null) && (costFileName.length() != 0)) { System.out.println( "NOTE: The behaviour of the -m option has changed between WEKA 3.0" +" and WEKA 3.1. -m now carries out cost-sensitive *evaluation*" +" only. For cost-sensitive *prediction*, use one of the" +" cost-sensitive metaschemes such as" +" weka.classifiers.meta.CostSensitiveClassifier or" +" weka.classifiers.meta.MetaCost"); Reader costReader = null; try { costReader = new BufferedReader(new FileReader(costFileName)); } catch (Exception e) { throw new Exception("Can't open file " + e.getMessage() + '.'); } try { // First try as a proper cost matrix format return new CostMatrix(costReader); } catch (Exception ex) { try { // Now try as the poxy old format :-) //System.err.println("Attempting to read old format cost file"); try { costReader.close(); // Close the old one costReader = new BufferedReader(new FileReader(costFileName)); } catch (Exception e) { throw new Exception("Can't open file " + e.getMessage() + '.'); } CostMatrix costMatrix = new CostMatrix(numClasses); //System.err.println("Created default cost matrix"); costMatrix.readOldFormat(costReader); return costMatrix; //System.err.println("Read old format"); } catch (Exception e2) { // re-throw the original exception //System.err.println("Re-throwing original exception"); throw ex; } } } else { return null; } } /** * Evaluates the classifier on a given set of instances. Note that * the data must have exactly the same format (e.g. order of * attributes) as the data used to train the classifier! Otherwise * the results will generally be meaningless. * * @param classifier machine learning classifier * @param data set of test instances for evaluation * @exception Exception if model could not be evaluated * successfully */ public double[] evaluateModel(Classifier classifier, Instances data) throws Exception { double predictions[] = new double[data.numInstances()]; // Need to be able to collect predictions if appropriate (for AUC) for (int i = 0; i < data.numInstances(); i++) { predictions[i] = evaluateModelOnceAndRecordPrediction((Classifier)classifier, data.instance(i)); } return predictions; } /** * Evaluates the classifier on a single instance and records the * prediction (if the class is nominal). * * @param classifier machine learning classifier * @param instance the test instance to be classified * @return the prediction made by the clasifier * @exception Exception if model could not be evaluated * successfully or the data contains string attributes */ public double evaluateModelOnceAndRecordPrediction(Classifier classifier, Instance instance) throws Exception { Instance classMissing = (Instance)instance.copy(); double pred = 0; classMissing.setDataset(instance.dataset()); classMissing.setClassMissing(); if (m_ClassIsNominal) { if (m_Predictions == null) { m_Predictions = new FastVector(); } double [] dist = classifier.distributionForInstance(classMissing); pred = Utils.maxIndex(dist); if (dist[(int)pred] <= 0) { pred = Instance.missingValue(); } updateStatsForClassifier(dist, instance); m_Predictions.addElement(new NominalPrediction(instance.classValue(), dist, instance.weight())); } else { pred = classifier.classifyInstance(classMissing); updateStatsForPredictor(pred, instance); } return pred; } /** * Evaluates the classifier on a single instance. * * @param classifier machine learning classifier * @param instance the test instance to be classified * @return the prediction made by the clasifier * @exception Exception if model could not be evaluated * successfully or the data contains string attributes */ public double evaluateModelOnce(Classifier classifier, Instance instance) throws Exception { Instance classMissing = (Instance)instance.copy(); double pred = 0; classMissing.setDataset(instance.dataset()); classMissing.setClassMissing(); if (m_ClassIsNominal) { double [] dist = classifier.distributionForInstance(classMissing); pred = Utils.maxIndex(dist); if (dist[(int)pred] <= 0) { pred = Instance.missingValue(); } updateStatsForClassifier(dist, instance); } else { pred = classifier.classifyInstance(classMissing); updateStatsForPredictor(pred, instance); } return pred; } /** * Evaluates the supplied distribution on a single instance. * * @param dist the supplied distribution * @param instance the test instance to be classified * @exception Exception if model could not be evaluated * successfully */ public double evaluateModelOnce(double [] dist, Instance instance) throws Exception { double pred; if (m_ClassIsNominal) { pred = Utils.maxIndex(dist); if (dist[(int)pred] <= 0) { pred = Instance.missingValue(); } updateStatsForClassifier(dist, instance); } else { pred = dist[0]; updateStatsForPredictor(pred, instance); } return pred; } /** * Evaluates the supplied prediction on a single instance. * * @param prediction the supplied prediction * @param instance the test instance to be classified * @exception Exception if model could not be evaluated * successfully */ public void evaluateModelOnce(double prediction, Instance instance) throws Exception { if (m_ClassIsNominal) { updateStatsForClassifier(makeDistribution(prediction), instance); } else { updateStatsForPredictor(prediction, instance); } } /** * Returns the predictions that have been collected. * * @return a reference to the FastVector containing the predictions * that have been collected. This should be null if no predictions * have been collected (e.g. if the class is numeric). */ public FastVector predictions() { return m_Predictions; } /** * Wraps a static classifier in enough source to test using the weka * class libraries. * * @param classifier a Sourcable Classifier * @param className the name to give to the source code class * @return the source for a static classifier that can be tested with * weka libraries. */ protected static String wekaStaticWrapper(Sourcable classifier, String className) throws Exception { //String className = "StaticClassifier"; String staticClassifier = classifier.toSource(className); return "package weka.classifiers;\n" +"import weka.core.Attribute;\n" +"import weka.core.Instance;\n" +"import weka.core.Instances;\n" +"import weka.classifiers.Classifier;\n\n" +"public class WekaWrapper extends Classifier {\n\n" +" public void buildClassifier(Instances i) throws Exception {\n" +" }\n\n" +" public double classifyInstance(Instance i) throws Exception {\n\n" +" Object [] s = new Object [i.numAttributes()];\n" +" for (int j = 0; j < s.length; j++) {\n" +" if (!i.isMissing(j)) {\n" +" if (i.attribute(j).type() == Attribute.NOMINAL) {\n" +" s[j] = i.attribute(j).value((int) i.value(j));\n" +" } else if (i.attribute(j).type() == Attribute.NUMERIC) {\n" +" s[j] = new Double(i.value(j));\n" +" }\n" +" }\n" +" }\n" +" return " + className + ".classify(s);\n" +" }\n\n" +"}\n\n" +staticClassifier; // The static classifer class } /** * Gets the number of test instances that had a known class value * (actually the sum of the weights of test instances with known * class value). * * @return the number of test instances with known class */ public final double numInstances() { return m_WithClass; } /** * Gets the number of instances incorrectly classified (that is, for * which an incorrect prediction was made). (Actually the sum of the weights * of these instances) * * @return the number of incorrectly classified instances */ public final double incorrect() { return m_Incorrect; } /** * Gets the percentage of instances incorrectly classified (that is, for * which an incorrect prediction was made). * * @return the percent of incorrectly classified instances * (between 0 and 100)
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?