📄 perfestdispatch.java
字号:
{
if (trainInstList == null)
Error.fatalErr("PerfEstDispatch.estimate_performance: Null training instList");
if (trainInstList.no_weight())
Error.fatalErr("PerfEstDispatch.estimate_performance: "
+"no weight in training InstanceList");
if (testInstList != null && testInstList.no_weight())
Error.fatalErr("PerfEstDispatch.estimate_performance: "
+"no weight in test InstanceList");
boolean usedAutomaticMethod = false;
if (perfEstimationMethod == automatic) {
usedAutomaticMethod = true;
perfEstimationMethod = trainInstList.total_weight() < cvWeightThreshold
? cv : holdOut;
logOptions.LOG(1, "Using "
+perfEstimationMethodEnum.name_from_value(perfEstimationMethod)
+" for performance estimation.\n");
}
// This reference keeps track of the PerfEstimator to use.
PerfEstimator perfEstimator = null;
// Det up an performance estimator. The one to use depends on the method
// we selected in the options. Set the options for each estimator
// here, too.
if (perfEstimationMethod == stratCV) {
StratifiedCV crossValidator = new StratifiedCV();
crossValidator.set_log_level(get_log_level());
crossValidator.init_rand_num_gen(get_seed());
crossValidator.set_folds(Math.min(get_cv_folds(),
trainInstList.num_instances()));
crossValidator.set_fraction(get_cv_fraction());
if (get_cv_times() == 0)
crossValidator.auto_estimate_performance(baseInducer,
trainInstList,
desStdDev,
maxTimes);
else {
crossValidator.set_times(get_cv_times());
crossValidator.estimate_performance(baseInducer, trainInstList);
}
actualTimes = crossValidator.get_times();
perfEstimator = crossValidator;
} else if (perfEstimationMethod == cv) {
CrossValidator crossValidator;
if (baseInducer.can_cast_to_incr_inducer())
crossValidator = new CVIncremental();
else
crossValidator = new CrossValidator();
crossValidator.set_log_level(get_log_level());
crossValidator.set_fraction(get_cv_fraction());
crossValidator.init_rand_num_gen(get_seed());
crossValidator.set_folds(Math.min(get_cv_folds(),
trainInstList.num_instances()));
if (get_cv_times() == 0)
crossValidator.auto_estimate_performance(baseInducer,
trainInstList,
desStdDev,
maxTimes);
else {
crossValidator.set_times(get_cv_times());
crossValidator.estimate_performance(baseInducer, trainInstList);
}
actualTimes = crossValidator.get_times();
perfEstimator = crossValidator;
} else if (perfEstimationMethod == bootstrap){
actualTimes = get_bootstrap_times();
Bootstrap bootstrap = new Bootstrap(get_bootstrap_times());
bootstrap.set_fraction(get_bootstrap_fraction());
bootstrap.set_type(Bootstrap.fractional);
bootstrap.set_log_level(get_log_level());
bootstrap.init_rand_num_gen(get_seed());
bootstrap.estimate_performance(baseInducer, trainInstList);
perfEstimator = bootstrap;
} else if (perfEstimationMethod == holdOut) {
HoldOut holdout = new HoldOut(get_holdout_times(),
get_holdout_number(),
get_holdout_percent());
holdout.set_log_level(get_log_level());
holdout.init_rand_num_gen(get_seed());
holdout.estimate_performance(baseInducer, trainInstList);
perfEstimator = holdout;
} else if (perfEstimationMethod == testSet){
actualTimes = 0; // We didn't do any estimation runs.
// It is an error to use testSet if we have no real performance.
if(testInstList == null) {
if(providedPerfData != null) {
// If an ErrorData was provided, but we have no test set,
// then just return the value we had before rather than
// aborting. We do this to fix a design bug in SASearch
// which would cause an abort every time a node is
// reevaluated under testSet.
// Based on the control flow here, the net effect here is
// to do nothing.
}
else {
Error.fatalErr("PerfEstDispatch::estimate_performance: cannot use test "
+"set error when no testing set is provided");
}
}
} else
Error.fatalErr("PerfEstDispatch::estimate_performance: "
+"invalid performance estimator");
// Clear the current PerfData first.
perfData.clear();
if(perfEstimator != null) {
// If an PerfData was provided, append its results first.
if(providedPerfData != null)
perfData.append(providedPerfData);
// Now accumulate results into perfData.
perfData.append(perfEstimator.get_perf_data());
}
// Set real statistics if we have a test InstanceList.
if(testInstList != null) {
boolean saveDribble = GlobalOptions.dribble;
GlobalOptions.dribble = false;
// if full testing is supported, make sure ALL error metrics
// in PerfData have their test set information present.
if(baseInducer.supports_full_testing()) {
CatTestResult result = baseInducer.train_and_perf(trainInstList,
testInstList);
perfData.set_test_set(result);
if (perfEstimationMethod == testSet) {
MLJ.ASSERT(perfData.size() == 0,"PerfEstDispatch.estimate_performance: perfData.size() != 0.");
perfData.insert(result);
}
result = null;
}
else
perfData.set_test_set(baseInducer.train_and_test(trainInstList,
testInstList),
testInstList.total_weight());
GlobalOptions.dribble = saveDribble;
}
perfEstimator = null;
trainInstList = null;
testInstList = null;
// If a PerfData was provided, accumulate new results into it.
if(providedPerfData != null) {
providedPerfData.clear();
providedPerfData.append(perfData);
}
if (usedAutomaticMethod)
perfEstimationMethod = automatic;
// call get_error_data() to take errorType into account when
// returning results.
return get_error_data().error(errTrim);
}
public double estimate_performance(BaseInducer inducer,
InstanceList trainList){return estimate_performance(inducer,trainList,(PerfData)null);}
public double estimate_performance(BaseInducer baseInducer,
InstanceList trainInstList,
PerfData pPerfData)
{
InstanceList newTrainInstList = (InstanceList)trainInstList.clone();
return estimate_performance(baseInducer, newTrainInstList, null, pPerfData);
}
// generic parameters
public void set_perf_estimator(int perfEstimation ) //perfEstimation is PerfEstimationMethod enum
{
if (perfEstimation != cv && perfEstimation != stratCV &&
perfEstimation != testSet && perfEstimation != bootstrap &&
perfEstimation != holdOut && perfEstimation != automatic)
Error.fatalErr("PerfEstDispatch::set_performance_estimator: performance "
+"estimator must be either cv, stratCV, testSet, "
+"bootstrap, holdOut, or automatic");
perfEstimationMethod = perfEstimation;
}
public int get_perf_estimator() { return perfEstimationMethod; } //perfEstimationMethod is PerfEstimationMethod enum
public void set_error_type(int errType) { errorType = errType; } //errType is ErrorType enum
public int get_error_type() { return errorType; } //errType is ErrorType enum
public void set_error_trim(double trim) { errTrim = trim; }
public double get_error_trim() { return errTrim; }
public void set_seed(int seed) { randSeed = seed; }
public int get_seed(){ return randSeed; }
public void set_cv_weight_threshold(double threshold) { cvWeightThreshold = threshold; }
public double get_cv_weight_threshold() { return cvWeightThreshold; }
// Parameters for cross validation.
public void set_cv_folds(int folds)
{
if(folds == 0 || folds == 1)
Error.fatalErr("PerfEstDispatch::set_cv_folds: picking "+folds+" folds "
+"is illegal");
cvFolds = folds;
}
public int get_cv_folds() { return cvFolds; }
public void set_cv_times(int times) { cvTimes = times; }
public int get_cv_times() { return cvTimes; }
public void set_cv_fraction(double fract) { cvFraction = fract; }
public double get_cv_fraction() { return cvFraction; }
public void set_max_times(int times) { maxTimes = times; }
public int get_max_times() { return maxTimes; }
public void set_desired_std_dev(double dstd) { desStdDev = dstd; }
public double get_desired_std_dev() { return desStdDev; }
// Parameters for bootstrap.
public void set_bootstrap_times(int times)
{
if (times < 1)
Error.fatalErr("PerfEstDispatchimator::set_bootstrap_times: times ("+times
+") must be at least 1");
bootstrapTimes = times;
}
public int get_bootstrap_times() { return bootstrapTimes; }
public void set_bootstrap_fraction(double fract)
{
if(fract <= 0.0 || fract > 1.0)
Error.fatalErr("PerfEstDispatchimator::set_bootstrap_fraction: fraction ("+
fract+") must be between 0.0 and 1.0");
bootstrapFraction = fract;
}
public double get_bootstrap_fraction() { return bootstrapFraction; }
// Parameters for holdout
public void set_holdout_times(int times) { hoTimes = times; }
public int get_holdout_times() { return hoTimes; }
public void set_holdout_number(int num) { hoNumber = num; }
public int get_holdout_number() { return hoNumber; }
public void set_holdout_percent(double pct) { hoPercent = pct; }
public double get_holdout_percent() { return hoPercent; }
// get results; get_error_data returns your chosen error data based
// on errorType. get_perf_data gets the full suite of error stats.
public ErrorData get_error_data()
{
switch(errorType) {
case classError:
return perfData.get_error_data();
case meanSquaredError:
return perfData.get_mean_squared_error_data();
case meanAbsoluteError:
return perfData.get_mean_absolute_error_data();
case classLoss:
return perfData.get_loss_data();
default:
Error.fatalErr("PerfEstDispatch::get_error_data: errorType has "
+"illegal value: "+(int)errorType);
return perfData.get_error_data();
}
}
public PerfData get_perf_data()
{
return perfData;
}
// get number of times auto-cv was actually run
public int get_actual_times() { return actualTimes; }
// display: we can display settings, performance information, or both
public void display_settings(Writer stream)
{
try{
// display generic parameters first
int acMethod = get_perf_estimator(); //PerfEstimationMethod enum
String acMethodString = perfEstimationMethodEnum.name_from_value(acMethod);
stream.write("Method: "+acMethodString+"\n");
String errorTypeString = errorTypeEnum.name_from_value(errorType);
stream.write("Error type: "+errorTypeString+"\n");
if (acMethod == automatic)
stream.write("CV weight threshold: "+cvWeightThreshold+"\n");
stream.write("Trim: "+get_error_trim()+"\n");
stream.write("Seed: "+get_seed()+"\n");
// CV parameters
if(acMethod == cv || acMethod == stratCV || acMethod == automatic) {
stream.write("Folds: "+get_cv_folds());
if(cvTimes >= 1)
stream.write(", Times: "+get_cv_times()+"\n");
else {
stream.write(", Desired std dev: "+get_desired_std_dev()+
", Max times: "+get_max_times()+"\n");
}
}
// bootstrap parameters
if(acMethod == bootstrap) {
stream.write("Times: "+get_bootstrap_times()+", Fraction: "+
get_bootstrap_fraction()+"\n");
}
// holdout parameters
if(acMethod == holdOut || acMethod == automatic) {
stream.write("Times: "+get_holdout_times());
if(get_holdout_number() == 0)
stream.write(", Percent: "+get_holdout_percent()+"\n");
else
stream.write(", Number: "+get_holdout_number()+"\n");
}
}catch(IOException e){e.printStackTrace();System.exit(1);}
}
/*bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 3
public void display_performance(Writer stream)
{
perfData.display_error(stream, errTrim);
}
public void display(Writer stream)
{
display_settings(stream);
stream.write("Error: ");
display_performance(stream);
stream.write("\n");
if (!perfData.perf_empty()) {
perfData.display_non_error(stream,
CatTestResult.get_compute_log_loss());
stream.write("\n");
}
}
eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee 3*/
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -