⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 perfestdispatch.java

📁 java数据挖掘算法
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
{
   if (trainInstList == null)
      Error.fatalErr("PerfEstDispatch.estimate_performance: Null training instList");
   if (trainInstList.no_weight())
      Error.fatalErr("PerfEstDispatch.estimate_performance: "
	 +"no weight in training InstanceList");
   
   if (testInstList != null && testInstList.no_weight())
      Error.fatalErr("PerfEstDispatch.estimate_performance: "
	 +"no weight in test InstanceList");

   boolean usedAutomaticMethod = false;

   if (perfEstimationMethod == automatic) {
      usedAutomaticMethod = true;
      perfEstimationMethod = trainInstList.total_weight() < cvWeightThreshold
	 ? cv : holdOut;
      logOptions.LOG(1, "Using "
	  +perfEstimationMethodEnum.name_from_value(perfEstimationMethod)
	  +" for performance estimation.\n");
   }
   
   // This reference keeps track of the PerfEstimator to use.
   PerfEstimator perfEstimator = null;
   
   // Det up an performance estimator.  The one to use depends on the method
   // we selected in the options.  Set the options for each estimator
   // here, too.
   if (perfEstimationMethod == stratCV) {
      StratifiedCV crossValidator = new StratifiedCV();
      crossValidator.set_log_level(get_log_level());
      crossValidator.init_rand_num_gen(get_seed());
      crossValidator.set_folds(Math.min(get_cv_folds(),
					trainInstList.num_instances()));
      crossValidator.set_fraction(get_cv_fraction());
      if (get_cv_times() == 0)
         crossValidator.auto_estimate_performance(baseInducer,
						   trainInstList,
						   desStdDev,
						   maxTimes);
      else {
         crossValidator.set_times(get_cv_times());
	 crossValidator.estimate_performance(baseInducer, trainInstList);
      }
      actualTimes = crossValidator.get_times();
      perfEstimator = crossValidator;
   } else if (perfEstimationMethod == cv) {
      CrossValidator crossValidator;
      if (baseInducer.can_cast_to_incr_inducer())
	 crossValidator = new CVIncremental();
       else
	 crossValidator = new CrossValidator();

      crossValidator.set_log_level(get_log_level());
      crossValidator.set_fraction(get_cv_fraction());

      crossValidator.init_rand_num_gen(get_seed());
      crossValidator.set_folds(Math.min(get_cv_folds(),
					trainInstList.num_instances()));

      if (get_cv_times() == 0)      
         crossValidator.auto_estimate_performance(baseInducer,
						trainInstList,
						desStdDev,
						maxTimes);
      else {
 	 crossValidator.set_times(get_cv_times());
	 crossValidator.estimate_performance(baseInducer, trainInstList);
      }
      actualTimes = crossValidator.get_times();
      perfEstimator = crossValidator;
   } else if (perfEstimationMethod == bootstrap){
      actualTimes = get_bootstrap_times();
      Bootstrap bootstrap = new Bootstrap(get_bootstrap_times());
      bootstrap.set_fraction(get_bootstrap_fraction());
      bootstrap.set_type(Bootstrap.fractional);
      bootstrap.set_log_level(get_log_level());
      bootstrap.init_rand_num_gen(get_seed());
      bootstrap.estimate_performance(baseInducer, trainInstList);
      perfEstimator = bootstrap;
   } else if (perfEstimationMethod == holdOut) {
      HoldOut holdout = new HoldOut(get_holdout_times(),
				     get_holdout_number(),
				     get_holdout_percent());
      holdout.set_log_level(get_log_level());
      holdout.init_rand_num_gen(get_seed());
      holdout.estimate_performance(baseInducer, trainInstList);
      perfEstimator = holdout;
   } else if (perfEstimationMethod == testSet){
      actualTimes = 0; // We didn't do any estimation runs.

      // It is an error to use testSet if we have no real performance.
      if(testInstList == null) {
	 if(providedPerfData != null) {
	    // If an ErrorData was provided, but we have no test set,
	    // then just return the value we had before rather than
	    // aborting.  We do this to fix a design bug in SASearch
	    // which would cause an abort every time a node is
	    // reevaluated under testSet.
	    // Based on the control flow here, the net effect here is
	    // to do nothing.
	 }
	 else {
	    Error.fatalErr("PerfEstDispatch::estimate_performance: cannot use test "
	       +"set error when no testing set is provided");
	 }
      }
 
   } else
      Error.fatalErr("PerfEstDispatch::estimate_performance: "
	 +"invalid performance estimator");


   // Clear the current PerfData first.
   perfData.clear();
   
   if(perfEstimator != null) {
      // If an PerfData was provided, append its results first.
      if(providedPerfData != null)
	 perfData.append(providedPerfData);

      // Now accumulate results into perfData.
      perfData.append(perfEstimator.get_perf_data());
   }

   // Set real statistics if we have a test InstanceList.
   if(testInstList != null) {
      boolean saveDribble = GlobalOptions.dribble;
      GlobalOptions.dribble = false;

      // if full testing is supported, make sure ALL error metrics
      // in PerfData have their test set information present.
      if(baseInducer.supports_full_testing()) {
	 CatTestResult result = baseInducer.train_and_perf(trainInstList,
							    testInstList);
	 perfData.set_test_set(result);
	 if (perfEstimationMethod == testSet) {
	    MLJ.ASSERT(perfData.size() == 0,"PerfEstDispatch.estimate_performance: perfData.size() != 0.");
	    perfData.insert(result);
	 }
	 result = null;
      }
      else
	 perfData.set_test_set(baseInducer.train_and_test(trainInstList,
							  testInstList),
			       testInstList.total_weight());
      
      GlobalOptions.dribble = saveDribble;
   }

   perfEstimator = null;
   trainInstList = null;
   testInstList = null;

   // If a PerfData was provided, accumulate new results into it.
   if(providedPerfData != null) {
      providedPerfData.clear();
      providedPerfData.append(perfData);
   }   

   if (usedAutomaticMethod)
      perfEstimationMethod = automatic;
   
   // call get_error_data() to take errorType into account when
   // returning results.
   return get_error_data().error(errTrim);

}



   public double estimate_performance(BaseInducer inducer,
			  InstanceList trainList){return estimate_performance(inducer,trainList,(PerfData)null);}

   public double estimate_performance(BaseInducer baseInducer,
			  InstanceList trainInstList,
			  PerfData pPerfData)
   {
      InstanceList newTrainInstList = (InstanceList)trainInstList.clone();
      return estimate_performance(baseInducer, newTrainInstList, null, pPerfData);
   }

      
   // generic parameters
   public void set_perf_estimator(int perfEstimation ) //perfEstimation is PerfEstimationMethod enum
{
   if (perfEstimation != cv && perfEstimation != stratCV &&
	   perfEstimation != testSet && perfEstimation != bootstrap &&
           perfEstimation != holdOut && perfEstimation != automatic)
           Error.fatalErr("PerfEstDispatch::set_performance_estimator:  performance "
                  +"estimator must be either cv, stratCV, testSet, "
                  +"bootstrap, holdOut, or automatic");
   perfEstimationMethod = perfEstimation;
}

   public int get_perf_estimator() { return perfEstimationMethod; } //perfEstimationMethod is PerfEstimationMethod enum

   public void set_error_type(int errType) { errorType = errType; } //errType is ErrorType enum
   public int get_error_type() { return errorType; } //errType is ErrorType enum
   public void set_error_trim(double trim) { errTrim = trim; }
   public double get_error_trim() { return errTrim; }
   public void set_seed(int seed) { randSeed = seed; }

   public int get_seed(){ return randSeed; }

   public void set_cv_weight_threshold(double threshold) { cvWeightThreshold = threshold; }
   public double get_cv_weight_threshold() { return cvWeightThreshold; }
   
   // Parameters for cross validation.
   public void set_cv_folds(int folds)
{
   if(folds == 0 || folds == 1)
      Error.fatalErr("PerfEstDispatch::set_cv_folds: picking "+folds+" folds "
	 +"is illegal");
   cvFolds = folds;
}

   public int  get_cv_folds() { return cvFolds; }
   public void set_cv_times(int times) { cvTimes = times; }
   public int  get_cv_times() { return cvTimes; }
   public void set_cv_fraction(double fract) { cvFraction = fract; }
   public double get_cv_fraction() { return cvFraction; }
   public void set_max_times(int times) { maxTimes = times; }
   public int  get_max_times() { return maxTimes; }
   public void set_desired_std_dev(double dstd) { desStdDev = dstd; }
   public double get_desired_std_dev() { return desStdDev; }

   // Parameters for bootstrap.
   public void set_bootstrap_times(int times)
{
   if (times < 1)
      Error.fatalErr("PerfEstDispatchimator::set_bootstrap_times: times ("+times
	  +") must be at least 1");
   bootstrapTimes = times;
}

   public int  get_bootstrap_times() { return bootstrapTimes; }
   public void set_bootstrap_fraction(double fract)
{
   if(fract <= 0.0 || fract > 1.0)
      Error.fatalErr("PerfEstDispatchimator::set_bootstrap_fraction: fraction ("+
	 fract+") must be between 0.0 and 1.0");
   bootstrapFraction = fract;
}

   public double get_bootstrap_fraction() { return bootstrapFraction; }

   // Parameters for holdout
   public void set_holdout_times(int times) { hoTimes = times; }
   public int get_holdout_times() { return hoTimes; }
   public void set_holdout_number(int num) { hoNumber = num; }
   public int get_holdout_number() { return hoNumber; }
   public void set_holdout_percent(double pct) { hoPercent = pct; }
   public double get_holdout_percent() { return hoPercent; }

   // get results; get_error_data returns your chosen error data based
   // on errorType.  get_perf_data gets the full suite of error stats.
   public ErrorData get_error_data()
{
   switch(errorType) {
      case classError:
	 return perfData.get_error_data();
      case meanSquaredError:
	 return perfData.get_mean_squared_error_data();
      case meanAbsoluteError:
	 return perfData.get_mean_absolute_error_data();
      case classLoss:
	 return perfData.get_loss_data();
      default:
	 Error.fatalErr("PerfEstDispatch::get_error_data: errorType has "
	    +"illegal value: "+(int)errorType);
	 return perfData.get_error_data();
   }
}

   public PerfData get_perf_data()
{
   return perfData;
}


   // get number of times auto-cv was actually run
   public int get_actual_times() { return actualTimes; }
   
   // display:  we can display settings, performance information, or both
   public void display_settings(Writer stream)
{
   try{
   // display generic parameters first
   int acMethod = get_perf_estimator(); //PerfEstimationMethod  enum
   String acMethodString = perfEstimationMethodEnum.name_from_value(acMethod);
   stream.write("Method: "+acMethodString+"\n");
   String errorTypeString = errorTypeEnum.name_from_value(errorType);
   stream.write("Error type: "+errorTypeString+"\n");
   if (acMethod == automatic)
      stream.write("CV weight threshold: "+cvWeightThreshold+"\n");
   stream.write("Trim: "+get_error_trim()+"\n");
   stream.write("Seed: "+get_seed()+"\n");

   // CV parameters
   if(acMethod == cv || acMethod == stratCV || acMethod == automatic) {
      stream.write("Folds: "+get_cv_folds());
      if(cvTimes >= 1)
	 stream.write(",  Times: "+get_cv_times()+"\n");
      else {
	 stream.write(",  Desired std dev: "+get_desired_std_dev()+
	           ",  Max times: "+get_max_times()+"\n");
      }
   }

   // bootstrap parameters
   if(acMethod == bootstrap) {
      stream.write("Times: "+get_bootstrap_times()+",  Fraction: "+
	 get_bootstrap_fraction()+"\n");
   }

   // holdout parameters
   if(acMethod == holdOut || acMethod == automatic) {
      stream.write("Times: "+get_holdout_times());
      if(get_holdout_number() == 0)
	 stream.write(",  Percent: "+get_holdout_percent()+"\n");
      else
	 stream.write(",  Number: "+get_holdout_number()+"\n");
   }
   }catch(IOException e){e.printStackTrace();System.exit(1);}
}
/*bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 3
   public void display_performance(Writer stream)
{
   perfData.display_error(stream, errTrim);
}

   public void display(Writer stream)
{
   display_settings(stream);
   stream.write("Error: ");
   display_performance(stream);
   stream.write("\n");
   if (!perfData.perf_empty()) {
      perfData.display_non_error(stream,
				 CatTestResult.get_compute_log_loss());
      stream.write("\n");
   }
}
eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee 3*/

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -