📄 crossvalidator.java
字号:
if (perfData.size() != numTimes * actualFolds)
Error.fatalErr("CrossValidator:estimate_performance error size "
+ perfData.size() + " does not match expected size "
+ (numTimes * actualFolds)
+ ". Probable cause: train_and_test not updating perfData");
logOptions.LOG(1, "Untrimmed error ");
if(get_log_level() >= 1) display_error_data(get_log_stream());
logOptions.LOG(1, "\n");
// set cost within this accData
perfData.insert_cost(actualFolds * numTimes);
return error();
}
/***************************************************************************
Automatically estimate error to the given std-dev level.
***************************************************************************/
public double auto_estimate_performance(BaseInducer inducer,
InstanceList dataList,
double desiredStdDev,
int maxTimes)
{
if (numFolds == -1)
Error.fatalErr("CrossValidator::auto_estimate_performance: it does not make "
+" sense to do leave-one-out multiple times");
int totalInstances = dataList.num_instances();
int actualFolds = numFolds; // cannot be negative
compute_folds(actualFolds, totalInstances);
logOptions.LOG(1, "Inducer: " + inducer.description() + "\n");
logOptions.DRIBBLE("Number of folds: " + actualFolds + ". Looping until "
+"std-dev of mean =" + desiredStdDev + "\n");
perfData = null;
perfData = new PerfData();
int time = 0;
do {
set_times(time + 1); // so asserts will work.
estimate_time_performance(inducer, dataList, time, actualFolds);
// The NULL in the statement below belongs to the else in the LOG()
// macro. The compiler crashes without it! If you take it out and
// it works, leave it out.
logOptions.LOG(2, "Overall error: " + this + "\n");
} while (++time < maxTimes && error_std_dev() > desiredStdDev);
set_times(time); // so caller can do do get_times();
if (perfData.size() != actualFolds * numTimes)
Error.fatalErr("CrossValidator:auto_estimate_performance error size "
+ perfData.size() + " does not match expected size "
+ (actualFolds * time)
+ ". Probable cause: train_and_test not updating perfData");
logOptions.LOG(1, "Untrimmed error " + this + "\n");
perfData.insert_cost(actualFolds * get_times());
return error();
}
public double auto_estimate_performance(BaseInducer inducer,
InstanceList dataList,
double desiredStdDev)
{
return auto_estimate_performance(inducer,dataList,desiredStdDev,defaultMaxTimes);
}
public double auto_estimate_performance(BaseInducer inducer,
InstanceList dataList)
{
return auto_estimate_performance(inducer,dataList,defaultAutoStdDev,defaultMaxTimes);
}
/***************************************************************************
Trains and tests the inducer using files. Uses the files fileStem.names,
fileStem-T-F.data, and fileStem-T-F.test, where T is an integer in the range
[0, numTimes-1] and F is an integer in the range [0, numFolds-1].
Unimplemented in this version.
***************************************************************************/
public double estimate_performance(BaseInducer inducer, String fileStem)
{
perfData = null;
perfData = new PerfData();
for (int time = 0; time < numTimes; time++) {
double err =
estimate_file_performance(inducer, numFolds,
fileStem + "-" + time, perfData);
logOptions.LOG(2, "fold error for set " + time + " is " + err + "\n");
}
logOptions.LOG(1, "error: " + this + "\n");
return error();
}
/***************************************************************************
Attempt to find good values for the numFolds that will be suitable for
the current training set size. This is an expensive operation that is
useful if you intend to do many cross validations on variants of this
dataset. It decreases the number of folds until the error or std-dev
deteriorates.
***************************************************************************/
public void auto_set_folds(BaseInducer inducer,
InstanceList dataList,
int maxFolds,
int maxTimes,
double accEpsilon,
double stdDevEpsilon)
{
int saveLogLevel = get_log_level();
int saveNumTimes = get_times();
set_folds(maxFolds);
double prevStdDev = 1; // high number for StdDev
double prevErr = 1.0; // worst possible error
int prevFolds = maxFolds; // in case we stop on first iteration.
if (maxFolds < 2)
Error.fatalErr("CrossValidator::auto_set_folds: maxFolds: " + maxFolds
+ " less than 2");
if (stdDevEpsilon < 0)
Error.fatalErr("CrossValidator::auto_set_folds: stdDevEpsilon: "
+ stdDevEpsilon + " negative");
if (accEpsilon < 0)
Error.fatalErr("CrossValidator::auto_set_folds: accEpsilon: "
+ accEpsilon + " negative");
if (maxFolds > dataList.num_instances()) {
maxFolds = dataList.num_instances();
logOptions.LOG(1, "Max folds set to number of instances: " + maxFolds + "\n");
}
do {
logOptions.LOG(2, "auto_set_folds trying " + get_folds() + " folds...");
set_log_level(saveLogLevel - 2); // estimate at loglevel-2
auto_estimate_performance(inducer, dataList, defaultAutoStdDev,
maxTimes);
set_log_level(saveLogLevel);
// The real std-dev is higher since we're suppose to find
// a good value for 1xfolds.
double stdDev = error_std_dev() * Math.sqrt(get_times());
double errorRate = error();
logOptions.LOG(2, "mean " + Math.round(error()) * 100 + '%'
+ " +- " + Math.round(stdDev) * 100 + '%'
+ " (average of " + get_times() + ")" + "\n");
// If one of the following is true, we stop and return previous number
// of folds:
// 1. Error increases by accEpsilon.
// 2. Std-dev goes up by stdDevEpsilon.
if (errorRate >= prevErr + accEpsilon ||
stdDev >= prevStdDev + stdDevEpsilon) {
set_times(saveNumTimes);
MLJ.ASSERT(prevFolds >= 2,"CrossValidator.auto_set_folds: prevFolds < 2.");
set_folds(prevFolds);
logOptions.LOG(1, "Fold setting set to " + get_folds()
+ " (significant deterioration or above std-dev threshold)"
+ "\n");
return;
}
prevStdDev = stdDev;
prevErr = errorRate;
prevFolds = get_folds();
int newFolds = get_folds() / 2;
if (newFolds < 2)
newFolds = 2;
set_folds(newFolds);
} while (prevFolds > 2);
set_times(saveNumTimes);
logOptions.LOG(1, "Fold setting set to " + get_folds()
+ " (minimum possible)" + "\n");
return;
}
public void auto_set_folds(BaseInducer inducer,
InstanceList dataList)
{
auto_set_folds(inducer,dataList,defaultMaxFolds,defaultMaxTimes,defaultAccEpsilon,defaultStdDevEpsilon);
}
public void auto_set_folds(BaseInducer inducer,
InstanceList dataList,
int maxFolds)
{
auto_set_folds(inducer,dataList,maxFolds,defaultMaxTimes,defaultAccEpsilon,defaultStdDevEpsilon);
}
public void auto_set_folds(BaseInducer inducer,
InstanceList dataList,
int maxFolds,
int maxTimes)
{
auto_set_folds(inducer,dataList,maxFolds,maxTimes,defaultAccEpsilon,defaultStdDevEpsilon);
}
public void auto_set_folds(BaseInducer inducer,
InstanceList dataList,
int maxFolds,
int maxTimes,
double accEpsilon)
{
auto_set_folds(inducer,dataList,maxFolds,maxTimes,accEpsilon,defaultStdDevEpsilon);
}
/*
public:
// Public data
// Methods
// Auto-settings attempts to set numFolds by an ad-hoc to a reasonable
// number such that the variance resulting from the training-set size is
// not too big.
virtual void auto_set_folds(BaseInducer& inducer, InstanceList& dataList,
int maxFolds = defaultMaxFolds,
int maxTimes = defaultMaxTimes,
double AccEpsilon = defaultAccEpsilon,
double stdDevEpsilon = defaultStdDevEpsilon);
*/
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -