📄 cattestresult.java
字号:
package shared;
import java.lang.*;
import java.util.*;
import java.io.*;
/** The CatTestResult class provides summaries of running
* categorizers on test data. This includes the option of
* loading the test data from a file (or giving an existing
* InstanceList), running the categorizer on all instances,
* and storing the results. Information can then be extracted
* quickly.<P>
* The training set and test set (if given as opposed to
* loading it here) must not be altered as long
* as calls to this class are being made, because references
* are kept to those structures.<P>
* The complexity for construction of the CatTestResult
* is O(n1 n2), where n1 is the size of the training-set
* InstanceList and n2 is the size of the test set. All
* display routines take time proportional to the
* number of displayed numbers.<P>
* The CatTestResult class has been enhanced to compute the
* log-evidence metric. The log evidence metric is equal to
* the total evidence against the correct category.<P>
* @author James Louis 12/06/2000 Java Implementations
* JavaDocumentation
* @author Jim Kelly 11/08/96 Further strengthening of
* display_confusion_matrix(): now can
* display matrix in scatterviz.
* @author Yeogirl Yun 12/27/94 Implemented NOT_IMPLEMENTED
* parts. Strengthened display_confusion_matrix()
* @author Robert Allen 12/10/94 Add generalized vs memorized error
* @author Richard Long 10/01/93 Initial revision (.c)
* @author Ronny Kohavi 9/13/93 Initial revision (.h)
*/
public class CatTestResult{
// LOG_OPTIONS;
/*ENUM ErrorType*/
/** The Normal partition for error reporting.**/
static final public int Normal = 0; /*ENUM ErrorType*/
/** The Generalized partition for error reporting.**/
static final public int Generalized = 1; /*ENUM ErrorType*/
/** The Memorized partition for error reporting.**/
static final public int Memorized = 2; /*ENUM ErrorType*/
/*ENUM ErrorType*/
// Member data
/** The InstanceList containing the training data set.**/
InstanceList trainIL;
/** The InstanceList containing the testing data set.**/
InstanceList testIL;
/** The number of Instances on which a Categorizer is trained.**/
int numOnTrain;
/** The total weight of the Instances on which a Categorizer is trained.**/
double weightOnTrain;
/** The scores of a Categorizer's test.**/
ScoringMetrics metrics;
/** Statistical data on a Categorizer.**/
StatData lossStats;
/** The results of a Categorizer's test.**/
CatOneTestResult[] results;
/** The distribution of categories over a data set.**/
CatCounters[] catDistrib;
/** The confusion matrix produced by a test of a Categorizer.**/
double[][] confusionMatrix;
/** The indicator that this CatTestResult owns the test data set.
True if this CatTestResult does own the set.**/
boolean ownsTestIL;
/** The indicator that this CatTestResult owns the training data set.
True if this CatTestResult does own the set.**/
boolean ownsTrainIL;
/** The indicator that the training data set is initialized with instances.**/
boolean inTrainingSetInitialized;
/** The indicator that this CatTestResult should compute LogLoss. True indicates
this CatTestResult should do so.**/
static boolean computeLogLoss;
/** Logging options for this class. **/
protected LogOptions logOptions = new LogOptions();
/** Sets the logging level for this object.
* @param level The new logging level.
*/
public void set_log_level(int level){logOptions.set_log_level(level);}
/** Returns the logging level for this object.
* @return The log level for this object.
*/
public int get_log_level(){return logOptions.get_log_level();}
/** Sets the stream to which logging options are displayed.
* @param strm The stream to which logs will be written.
*/
public void set_log_stream(Writer strm)
{logOptions.set_log_stream(strm);}
/** Returns the stream to which logs for this object are written.
* @return The stream to which logs for this object are written.
*/
public Writer get_log_stream(){return logOptions.get_log_stream();}
/** Returns the LogOptions object for this object.
* @return The LogOptions object for this object.
*/
public LogOptions get_log_options(){return logOptions;}
/** Sets the LogOptions object for this object.
* @param opt The new LogOptions object.
*/
public void set_log_options(LogOptions opt)
{logOptions.set_log_options(opt);}
/** Sets the logging message prefix for this object.
* @param file The file name to be displayed in the prefix of log messages.
* @param line The line number to be displayed in the prefix of log messages.
* @param lvl1 The log level of the statement being logged.
* @param lvl2 The level of log messages being displayed.
*/
public void set_log_prefixes(String file, int line,int lvl1, int lvl2)
{logOptions.set_log_prefixes(file, line, lvl1, lvl2);}
/** This class has no access to a copy constructor.
* @param source The CatTestResult object to be copied.
*/
private CatTestResult(CatTestResult source){}
/** This class has no access to an assign method.
* @param source The CatTestResult containing data to be copied into this CatTestResult object.
*/
private void assign(CatTestResult source){}
/** Constructor.
* @param cat The Categorizer used to create this CatTestResult.
* @param trainILSource The training data set.
* @param testILSource The test data set.
*/
public CatTestResult(Categorizer cat,
InstanceList trainILSource,
InstanceList testILSource) {
logOptions = new LogOptions("CTR");
trainIL = trainILSource;
testIL = testILSource;
results = new CatOneTestResult[testIL.num_instances()];
for(int y = 0; y < results.length; y++) results[y]= new CatOneTestResult();
catDistrib = new CatCounters[testIL.num_categories() + 1];//(Globals.UNKNOWN_CATEGORY_VAL, testIL.num_categories() + 1);
for(int z = 0; z < catDistrib.length; z++) catDistrib[z] = new CatCounters();
confusionMatrix = new double[testILSource.num_categories()+1][testILSource.num_categories()+1];
// (Globals.UNKNOWN_CATEGORY_VAL, Globals.UNKNOWN_CATEGORY_VAL,
// testILSource.num_categories()+1,
// testILSource.num_categories()+1,
// 0);
metrics = new ScoringMetrics();
lossStats = new StatData();
ownsTestIL = false;
ownsTrainIL = false;
initialize(cat);
inTrainingSetInitialized = false;
}
/** Prune the tree for the given pruning factor. Pruning is based on
* C4.5's pruning / Quinlan. We return the pessimistic number of errors
* on the training set. We use the standard normal distribution approximation
* from CatTestResult. Here's a derivation that shows that this happens to
* be the same as C4.5, at least for errors >= 1. <BR>
* err = (2ne+z^2+z*sqrt(4ne+z^2-4ne^2))/(2*(n+z^2)) <BR>
* where n is the number of records, e is the prob of error, and z is the z-value.
* Let E = count of errors, i.e., ne.<BR>
* err = (2E + z^2 + z*sqrt(4E+z^2-4E^2/n))/(2*(n+z^2)) <BR>
* err = (E + z^2/2 + z*sqrt(E-E^2/n+z^2/4))/(n+z^2) <BR>
* err = (E + z^2/2 + z*sqrt(E(1-E/n)+z^2/4))/(n+z^2) <BR>
*
* @return The pessimistic number of errors on the training set.
* @param numErrors The number of errors produced in a test run of this categorizer.
* @param totalWeight The total weight of all Instances tested.
* @param zValue The half of the interval width for confidence evaluation.
*/
static public double pessimistic_error_correction(double numErrors,
double totalWeight,
double zValue) {
MLJ.verify_strictly_greater(totalWeight, 0, "CatTestResult::"
+"pessimistic_error_correction: "
+"zero total weight");
if (zValue == 0)
return numErrors;
// This can be strictly less than if we're guaranteed to have majority,
// but when classifying instances in another subtree, this may not
// hold any more (e.g., when asserting whether to replace a node
// with one of its subtrees).
//@@ this check may go away with loss functions
//@@ Dan, add maximum loss here under DBG
// ASSERT(numErrors <= totalWeight);
double probError = (numErrors + 0.5) / totalWeight;
if (probError > 1)
probError = 1;
DoubleRef optimisticProb = new DoubleRef(0);
DoubleRef pessimisticProb = new DoubleRef(0);
confidence(optimisticProb, pessimisticProb,
probError, totalWeight, zValue);
MLJ.clamp_below(pessimisticProb, 1, "CatTestResult::"
+"pessimistic_error_correction: too many errors");
// ASSERT(pessimisticProb >= 0);
return pessimisticProb.value * totalWeight;
}
/** Compute the confidence interval according to the binomial
* model. Source is Devijver and Kittler.
* @param confLow Low bound of confidence interval. This value is altered.
* @param confHigh High bound of confidence interval. This value is altered.
* @param error The error value for which the confidence interval is requested.
* @param n Number of samples.
* @param z The confidence coefficient.
*/
static void confidence(DoubleRef confLow, DoubleRef confHigh,
double error, double n, double z) {
double z2 = z*z;
double sqrtTerm = z*Math.sqrt(4*n*error+z2 - 4*n*error*error);
double numer = 2*n*error + z2;
double denom = 2*(n+z2);
confLow.value = (numer - sqrtTerm)/denom;
confHigh.value = (numer + sqrtTerm)/denom;
}
/** Compute the confidence interval according to the binomial
* model. Source is Devijver and Kittler.
* @param confLow Low bound of confidence interval. This value is altered.
* @param confHigh High bound of confidence interval. This value is altered.
* @param error The error value for which the confidence interval is requested.
* @param n Number of samples.
*/
static void confidence(DoubleRef confLow, DoubleRef confHigh,
double error, double n) {
confidence(confLow,confHigh,error,n,Globals.CONFIDENCE_INTERVAL_Z);
}
/** Returns ratio number of test instances incorrectly categorized
* / number of test instances. Test instance set defaults to
* all test instances.
* @return The ratio number of incorrectly classified instances without
* partitioning.
*/
public double error(){return error(Normal);}
/** Returns ratio number of test instances incorrectly categorized
* / number of test instances. Test instance set defaults to
* all test instance. ErrorType argument can be used to
* partition test cases into those occuring in the training
* set or not.
* @return The ratio number of incorrectly classified instances.
* @param errType The type of error used to partition test cases. Possible
* values are CatTestResult.Normal, CatTestResult.Generalized,
* CatTestResult.Memorized.
*/
public double error(int errType) {
if (testIL.no_instances())
Error.fatalErr("CatTestResult::error: No test instances. This causes "
+"division by 0");
switch (errType) {
case Normal:
return total_incorrect_weight() / total_test_weight();
case Generalized: {
if (total_weight_off_train() == 0)
Error.fatalErr("CatTestResult::error(Generalized): All test instances "
+"are also in training set. This causes division by 0");
double weightOffTrainIncorrect = 0;
// for (int i = results.low(); i <= results.high(); i++)
for (int i = 0; i < results.length; i++)
if ( (!results[i].inTrainIL) &&
( results[i].augCat.num() != results[i].correctCat ) )
weightOffTrainIncorrect += results[i].instance.get_weight();
return weightOffTrainIncorrect / total_weight_off_train();
}
case Memorized: {
if (total_weight_on_train() == 0)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -