📄 crossvalidator.java
字号:
package fss;import shared.*;import shared.Error;import java.lang.*;
public class CrossValidator extends PerfEstimator {
private int numFolds;
private int numTimes;
protected double fraction;
public static int defaultNumFolds = 10;
public static int defaultNumTimes = 1;
public static int defaultMaxFolds = 20; // for auto_set_folds
public static int defaultAutoFoldTimes; // for auto_set_folds
public static double defaultStdDevEpsilon = 0.001; // for auto_set_folds
public static double defaultAccEpsilon = 0.005; // for auto_set_folds
public static int defaultMaxTimes = 10; // for auto_estimate
public static double defaultAutoStdDev = 0.01; // for auto_estimate
/***************************************************************************
This class has no access to a copy constructor.
***************************************************************************/
private CrossValidator(CrossValidator source){}
/***************************************************************************
This class has no access to an assign method.
***************************************************************************/
private void assign(CrossValidator source){}
/***************************************************************************
Description : Estimate error for a single time (multi fold).
Comments : protected function
***************************************************************************/
protected double estimate_time_performance(BaseInducer inducer,
InstanceList dataList,
int time, int folds)
{
int totalInstances = dataList.num_instances();
InstanceList shuffledList = dataList.shuffle(rand_num_gen());
PerfData foldData = new PerfData();
logOptions.DRIBBLE(folds + " folds: ");
for (int fold = 0; fold < folds; fold++) {
logOptions.DRIBBLE(fold + 1 + " ");
int numInSplit = totalInstances / folds +
((totalInstances % folds > fold)? 1:0);
logOptions.LOG(3, "Number of instances in fold " + fold + ": " +
numInSplit + ".");
InstanceList testList = shuffledList.split_prefix(numInSplit);
InstanceList fractList = shuffledList;
logOptions.LOG(3, "Total weights of fold " + fold + " (train/test): "
+ fractList.total_weight() + "/" + testList.total_weight()
+ "\n");
if(fraction < 1.0) {
double numInFract = fraction * (double)(shuffledList.num_instances());
int intNumInFract = (int)(numInFract + 0.5);
if(intNumInFract == 0)
Error.fatalErr("CrossValidator.estimate_time_performance: "
+"No instances left in cv fraction");
fractList = shuffledList.split_prefix(intNumInFract);
}
logOptions.LOG(6, "Training set is:" + "\n" + fractList + "\n");
logOptions.LOG(6, "Test set is:" + "\n" + testList + "\n");
boolean saveDribble = GlobalOptions.dribble;
GlobalOptions.dribble = false;
double error =
train_and_test(inducer, fractList, testList,
"-" + time + "-" + fold,
perfData);
GlobalOptions.dribble = saveDribble;
logOptions.LOG(3, "Error " + error + "\n");
foldData.insert_error(error);
if(fraction < 1.0) {
// clean up fractional list
fractList.unite(shuffledList);
shuffledList = fractList;
}
shuffledList.unite(testList);
testList = null; //testList is returned null from InstanceList.unite -JL
MLJ.ASSERT(testList == null,"CrossValidator.estimate_time_performance: testList != null.");
}
logOptions.DRIBBLE("\n");
logOptions.LOG(2, "fold " + foldData + ". ");
shuffledList = null;
return foldData.get_error_data().mean();
}
/***************************************************************************
Contructor for cross-validation. See estimate_performance with same
arguments.
***************************************************************************/
public CrossValidator(int nFolds, int nTimes)
{
set_folds(nFolds);
set_times(nTimes);
set_fraction(1.0);
}
/***************************************************************************
Contructor for cross-validation. See estimate_performance with same
arguments.
***************************************************************************/
public CrossValidator(int nFolds)
{
set_folds(nFolds);
set_times(defaultNumTimes);
set_fraction(1.0);
}
/***************************************************************************
Contructor for cross-validation. See estimate_performance with same
arguments.
***************************************************************************/
public CrossValidator()
{
set_folds(defaultNumFolds);
set_times(defaultNumTimes);
set_fraction(1.0);
}
public void set_folds(int num)
{
if (num == 0)
Error.fatalErr( "CrossValidator.set_folds: num folds (" + num + ") == 0");
numFolds = num;
}
public void set_times(int num)
{
if (num <= 0)
Error.fatalErr( "CrossValidator.set_times: num times (" + num + ") <= 0");
numTimes = num;
}
public void set_fraction(double fract)
{
if( fract <= 0 || fract > 1.0)
Error.fatalErr( "CrossValidator.set_fraction: " + fract + " is out of the range (0,1]");
fraction = fract;
}
public int get_folds(){return numFolds;}
public int get_times(){MLJ.ASSERT(numTimes > 0,"CrossValidator.get_times: numTimes <= 0"); return numTimes;}
/***************************************************************************
Prints identifying string for this estimator. This includes number of
times and folds.
***************************************************************************/
public String description()
{
return numTimes + "x" + numFolds + " Cross validator";
}
/***************************************************************************
Trains and tests the inducer using "numFolds"-fold cross-validation, and
repeated numTimes times. If numFolds is negative, it means leave-k-out,
where k is Math.abs(numFolds). Shuffles the trainList before each time cross
validation is performed. Use init_rand_num_gen(seed) to achieve
reproducible results; otherwise results may vary because of variations in
the shuffling of the data.
***************************************************************************/
// Helper function to help support leave-k-out
public static String compute_folds(int numFolds, int totalInstances)
{
if (totalInstances == 0)
Error.fatalErr("CrossValidator::estimate_performance: 0 instances in dataList");
if (totalInstances == 1)
Error.fatalErr("CrossValidator::estimate_performance: Cannot estimate error for 1 instance" );
if (numFolds == 0 || numFolds > totalInstances)
Error.fatalErr("CrossValidator::estimate_performance: number of folds ("
+ numFolds + ") is invalid for data with "
+ totalInstances + " instances" );
String foldsStr;
if (numFolds > 0)
foldsStr = String.valueOf(numFolds);
else { // leave-Math.abs(numFolds)-out
int actualFolds = (int)(Math.ceil((double)(totalInstances) / Math.abs(numFolds)) + 0.5);
if (actualFolds < 2)
Error.fatalErr("CrossValidator::estimate_performance: number of folds ("
+ numFolds + ") will leave no training instances");
foldsStr = actualFolds + " (leave "
+ Math.abs(numFolds) + " out)";
numFolds = actualFolds;
MLJ.ASSERT(numFolds > 1,"CrossValidator.compute_folds: numFolds <= 1.");
}
return foldsStr;
}
public double estimate_performance(BaseInducer inducer,
InstanceList dataList)
{
// copy dataList in slow debug mode to check for ordering problems
InstanceList dataListPtr = dataList;
if (Basics.DBGSLOW) dataListPtr = (InstanceList)dataList.clone();
int totalInstances = dataList.num_instances();
int actualFolds = numFolds; // cannot be negative
String foldsStr = compute_folds(actualFolds, totalInstances);
logOptions.LOG(1, "Inducer: " + inducer.description() + "\n");
logOptions.LOG(1, "Number of folds: " + foldsStr + ", Number of times: "
+ numTimes + "\n");
perfData = null;
perfData = new PerfData();
for (int time = 0; time < numTimes; time++) {
if (numTimes > 1)
logOptions.DRIBBLE("Time " + time + "\n");
estimate_time_performance(inducer, dataListPtr, time, actualFolds);
if (numTimes > 1) // Don't print this if it's only once because we get
// the same output below
logOptions.LOG(2, "Overall: " + this + "\n");
}
// check the copied data list ordering in slow debug mode to
// make sure CValidator does not mess up list ordering
if (Basics.DBGSLOW){
if(!(dataList == dataListPtr))
Error.err( "CrossValidator::estimate_performance: ordering of dataList "
+"changed during cross-validation" + "\n");
dataListPtr = null;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -