📄 evaluation.java
字号:
}
testTimeElapsed = System.currentTimeMillis() - testTimeStart;
trainReader.close();
} else {
testTimeStart = System.currentTimeMillis();
trainingEvaluation.evaluateModel(classifier,
train);
testTimeElapsed = System.currentTimeMillis() - testTimeStart;
}
// Print the results of the training evaluation
if (printMargins) {
return trainingEvaluation.toCumulativeMarginDistributionString();
} else {
text.append("\nTime taken to build model: " +
Utils.doubleToString(trainTimeElapsed / 1000.0,2) +
" seconds");
text.append("\nTime taken to test model on training data: " +
Utils.doubleToString(testTimeElapsed / 1000.0,2) +
" seconds");
text.append(trainingEvaluation.
toSummaryString("\n\n=== Error on training" +
" data ===\n", printComplexityStatistics));
if (template.classAttribute().isNominal()) {
if (classStatistics) {
text.append("\n\n" + trainingEvaluation.toClassDetailsString());
}
text.append("\n\n" + trainingEvaluation.toMatrixString());
}
}
}
// Compute proper error estimates
if (testFileName.length() != 0) {
// Testing is on the supplied test data
while (test.readInstance(testReader)) {
testingEvaluation.evaluateModelOnce((Classifier)classifier,
test.instance(0));
test.delete(0);
}
testReader.close();
text.append("\n\n" + testingEvaluation.
toSummaryString("=== Error on test data ===\n",
printComplexityStatistics));
} else if (trainFileName.length() != 0) {
// Testing is via cross-validation on training data
Random random = new Random(seed);
// use untrained (!) classifier for cross-validation
classifier = Classifier.makeCopy(classifierBackup);
testingEvaluation.crossValidateModel(classifier, train, folds, random);
if (template.classAttribute().isNumeric()) {
text.append("\n\n\n" + testingEvaluation.
toSummaryString("=== Cross-validation ===\n",
printComplexityStatistics));
} else {
text.append("\n\n\n" + testingEvaluation.
toSummaryString("=== Stratified " +
"cross-validation ===\n",
printComplexityStatistics));
}
}
if (template.classAttribute().isNominal()) {
if (classStatistics) {
text.append("\n\n" + testingEvaluation.toClassDetailsString());
}
text.append("\n\n" + testingEvaluation.toMatrixString());
}
return text.toString();
}
/**
* Attempts to load a cost matrix.
*
* @param costFileName the filename of the cost matrix
* @param numClasses the number of classes that should be in the cost matrix
* (only used if the cost file is in old format).
* @return a <code>CostMatrix</code> value, or null if costFileName is empty
* @throws Exception if an error occurs.
*/
protected static CostMatrix handleCostOption(String costFileName,
int numClasses)
throws Exception {
if ((costFileName != null) && (costFileName.length() != 0)) {
System.out.println(
"NOTE: The behaviour of the -m option has changed between WEKA 3.0"
+" and WEKA 3.1. -m now carries out cost-sensitive *evaluation*"
+" only. For cost-sensitive *prediction*, use one of the"
+" cost-sensitive metaschemes such as"
+" weka.classifiers.meta.CostSensitiveClassifier or"
+" weka.classifiers.meta.MetaCost");
Reader costReader = null;
try {
costReader = new BufferedReader(new FileReader(costFileName));
} catch (Exception e) {
throw new Exception("Can't open file " + e.getMessage() + '.');
}
try {
// First try as a proper cost matrix format
return new CostMatrix(costReader);
} catch (Exception ex) {
try {
// Now try as the poxy old format :-)
//System.err.println("Attempting to read old format cost file");
try {
costReader.close(); // Close the old one
costReader = new BufferedReader(new FileReader(costFileName));
} catch (Exception e) {
throw new Exception("Can't open file " + e.getMessage() + '.');
}
CostMatrix costMatrix = new CostMatrix(numClasses);
//System.err.println("Created default cost matrix");
costMatrix.readOldFormat(costReader);
return costMatrix;
//System.err.println("Read old format");
} catch (Exception e2) {
// re-throw the original exception
//System.err.println("Re-throwing original exception");
throw ex;
}
}
} else {
return null;
}
}
/**
* Evaluates the classifier on a given set of instances. Note that
* the data must have exactly the same format (e.g. order of
* attributes) as the data used to train the classifier! Otherwise
* the results will generally be meaningless.
*
* @param classifier machine learning classifier
* @param data set of test instances for evaluation
* @return the predictions
* @throws Exception if model could not be evaluated
* successfully
*/
public double[] evaluateModel(Classifier classifier,
Instances data) throws Exception {
double predictions[] = new double[data.numInstances()];
for (int i = 0; i < data.numInstances(); i++) {
predictions[i] = evaluateModelOnce((Classifier)classifier,
data.instance(i));
}
return predictions;
}
/**
* Evaluates the classifier on a single instance.
*
* @param classifier machine learning classifier
* @param instance the test instance to be classified
* @return the prediction made by the clasifier
* @throws Exception if model could not be evaluated
* successfully or the data contains string attributes
*/
public double evaluateModelOnce(Classifier classifier,
Instance instance) throws Exception {
Instance classMissing = (Instance)instance.copy();
double pred = 0;
classMissing.setDataset(instance.dataset());
classMissing.setClassMissing();
if (m_ClassIsNominal) {
double [] dist = classifier.distributionForInstance(classMissing);
pred = Utils.maxIndex(dist);
if (dist[(int)pred] <= 0) {
pred = Instance.missingValue();
}
updateStatsForClassifier(dist, instance);
} else {
pred = classifier.classifyInstance(classMissing);
updateStatsForPredictor(pred, instance);
}
return pred;
}
/**
* Evaluates the supplied distribution on a single instance.
*
* @param dist the supplied distribution
* @param instance the test instance to be classified
* @return the prediction
* @throws Exception if model could not be evaluated
* successfully
*/
public double evaluateModelOnce(double [] dist,
Instance instance) throws Exception {
double pred;
if (m_ClassIsNominal) {
pred = Utils.maxIndex(dist);
if (dist[(int)pred] <= 0) {
pred = Instance.missingValue();
}
updateStatsForClassifier(dist, instance);
} else {
pred = dist[0];
updateStatsForPredictor(pred, instance);
}
return pred;
}
/**
* Evaluates the supplied prediction on a single instance.
*
* @param prediction the supplied prediction
* @param instance the test instance to be classified
* @throws Exception if model could not be evaluated
* successfully
*/
public void evaluateModelOnce(double prediction,
Instance instance) throws Exception {
if (m_ClassIsNominal) {
updateStatsForClassifier(makeDistribution(prediction),
instance);
} else {
updateStatsForPredictor(prediction, instance);
}
}
/**
* Wraps a static classifier in enough source to test using the weka
* class libraries.
*
* @param classifier a Sourcable Classifier
* @param className the name to give to the source code class
* @return the source for a static classifier that can be tested with
* weka libraries.
* @throws Exception if code-generation fails
*/
protected static String wekaStaticWrapper(Sourcable classifier,
String className)
throws Exception {
//String className = "StaticClassifier";
String staticClassifier = classifier.toSource(className);
return "package weka.classifiers;\n\n"
+"import weka.core.Attribute;\n"
+"import weka.core.Instance;\n"
+"import weka.core.Instances;\n"
+"import weka.classifiers.Classifier;\n\n"
+"public class WekaWrapper extends Classifier {\n\n"
+" public void buildClassifier(Instances i) throws Exception {\n"
+" }\n\n"
+" public double classifyInstance(Instance i) throws Exception {\n\n"
+" Object [] s = new Object [i.numAttributes()];\n"
+" for (int j = 0; j < s.length; j++) {\n"
+" if (!i.isMissing(j)) {\n"
+" if (i.attribute(j).type() == Attribute.NOMINAL) {\n"
+" s[j] = i.attribute(j).value((int) i.value(j));\n"
+" } else if (i.attribute(j).type() == Attribute.NUMERIC) {\n"
+" s[j] = new Double(i.value(j));\n"
+" }\n"
+" }\n"
+" }\n"
+" return " + className + ".classify(s);\n"
+" }\n\n"
+"}\n\n"
+staticClassifier; // The static classifer class
}
/**
* Gets the number of test instances that had a known class value
* (actually the sum of the weights of test instances with known
* class value).
*
* @return the number of test instances with known class
*/
public final double numInstances() {
return m_WithClass;
}
/**
* Gets the number of instances incorrectly classified (that is, for
* which an incorrect prediction was made). (Actually the sum of the weights
* of these instances)
*
* @return the number of incorrectly classified instances
*/
public final double incorrect() {
return m_Incorrect;
}
/**
* Gets the percentage of instances incorrectly classified (that is, for
* which an incorrect prediction was made).
*
* @return the percent of incorrectly classified instances
* (between 0 and 100)
*/
public final double pctIncorrect() {
return 100 * m_Incorrect / m_WithClass;
}
/**
* Gets the total cost, that is, the cost of each prediction times the
* weight of the instance, summed over all instances.
*
* @return the total cost
*/
public final double totalCost() {
return m_TotalCost;
}
/**
* Gets the average cost, that is, total cost of misclassifications
* (incorrect plus unclassified) over the total number of instances.
*
* @return the average cost.
*/
public final double avgCost() {
return m_TotalCost / m_WithClass;
}
/**
* Gets the number of instances correctly classified (that is, for
* which a correct prediction was made). (Actually the sum of the weights
* of these instances)
*
* @return the number of correctly classified instances
*/
public final double correct() {
return m_Correct;
}
/**
* Gets the percentage of instances correctly classified (that is, for
* which a correct prediction was made).
*
* @return the percent of correctly classified instances (between 0 and 100)
*/
public final double pctCorrect() {
return 100 * m_Correct / m_WithClass;
}
/**
* Gets the number of instances not classified (that is, for
* which no prediction was made by the classifier). (Actually the sum
* of the weights of these instances)
*
* @return the number of unclassified instances
*/
public final double unclassified() {
return m_Unclassified;
}
/**
* Gets the percentage of instances not classified (that is, for
* which no prediction was made by the classifier).
*
* @return the percent of unclassified instances (between 0 and 100)
*/
public final double pctUnclassified() {
return 100 * m_Unclassified / m_WithClass;
}
/**
* Returns the estimated error rate or the root mean squared error
* (if the class is numeric). If a cost matrix was given this
* error rate gives the average cost.
*
* @return the estimated error rate (between 0 and 1, or between 0 and
* maximum cost)
*/
public final double errorRate() {
if (!m_ClassIsNominal) {
return Math.sqrt(m_SumSqrErr / m_WithClass);
}
if (m_CostMatrix == null) {
return m_Incorrect / m_WithClass;
} else {
return avgCost();
}
}
/**
* Returns value of kappa statistic if class is nominal.
*
* @return the value of the kappa statistic
*/
public final double kappa() {
double[] sumRows = new double[m_ConfusionMatrix.length];
double[] sumColumns = new double[m_ConfusionMatrix.length];
double sumOfWeights = 0;
for (int i = 0; i < m_ConfusionMatrix.length; i++) {
for (int j = 0; j < m_ConfusionMatrix.length; j++) {
sumRows[i] += m_ConfusionMatrix[i][j];
sumColumns[j] += m_ConfusionMatrix[i][j];
sumOfWeights += m_ConfusionMatrix[i][j];
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -