📄 cattestresult.java
字号:
public String display_confusion_matrix(String stream) {
// DispConfMatType defaultDispConfMat = no;
String dispConfMatHelp =
"Display confusion matrix when displaying results on test";
// static DispConfMatType dispConfusionMatrix =
// get_option_enum("DISP_CONFUSION_MAT", dispConfMatEnum,
// defaultDispConfMat, dispConfMatHelp, TRUE);
// switch (dispConfusionMatrix) {
// case no:
// break;
// case ascii:
return display_ascii_confusion_matrix(stream);
// break;
// case both:
// display_ascii_confusion_matrix(stream);
// case scatterviz:
// MString outputSVizfile =
// get_option_string("SCATTERVIZ_FILE", EMPTY_STRING,
// "File name for ScatterViz confusion matrix config file", TRUE);
//
// // if the name is provided, use it to generate a permanent file.
// if(outputSVizfile != EMPTY_STRING) {
// outputSVizfile += ".conf_matrix.scatterviz";
// MLCOStream out(outputSVizfile);
// MString dataName = out.description() + ".data";
// MLCOStream data(dataName);
// display_scatterviz_confusion_matrix(out, data);
// }
// otherwise, use a TEMPORARY scatterviz file only.
// else {
// Array<MString> suffixes(2);
// suffixes[0] = ".scatterviz";
// suffixes[1] = ".scatterviz.data";
// PtrArray<TmpFileName *> *tempNames = gen_temp_file_names(suffixes);
// const TmpFileName& tmpSVizConf = *tempNames->index(0);
// const TmpFileName& tmpSVizData = *tempNames->index(1);
// MLCOStream SVizConf(tmpSVizConf);
// MLCOStream SVizData(tmpSVizData);
// display_scatterviz_confusion_matrix(SVizConf, SVizData);
// SVizConf.close();
// SVizData.close();
// if(system(*GlobalOptions::scattervizUtil + " " +
// tmpSVizConf))
// Mcerr << "CatTestResult::display: "
// "Call to ScatterViz returns an error." << endl;
// delete tempNames;
// }
// break;
// }
}
/** Gives all available statistics (not displays)
* @param stream The writer to which the statistics will be displayed.
*/
public void display(BufferedWriter stream) {
try{stream.write(toString());}
catch(IOException e){e.printStackTrace();}
}
/** Converts information in this CatTestResult object to a string for display.
* @return The String containing the display.
*/
public String toString() {
String rtrn = new String();
GetEnv getenv = new GetEnv();
boolean dispMisclassifications = getenv.get_option_bool("DISP_MISCLASS", false,
"Display misclassified instances",true);
boolean isWeighted = trainIL.is_weighted() || testIL.is_weighted();
//No perf_display method yet -JL
// rtrn = rtrn + push_opts + perf_display;
if(isWeighted) {
rtrn = rtrn + "Total training weight: " + (float)total_train_weight()
+ " (" + num_train_instances() + " instances)" + "\n";
rtrn = rtrn + "Total test weight: " + (float)total_test_weight()
+ " (" + num_test_instances() + " instances)" + "\n";
rtrn = rtrn + " seen: " + (float)total_weight_on_train()
+ " (" + num_on_train() + " instances)" + "\n";
rtrn = rtrn + " unseen: " + (float)total_weight_off_train()
+ " (" + num_off_train() + " instances)" + "\n";
rtrn = rtrn + " correct: " + (float)total_correct_weight()
+ " (" + num_correct() + " instances)" + "\n";
rtrn = rtrn + " incorrect: " + (float)total_incorrect_weight()
+ " (" + num_incorrect() + " instances)" + "\n";
}
else {
rtrn = rtrn + "Number of training instances: " + num_train_instances()
+ "\n";
rtrn = rtrn + "Number of test instances: " + num_test_instances()
+ ". Unseen: " + num_off_train()
+ ", seen " + num_on_train() + ".\n";
rtrn = rtrn + "Number correct: " + num_correct()
+ ". Number incorrect: " + num_incorrect() + "\n";
}
// DBG(ASSERT(num_off_train() + num_on_train() == num_test_instances()));
DoubleRef confLow = new DoubleRef();
DoubleRef confHigh = new DoubleRef();
confidence(confLow, confHigh, error(), num_test_instances());
rtrn = rtrn + "Generalization error: ";
if (num_off_train() > 0)
rtrn = rtrn + (error(Generalized) * 100) + '%';
else
rtrn = rtrn + "unknown";
rtrn = rtrn + ". Memorization error: ";
if (num_on_train() > 0)
rtrn = rtrn + (error(Memorized) * 100) + '%' + "\n";
else
rtrn = rtrn + "unknown" + "\n";
rtrn = rtrn + "Error: " + error() * 100 + "% +- ";
if(total_test_weight() > 1)
rtrn = rtrn + (theoretical_std_dev(error(), total_test_weight()) * 100)
+ "%";
else
rtrn = rtrn + "Undefined";
rtrn = rtrn + " [" + confLow.value * 100 + "% - " + confHigh.value * 100 + "%]" + "\n";
// display scoring metrics if the categorizer supports scoring
rtrn = rtrn + "Average Normalized Mean Squared Error: "
+ (100 * total_mean_squared_error() / total_test_weight())
+ "%" + "\n";
rtrn = rtrn + "Average Normalized Mean Absolute Error: "
+ (100 * total_mean_absolute_error() / total_test_weight())
+ "%" + "\n";
if (computeLogLoss)
rtrn = rtrn + "Average log-loss: "
+ total_log_loss() / total_test_weight() + "\n";
// display the total and average loss if the test set had a loss matrix
if(testIL.get_schema().has_loss_matrix()) {
rtrn = rtrn + "Total Loss: " + metrics.totalLoss + "\n";
rtrn = rtrn + "Average Loss: " + (metrics.totalLoss / total_test_weight())
+ "\n";
}
/* stream.write(pop_opts);*/
rtrn = display_confusion_matrix(rtrn);
/*
if (dispMisclassifications)
display_incorrect_instances();
// dispay a confusion scattergram in VRML if requested
if(get_option_bool("DISP_CONF_SCATTERGRAM", false, "", true)) {
MString fileName =
get_option_string_no_default("CONF_SCATTERGRAM_NAME", "", false)
+ ".wrl";
display_vrml_scattergram(fileName);
}
*/
return rtrn;
}
/** Returns the InstanceList used for training.
* @return The InstanceList used for training.
*/
public InstanceList get_training_instance_list(){return trainIL;}
/** Returns the InstanceList used for testing.
* @return The InstanceList used for testing.
*/
public InstanceList get_testing_instance_list(){return testIL;}
/** Returns the individual results from testing.
* @return The individual results from testing.
*/
public CatOneTestResult[] get_results(){return results;}
/** Returns the scoring metrics collected from the test results.
* @return The scoring metrics collected from the test results.
*/
public ScoringMetrics get_metrics(){ return metrics; }
/** Returns the confusion matrix of the results of testing.
* @return The confusion matrix of the results of testing.
*/
public double[][] get_confusion_matrix(){ return confusionMatrix; }
/** Returns the total loss value from the scoring metrics.
* @return The total loss value from the scoring metrics.
*/
public double total_loss() {
return metrics.totalLoss;
}
/** Calculates a normalized loss value.
* @return The loss value normalized by the loss value range.
*/
public double normalized_loss() {
return (metrics.totalLoss - metrics.minimumLoss) /
(metrics.maximumLoss - metrics.minimumLoss);
}
/** Sets the computation of log loss option.
* @param b The new setting of the log loss option.
*/
public static void set_compute_log_loss(boolean b) {
computeLogLoss = b;
}
/** Returns TRUE if the log loss option is set, or FALSE otherwise.
* @return TRUE if the log loss option is set, FALSE otherwise.
*/
public static boolean get_compute_log_loss() {
return computeLogLoss;
}
/*
protected:
// Protected methods
void initialize();
public:
void OK(int level = 1) const;
CatTestResult(const Categorizer& cat,
const InstanceList& trainILSource,
const MString& testFile,
const MString& namesExtension = DEFAULT_NAMES_EXT,
const MString& testExtension = DEFAULT_TEST_EXT);
CatTestResult(const Categorizer& cat,
const InstanceList& trainILSource,
const InstanceList& testILSource);
CatTestResult(const Array<CatOneTestResult>& resultsArray,
const InstanceList& trainILSource,
const InstanceList& testILSource);
virtual ~CatTestResult();
void assign_instance_lists(InstanceList *& train,
InstanceList *& test);
// metrics
Real total_label_weight(Category label) const;
Real true_label_weight(Category label) const;
Real false_label_weight(Category label) const;
Real accuracy(ErrorType errType = Normal) const;
Real minimum_loss() const;
Real maximum_loss() const;
const StatData& loss_stats() const { return lossStats; }
Real mean_loss() const { return total_loss() / total_test_weight(); }
Real std_dev_loss() const { return loss_stats().std_dev(); }
void confidence_loss(Real& confLow, Real& confHigh,
Real z = CONFIDENCE_INTERVAL_Z) const
{ loss_stats().percentile(z, confLow, confHigh); }
virtual InstanceRC get_instance(int num) const;
virtual const AugCategory& label(int num) const;
virtual const AugCategory& predicted_label(int num) const;
// display_* show the instance and both labels (except for
// display_correct_instances() which shows only one label).
// Instances which were in the training set say "(In TS)" on
// the display line.
virtual void display_all_instances(MLCOStream& stream = Mcout) const;
virtual void display_incorrect_instances(MLCOStream& stream = Mcout) const;
virtual void display_correct_instances(MLCOStream& stream = Mcout) const;
// category distribution displays an array where each cell i
// consists of: (1) number of test instances in class i,
// (2) number of test instances correctly predicted as class i,
// (3) number of test instances incorrectly predicted as class i.
virtual void display_category_distrib(MLCOStream& stream = Mcout) const;
virtual void display_scatterviz_confusion_matrix(MLCOStream& stream,
MLCOStream& data) const;
// display_all dumps everything (display + display_all_instances).
virtual void display_all(MLCOStream& stream = Mcout) const;
virtual void display_scatterviz_lift_curve(const Category& labelValue,
MString configFileName) const;
virtual void display_vrml_scattergram(const MString& fileName) const;
*/
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -