📄 ensembleevaluation.java
字号:
if (template.classAttribute().isNumeric()) { text.append("\n\n\n" + testingEvaluation. toSummaryString("=== Cross-validation ===\n", printComplexityStatistics)); } else { text.append("\n\n\n" + testingEvaluation. toSummaryString("=== Stratified " + "cross-validation ===\n", printComplexityStatistics)); } } if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + testingEvaluation.toClassDetailsString()); } text.append("\n\n" + testingEvaluation.toMatrixString()); } return text.toString(); } /** * Attempts to load a cost matrix. * * @param costFileName the filename of the cost matrix * @param numClasses the number of classes that should be in the cost matrix * (only used if the cost file is in old format). * @return a <code>CostMatrix</code> value, or null if costFileName is empty * @exception Exception if an error occurs. */ private static CostMatrix handleCostOption(String costFileName, int numClasses) throws Exception { if ((costFileName != null) && (costFileName.length() != 0)) { System.out.println( "NOTE: The behaviour of the -m option has changed between WEKA 3.0" +" and WEKA 3.1. -m now carries out cost-sensitive *evaluation*" +" only. For cost-sensitive *prediction*, use one of the" +" cost-sensitive metaschemes such as" +" weka.classifiers.meta.CostSensitiveClassifier or" +" weka.classifiers.meta.MetaCost"); Reader costReader = null; try { costReader = new BufferedReader(new FileReader(costFileName)); } catch (Exception e) { throw new Exception("Can't open file " + e.getMessage() + '.'); } try { // First try as a proper cost matrix format return new CostMatrix(costReader); } catch (Exception ex) { try { // Now try as the poxy old format :-) //System.err.println("Attempting to read old format cost file"); try { costReader.close(); // Close the old one costReader = new BufferedReader(new FileReader(costFileName)); } catch (Exception e) { throw new Exception("Can't open file " + e.getMessage() + '.'); } CostMatrix costMatrix = new CostMatrix(numClasses); //System.err.println("Created default cost matrix"); costMatrix.readOldFormat(costReader); return costMatrix; //System.err.println("Read old format"); } catch (Exception e2) { // re-throw the original exception //System.err.println("Re-throwing original exception"); throw ex; } } } else { return null; } } /** * Evaluates the classifier on a given set of instances. * * @param classifier machine learning classifier * @param data set of test instances for evaluation * @exception Exception if model could not be evaluated * successfully */ public void evaluateModel(Classifier classifier, Instances data) throws Exception { m_EnsembleSize = ((EnsembleClassifier)classifier).getEnsembleSize(); m_EnsembleWts = ((EnsembleClassifier)classifier).getEnsembleWts(); for(int j=0; j<m_EnsembleSize; j++) m_SumEnsembleWts += m_EnsembleWts[j]; //DEBUG System.out.println("Ensemble size = "+m_EnsembleSize); if(m_SumEnsembleWts == 0.0){ System.out.println("Ensemble wts sum to 0!"); for(int j=0; j<m_EnsembleWts.length; j++) System.out.print("\t"+m_EnsembleWts[j]); System.out.println(); } for (int i = 0; i < data.numInstances(); i++) { evaluateModelOnce((Classifier)classifier, data.instance(i)); } } /** * Evaluates the classifier on a single instance. * * @param classifier machine learning classifier * @param instance the test instance to be classified * @return the prediction made by the clasifier * @exception Exception if model could not be evaluated * successfully or the data contains string attributes */ public double evaluateModelOnce(Classifier classifier, Instance instance) throws Exception { Instance classMissing = (Instance)instance.copy(); double pred=0; classMissing.setDataset(instance.dataset()); classMissing.setClassMissing(); //DEBUG try{ if (m_ClassIsNominal) { if (classifier instanceof DistributionClassifier) { double [] dist = ((DistributionClassifier)classifier). distributionForInstance(classMissing); pred = Utils.maxIndex(dist); //=============== BEGIN EDIT melville =============== double []ensemblePreds = ((EnsembleClassifier)classifier).getEnsemblePredictions(classMissing); updateEnsembleStats(pred, instance, ensemblePreds); //=============== END EDIT melville =============== updateStatsForClassifier(dist, instance); } else { pred = classifier.classifyInstance(classMissing); //=============== BEGIN EDIT melville =============== double []ensemblePreds = ((EnsembleClassifier)classifier).getEnsemblePredictions(classMissing); updateEnsembleStats(pred, instance, ensemblePreds); //=============== END EDIT melville =============== updateStatsForClassifier(makeDistribution(pred), instance); } } else { pred = classifier.classifyInstance(classMissing); updateStatsForPredictor(pred, instance); } }catch(Exception e){ e.printStackTrace(); System.exit(0); } return pred; } //=============== BEGIN EDIT melville =============== /** * Update statistics for ensemble classifiers. * * @param pred ensemble prediction * @param instance test instance * @param ensemblePreds predictions of ensemble members */ public void updateEnsembleStats(double pred, Instance instance, double []ensemblePreds){ //System.out.print("Updating Ensemble Stats..."); double sumEnsembleError = 0, sumEnsembleDiversity = 0; double actualClass = instance.classValue(); for(int i=0; i<m_EnsembleSize; i++){ if(actualClass != ensemblePreds[i]) sumEnsembleError += m_EnsembleWts[i]; //if member's prediction differs from the ensemble prediction, diversity increases if(pred != ensemblePreds[i]) sumEnsembleDiversity += m_EnsembleWts[i]; } m_EnsembleIncorrect += ((sumEnsembleError/m_SumEnsembleWts)*instance.weight()); m_EnsembleDiversity += ((sumEnsembleDiversity/m_SumEnsembleWts)*instance.weight()); //System.out.println("\t"+m_EnsembleIncorrect+"\t"+m_EnsembleDiversity+"\tDone"); } //=============== END EDIT melville =============== /** * Evaluates the supplied distribution on a single instance. * * @param dist the supplied distribution * @param instance the test instance to be classified * @exception Exception if model could not be evaluated * successfully */ public double evaluateModelOnce(double [] dist, Instance instance) throws Exception { double pred; if (m_ClassIsNominal) { pred = Utils.maxIndex(dist); updateStatsForClassifier(dist, instance); } else { pred = dist[0]; updateStatsForPredictor(pred, instance); } return pred; } /** * Evaluates the supplied prediction on a single instance. * * @param prediction the supplied prediction * @param instance the test instance to be classified * @exception Exception if model could not be evaluated * successfully */ public void evaluateModelOnce(double prediction, Instance instance) throws Exception { if (m_ClassIsNominal) { updateStatsForClassifier(makeDistribution(prediction), instance); } else { updateStatsForPredictor(prediction, instance); } } /** * Wraps a static classifier in enough source to test using the weka * class libraries. * * @param classifier a Sourcable Classifier * @param className the name to give to the source code class * @return the source for a static classifier that can be tested with * weka libraries. */ protected static String wekaStaticWrapper(Sourcable classifier, String className) throws Exception { //String className = "StaticClassifier"; String staticClassifier = classifier.toSource(className); return "package weka.classifiers;\n" +"import weka.core.Attribute;\n" +"import weka.core.Instance;\n" +"import weka.core.Instances;\n" +"import weka.classifiers.Classifier;\n\n" +"public class WekaWrapper extends Classifier {\n\n" +" public void buildClassifier(Instances i) throws Exception {\n" +" }\n\n" +" public double classifyInstance(Instance i) throws Exception {\n\n" +" Object [] s = new Object [i.numAttributes()];\n" +" for (int j = 0; j < s.length; j++) {\n" +" if (!i.isMissing(j)) {\n" +" if (i.attribute(j).type() == Attribute.NOMINAL) {\n" +" s[j] = i.attribute(j).value((int) i.value(j));\n" +" } else if (i.attribute(j).type() == Attribute.NUMERIC) {\n" +" s[j] = new Double(i.value(j));\n" +" }\n" +" }\n" +" }\n" +" return " + className + ".classify(s);\n" +" }\n\n" +"}\n\n" +staticClassifier; // The static classifer class } /** * Gets the number of test instances that had a known class value * (actually the sum of the weights of test instances with known * class value). * * @return the number of test instances with known class */ public final double numInstances() { return m_WithClass; } /** * Gets the number of instances incorrectly classified (that is, for * which an incorrect prediction was made). (Actually the sum of the weights * of these instances) * * @return the number of incorrectly classified instances */ public final double incorrect() { return m_Incorrect; } /** * Gets the percentage of instances incorrectly classified (that is, for * which an incorrect prediction was made). * * @return the percent of incorrectly classified instances * (between 0 and 100) */ public final double pctIncorrect() { return 100 * m_Incorrect / m_WithClass; } /** * Gets the ensemble mean of percentage of instances incorrectly classified (that is, for * which an incorrect prediction was made). * * @return the percent of incorrectly classified instances * (between 0 and 100) */ public final double ensemblePctIncorrect() { return 100 * m_EnsembleIncorrect / m_WithClass; } /** * Gets the total cost, that is, the cost of each prediction times the * weight of the instance, summed over all instances. * * @return the total cost */ public final double totalCost() { return m_TotalCost; } /** * Gets the average cost, that is, total cost of misclassifications * (incorrect plus unclassified) over the total number of instances. * * @return the average cost. */ public final double avgCost() { return m_TotalCost / m_WithClass; } /** * Gets the number of instances correctly classified (that is, for * which a correct prediction was made). (Actually the sum of the weights * of these instances) * * @return the number of correctly classified instances */ public final double correct() { return m_Correct; } /** * Gets the percentage of instances correctly classified (that is, for * which a correct prediction was made). * * @return the percent of correctly classified instances (between 0 and 100) */ public final double pctCorrect() { return 100 * m_Correct / m_WithClass; } /** * Gets the ensemble mean of percentage of instances correctly classified (that is, for * which a correct prediction was made). * * @return the percent of correctly classified instances (between 0 and 100) */ public final double ensemblePctCorrect() { return 100 - ensemblePctIncorrect(); } /** * Gets the mean ensemble diversity. * * @return the mean ensemble diversity (between 0 and 100) */ public final double ensembleDiversity() { return 100 * m_EnsembleDiversity / m_WithClass; } /** * Gets the number of instances not classified (that is, for * which no prediction was made by the classifier). (Actually the sum * of the weights of these instances) * * @return the number of unclassified instances */ public final double unclassified() { return m_Unclassified; } /** * Gets the percentage of instances not classified (that is, for * which no prediction was made by the classifier). * * @return the percent of unclassified instances (between 0 and 100) */ public final double pctUnclassified() { return 100 * m_Unclassified / m_WithClass; } /** * Returns the estimated error rate or the root mean squared error * (if the class is numeric). If a cost matrix was given this * error rate gives the average cost. * * @return the estimated error rate (between 0 and 1, or between 0 and * maximum cost) */ public final double errorRate() { if (!m_ClassIsNominal) { return Math.sqrt(m_SumSqrErr / m_WithClass); } if (m_CostMatrix == null) { return m_Incorrect / m_WithClass; } else { return avgCost();
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -