📄 evaluation.java
字号:
public CostMatrix costMatrix() { CostMatrix newMatrix = new CostMatrix(m_CostMatrix); return newMatrix; } /** * Gets the total cost, that is, the cost of each prediction times the * weight of the instance, summed over all instances. * * @return the total cost */ public final double totalCost() { return m_TotalCost; } /** * Gets the value of the sumErr variable. * * @return the value of the sumErr variable. */ public final double sumErr() { return m_SumErr; } /** * Gets the value of the sumAbsErr variable. * * @return the value of the sumAbsErr variable */ public final double sumAbsErr() { return m_SumAbsErr; } /** * Gets the value of the sumSqrErr variable. * * @return the value of the sumSqrErr variable */ public final double sumSqrErr() { return m_SumSqrErr; } /** * Gets the value of the sumClass variable. * * @return the value of the sumClass variable */ public final double sumClass() { return m_SumClass; } /** * Gets the value of the sumSqrClass variable. * * @return the value of the sumSqrClass variable */ public final double sumSqrClass() { return m_SumSqrClass; } /** * Gets the value of the sumPredicted variable. * * @return the value of the sumPredicted variable */ public final double sumPredicted() { return m_SumPredicted; } /** * Gets the value of the sumSqrPredicted variable. * * @return the value of the sumSqrPredicted variable */ public final double sumSqrPredicted() { return m_SumSqrPredicted; } /** * Gets the value of the sumClassPredicted variable. * * @return the value of the sumClassPredicted variable */ public final double sumClassPredicted() { return m_SumClassPredicted; } /** * Gets the value of the sumPriorAbsError variable. * * @return the value of the sumPriorAbsError variable */ public final double sumPriorAbsErr() { return m_SumPriorAbsErr; } /** * Gets the value of the sumPriorSqrErr variable. * * @return the value of the sumPriorSqrErr variable */ public final double sumPriorSqrErr() { return m_SumPriorSqrErr; } /** * Gets the value of the sumKBInfo variable. * * @return the value of the sumKBInfo variable */ public final double sumKBInfo() { return m_SumKBInfo; } /** * Gets the value of the marginResolution variable. * * @return the value of the marginResolution variable */ public final int marginResolution() { return k_MarginResolution; } /** * Gets a copy of the marginCounts array. * * @return a copy of the marginCounts array */ public double [] marginCounts() { double [] newCounts = new double[m_MarginCounts.length]; for(int i = 0; i < m_MarginCounts.length; i++) { newCounts[i] = m_MarginCounts[i]; } return newCounts; } /** * Gets the value of the numTrainClassVals variable. * * @return the value of the numTrainClassVals variable */ public final int numTrainClassVals() { return m_NumTrainClassVals; } /** * Gets a copy of the trainClassVals array. * * @return a copy of the trainClassVals array */ public double [] trainClassVals() { double [] newVals = new double[m_TrainClassVals.length]; for(int i = 0; i < m_TrainClassVals.length; i++) { newVals[i] = m_TrainClassVals[i]; } return newVals; } /** * Gets a copy of the trainClassWeights array. * * @return a copy of the trainClassWeights array */ public double [] trainClassWeights() { double [] newWeights = new double[m_TrainClassWeights.length]; for(int i = 0; i < m_TrainClassWeights.length; i++) { newWeights[i] = m_TrainClassWeights[i]; } return newWeights; } /** * Gets the value of the sumPriorEntropy variable. * * @return the value of the sumPriorEntropy variable */ public final double sumPriorEntropy() { return m_SumPriorEntropy; } /** * Gets the value of the sumSchemeEntropy variable. * * @return the value of the sumSchemeEntropy variable */ public final double sumSchemeEntropy() { return m_SumSchemeEntropy; } /** * Gets the number of test instances that had a known class value * (actually the sum of the weights of test instances with known * class value). * * @return the number of test instances with known class */ public final double numInstances() { return m_WithClass; } /** * Gets the percentage of instances incorrectly classified (that is, for * which an incorrect prediction was made). * * @return the percent of incorrectly classified instances * (between 0 and 100) */ public final double pctIncorrect() { return 100 * m_Incorrect / m_WithClass; } /** * Gets the average cost, that is, total cost of misclassifications * (incorrect plus unclassified) over the total number of instances. * * @return the average cost. */ public final double avgCost() { return m_TotalCost / m_WithClass; } /** * Gets the percentage of instances correctly classified (that is, for * which a correct prediction was made). * * @return the percent of correctly classified instances * (between 0 and 100) */ public final double pctCorrect() { return 100 * m_Correct / m_WithClass; } /** * Gets the percentage of instances not classified (that is, for * which no prediction was made by the classifier). * * @return the percent of unclassified instances (between 0 and 100) */ public final double pctUnclassified() { return 100 * m_Unclassified / m_WithClass; } /** * Aggregates data obtained from running different folds on different * machines. Used when the -a flag is set to run the cross-validation * in parallel. * * @param evaluation the data sent back from another machine */ public void aggregate(Evaluation evaluation) { this.m_Incorrect += evaluation.incorrect(); this.m_Correct += evaluation.correct(); this.m_Unclassified += evaluation.unclassified(); this.m_MissingClass += evaluation.missingClass(); this.m_WithClass += evaluation.withClass(); double [][] newMatrix = evaluation.confusionMatrix(); if(newMatrix != null) { for(int i = 0; i < this.m_ConfusionMatrix.length; i++) for(int j = 0; j < this.m_ConfusionMatrix[i].length; j++) this.m_ConfusionMatrix[i][j] += newMatrix[i][j]; } double [] newClassPriors = evaluation.classPriors(); if(newClassPriors != null) { for(int i = 0; i < this.m_ClassPriors.length; i++) this.m_ClassPriors[i] = newClassPriors[i]; } this.m_ClassPriorsSum = evaluation.classPriorsSum(); this.m_TotalCost += evaluation.totalCost(); this.m_SumErr += evaluation.sumErr(); this.m_SumAbsErr += evaluation.sumAbsErr(); this.m_SumSqrErr += evaluation.sumSqrErr(); this.m_SumClass += evaluation.sumClass(); this.m_SumSqrClass += evaluation.sumSqrClass(); this.m_SumPredicted += evaluation.sumPredicted(); this.m_SumSqrPredicted += evaluation.sumSqrPredicted(); this.m_SumClassPredicted += evaluation.sumClassPredicted(); this.m_SumPriorAbsErr += evaluation.sumPriorAbsErr(); this.m_SumPriorSqrErr += evaluation.sumPriorSqrErr(); this.m_SumKBInfo += evaluation.sumKBInfo(); double [] newMarginCounts = evaluation.marginCounts(); if(newMarginCounts != null) { for(int i = 0; i < this.m_MarginCounts.length; i++) this.m_MarginCounts[i] += newMarginCounts[i]; } this.m_SumPriorEntropy += evaluation.sumPriorEntropy(); this.m_SumSchemeEntropy += evaluation.sumSchemeEntropy(); } /** * Initializes all the counters for the evaluation and also takes a * cost matrix as parameter. * * @param data set of instances, to get some header information * @param costMatrix the cost matrix---if null, default costs will be used * @exception Exception if cost matrix is not compatible with * data, the class is not defined or the class is numeric */ public Evaluation(Instances data, CostMatrix costMatrix) throws Exception { m_NumClasses = data.numClasses(); m_NumFolds = 1; m_ClassIsNominal = data.classAttribute().isNominal(); if (m_ClassIsNominal) { m_ConfusionMatrix = new double [m_NumClasses][m_NumClasses]; m_ClassNames = new String [m_NumClasses]; for(int i = 0; i < m_NumClasses; i++) { m_ClassNames[i] = data.classAttribute().value(i); } } m_CostMatrix = costMatrix; if (m_CostMatrix != null) { if (!m_ClassIsNominal) { throw new Exception("Class has to be nominal if cost matrix " + "given!"); } if (m_CostMatrix.size() != m_NumClasses) { throw new Exception("Cost matrix not compatible with data!"); } } m_ClassPriors = new double [m_NumClasses]; setPriors(data); m_MarginCounts = new double [k_MarginResolution + 1]; } /** * Performs a (stratified if class is nominal) cross-validation * for a classifier on a set of instances. * * @param classifier the classifier with any options set. * @param data the data on which the cross-validation is to be * performed * @param numFolds the number of folds for the cross-validation * @exception Exception if a classifier could not be generated * successfully or the class is not defined * */ public void crossValidateModel(Classifier classifier, Instances data, int numFolds) throws Exception { // Make a copy of the data we can reorder data = new Instances(data); if (data.classAttribute().isNominal()) { data.stratify(numFolds); } // Do the folds for (int i = 0; i < numFolds; i++) { Instances train = data.trainCV(numFolds, i); setPriors(train); classifier.buildClassifier(train); Instances test = data.testCV(numFolds, i); evaluateModel(classifier, test); } m_NumFolds = numFolds; } /** * Performs a (stratified if class is nominal) cross-validation * for a classifier on a set of instances. This cross-validation * is run in parallel by connecting to the machines described in * ~/.weka-parallel. * * @param classifier the classifier with any options set. * @param data the data on which the cross-validation is to be * performed * @param numFolds the number of folds for the cross-validation * @param otherComputers will eventually hold the names of all of the * computers that the program was actually able to connect to and * receive data from * @exception Exception if a classifier could not be generated, * if the class is not defined, or if there was an incorrect number * of folds selected */ public void crossValidateModelParallel(Classifier classifier, Instances data, int numFolds, StringBuffer otherComputers) throws Exception
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -