📄 cattestresult.java
字号:
Error.fatalErr("CatTestResult::error(Memorized): No test instances "
+"are in the training set. This causes division by 0");
double weightOnTrainIncorrect = 0;
// for (int i = results.low(); i <= results.high(); i++)
for (int i = 0; i < results.length; i++)
if ( (results[i].inTrainIL) &&
( results[i].augCat.num() != results[i].correctCat ) )
weightOnTrainIncorrect += results[i].instance.get_weight();
return weightOnTrainIncorrect / total_weight_on_train();
}
default:
Error.fatalErr("CatTestResult::error unexpected error type"
+(int)errType);
return 1.0;
} // end switch
}
/** Weight of test instances appearing in appearing in
* the training data. Initializes flag for each test instance if not already
* done.
* @return The total weight of the instances found in the training and test
* data sets.
*/
public double total_weight_on_train() {
if (!inTrainingSetInitialized )
initializeTrainTable();
return weightOnTrain;
}
/** Weight of test instances appearing not appearing in
* the training data. Initializes flag for each test instance if not already
* done.
* @return The total weight of the instances not found in the training and test
* data sets.
*/
public double total_weight_off_train() {
return total_test_weight() - total_weight_on_train();
}
/** Uses TableCategorizer as an interface to hash table to do quick
* lookup on whether a test instance occurs in the training
* set. Only called when inTrainIL data is needed.
* Initializes class variable numOffTrain to number of test
* cases found in training set.
*/
protected void initializeTrainTable() {
int numTestInTrain = 0;
double weightTestInTrain = 0;
// @@ Due to the MString leak compiler bug, this temporary is allocated
// here
String tableName ="Training Set Lookup Table";
TableCategorizer trainTable = new TableCategorizer(trainIL,
Globals.UNKNOWN_CATEGORY_VAL,
tableName);
int i = 0;
// for (ILPix pix(*testIL); pix; ++pix, ++i) {
// InstanceRC instance = *pix;
for(ListIterator pixI = testIL.instance_list().listIterator();pixI.hasNext();i++){
Instance instance = (Instance)pixI.next();
if (trainTable.categorize(instance).Category() != Globals.UNKNOWN_CATEGORY_VAL) {
results[i].inTrainIL = true; // constructor sets to FALSE
numTestInTrain++;
weightTestInTrain += instance.get_weight();
}
}
numOnTrain = numTestInTrain;
weightOnTrain = weightTestInTrain;
inTrainingSetInitialized = true;
}
/** Returns the total weight in the test list.
* @return The total weight of the test data set.
*/
public double total_test_weight() {
return testIL.total_weight();
}
/** Returns the total weight of instances which were
* incorrectly classified.
* @return The total weight of incorrectly classified instances.
*/
public double total_incorrect_weight() {
return metrics.weightIncorrect;
}
/** Initializes this CatTestResult by categorizing the test data set with the
* given Categorizer.
* @param cat The categorizer with which the test data set will be categorized.
*/
protected void initialize(Categorizer cat) {
// DBG(trainIL.get_schema().compatible_with(testIL.get_schema(), true));
int i = 0;
int numProcessed = 0;
int numTotal = testIL.num_instances();
int tenthsDone = 0;
// just for now JWP logOptions.DRIBBLE("Classifying (% done): ");
// for (ILPix pix(*testIL); pix; ++pix, ++i) {
for(ListIterator pixIL = testIL.instance_list().listIterator();pixIL.hasNext();++i){
Instance instance = (Instance)pixIL.next();
int newTenthsDone = (10 * ++numProcessed) / numTotal;
// If the outer loop is not in an IFDRIBBLE, we get into
// an infinite loop here!
// IFDRIBBLE(while (newTenthsDone > tenthsDone)
// logOptions.DRIBBLE(10 * ++tenthsDone + "% " + flush));
// Instance instance = pix;
logOptions.LOG(2, "Instance: " + instance);
// logOptions.LOG(2, "Instance: ");
// instance.display(false,false);
int correctCat = instance.label_info().get_nominal_val(instance.get_label());
results[i].correctCat = correctCat;
results[i].instance = new Instance(instance);
// It is an error to encounter unknown labels at this point.
if(instance.label_info().is_unknown(instance.get_label()))
Error.fatalErr("Test instance (" +instance +") has an unknown "
+"label value");
// Change label to UNKNOWN, so that categorizers don't cheat.
// DBG(AttrValue_ av;
// instance.label_info().set_unknown(av);
// instance.set_label(av));
// Categorize the instance. If the categorizer is capable of
// scoring, compute the probability distribution of the label
// as well. If the categorizer is NOT capable of scoring,
// build all-or-nothing distributions.
if(cat.supports_scoring()) {
results[i].predDist = cat.score(instance);
results[i].correctDist = new CatDist(cat.get_schema(), correctCat);
results[i].augCat = new AugCategory(results[i].predDist.best_category());
}
else {
AugCategory predCat = cat.categorize(instance);
results[i].augCat = new AugCategory(predCat);
results[i].predDist = new CatDist(cat.get_schema(), predCat);
results[i].correctDist = new CatDist(cat.get_schema(), correctCat);
}
// Accumulate all results into a ScoringMetric structure.
ScoringMetrics metric = new ScoringMetrics();
compute_scoring_metrics(metric,
results[i].augCat,
new AugCategory(correctCat, "correct"),
results[i].predDist,
results[i].correctDist,
instance.get_weight(),
testIL.get_schema(),
computeLogLoss);
AugCategory predCat = results[i].augCat;
catDistrib[predCat.num()].numTestSet++;
accumulate_scoring_metrics(catDistrib[predCat.num()].metrics, metric);
confusionMatrix[correctCat][predCat.num()] += instance.get_weight();
accumulate_scoring_metrics(metrics, metric);
// accumulate the loss into the loss statistics statData
lossStats.insert(metric.totalLoss);
logOptions.LOG(2, "Correct label: "
+instance.get_schema().category_to_label_string(correctCat)
+" Predicted label: " +predCat.description() +". ");
if (predCat.num() == correctCat) {
// checks that strings match too
// DBG(if (predCat.description() !=
// instance.get_schema().category_to_label_string(correctCat))
// err + "CatTestResult::initialize: labels match but not in "
// " name. Categorizer=" + predCat.description()
// + ". Instance="
// + instance.get_schema().
// category_to_label_string(correctCat)+ fatal_error);
logOptions.LOG(2, "Correct. ");
} else {
logOptions.LOG(2, "Incorrect. ");
}
logOptions.LOG(2, "No correct: " +metrics.numCorrect +'/' +(i + 1)+'\n');
}
// IFDRIBBLE(while (tenthsDone < 10)
// logOptions.DRIBBLE(10 * ++tenthsDone +"% " + flush));
// just for now JWP logOptions.DRIBBLE("done." +'\n');
}
/** Accumulates the given ScoringMetrics by the increment of another
* ScoringMetrics.
* @param dest The ScoringMetrics which is incremented.
* @param src The ScoringMetrics which provides the step for incrementation.
*/
private void accumulate_scoring_metrics(ScoringMetrics dest,
ScoringMetrics src) {
dest.numCorrect += src.numCorrect;
dest.numIncorrect += src.numIncorrect;
dest.totalLoss += src.totalLoss;
dest.weightCorrect += src.weightCorrect;
dest.weightIncorrect += src.weightIncorrect;
dest.meanSquaredError += src.meanSquaredError;
dest.meanAbsoluteError += src.meanAbsoluteError;
dest.minimumLoss += src.minimumLoss;
dest.maximumLoss += src.maximumLoss;
dest.totalLogLoss += src.totalLogLoss;
}
/** Useful functions for computing scoring metrics.
* All metrics (probabilistic or normal) are computed
* within this function.
* @param metrics The ScoringMetrics
* @param predictedCat The category determined by a Categorizer.
* @param correctCat The correct category for a test Instance.
* @param predictedDist The distribution of categories determined by a Categorizer.
* @param correctDist The distribution of the correct categories for a test data set.
* @param weight The weight of the instance currently being added.
* @param testSchema The Schema for the test Instance.
* @param computeLogLoss Indicator of whether LogLoss should be computed. True indicates
* LogLoss should be computed.
*/
private void compute_scoring_metrics(ScoringMetrics metrics,
AugCategory predictedCat,
AugCategory correctCat,
CatDist predictedDist,
CatDist correctDist,
double weight,
Schema testSchema,
boolean computeLogLoss) {
double[] predScores = predictedDist.get_scores();
double[] corrScores = correctDist.get_scores();
if(predScores.length != corrScores.length)
Error.fatalErr("compute_scoring_metrics: correct and predicted "
+"distributions have different sizes" );
// Compare only the categories, because the predictedCat and correctCat
// may get different descriptions if one was created by a CatDist
// based on a scoring inducer.
if(predictedCat.Category() == correctCat.Category()) {
metrics.weightCorrect += weight;
metrics.numCorrect++;
}
else {
metrics.weightIncorrect += weight;
metrics.numIncorrect++;
}
// If a loss matrix is defined for the TEST set, look up the loss
// and add it to our totalLoss. Then find the minimum and maximum
// losses given this correctCat. Add these to the minimum and
// maximum loss metrics respectively.
// All losses must be scaled by the instance's weight.
if(testSchema.has_loss_matrix()) {
metrics.totalLoss += testSchema.get_loss_matrix()[correctCat.num()][predictedCat.num()] * weight;
int index = 0;
metrics.minimumLoss += Matrix.min_in_row(correctCat.num(), index, testSchema.get_loss_matrix()) * weight;
metrics.maximumLoss += Matrix.max_in_row(correctCat.num(), index, testSchema.get_loss_matrix()) * weight;
}
else {
if(predictedCat.notequal(correctCat))
metrics.totalLoss += weight;
metrics.maximumLoss += weight;
}
// Compute the mean squared and mean absolute error between the
// predicted and correct distributions.
// The correct distribution will generally be an
// all-or-nothing distribution in favor of the correct category.
// Normalize the mean squared/mean absolute errors to be always
// between 0 and 1. This involves dividing by 2.
double mse = 0;
double mae = 0;
for(int i=0; i<predScores.length; i++) {
double diff = predScores[i] - corrScores[i];
mse += diff*diff;
mae += Math.abs(diff);
}
mse /= 2;
mae /= 2;
metrics.meanSquaredError += weight*mse;
metrics.meanAbsoluteError += weight*mae;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -