⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 crossvalidator.java

📁 java数据挖掘算法
💻 JAVA
📖 第 1 页 / 共 2 页
字号:

   if (perfData.size() != numTimes * actualFolds)
      Error.fatalErr("CrossValidator:estimate_performance error size "
	  + perfData.size() + " does not match expected size "
	  + (numTimes * actualFolds)
	  + ". Probable cause: train_and_test not updating perfData");
   
   logOptions.LOG(1, "Untrimmed error ");
   if(get_log_level() >= 1) display_error_data(get_log_stream());
   logOptions.LOG(1, "\n");

   // set cost within this accData
   perfData.insert_cost(actualFolds * numTimes);
   return error();
}

/***************************************************************************
  Automatically estimate error to the given std-dev level.
***************************************************************************/

public double auto_estimate_performance(BaseInducer inducer,
					    InstanceList dataList,
					    double desiredStdDev,
					    int  maxTimes)
{
   if (numFolds == -1)
      Error.fatalErr("CrossValidator::auto_estimate_performance: it does not make "
	     +" sense to do leave-one-out multiple times");

   int totalInstances = dataList.num_instances();
   int actualFolds = numFolds; // cannot be negative
   compute_folds(actualFolds, totalInstances);

   logOptions.LOG(1, "Inducer: " + inducer.description() + "\n");
   logOptions.DRIBBLE("Number of folds: " + actualFolds + ". Looping until "
          +"std-dev of mean =" + desiredStdDev + "\n");
   perfData = null;
   perfData = new PerfData();
   int time = 0;
   do {
      set_times(time + 1); // so asserts will work.
      estimate_time_performance(inducer, dataList, time, actualFolds);
      // The NULL in the statement below belongs to the else in the LOG()
      //   macro.   The compiler crashes without it!  If you take it out and
      //   it works, leave it out.
      logOptions.LOG(2, "Overall error: " + this + "\n");
   } while (++time < maxTimes && error_std_dev() > desiredStdDev);

   set_times(time); // so caller can do do get_times();

   if (perfData.size() != actualFolds * numTimes)
      Error.fatalErr("CrossValidator:auto_estimate_performance error size "
	  + perfData.size() + " does not match expected size "
	  + (actualFolds * time)
	  + ". Probable cause: train_and_test not updating perfData");
   
   logOptions.LOG(1, "Untrimmed error " + this + "\n");

   perfData.insert_cost(actualFolds * get_times());
   
   return error();
}

public double auto_estimate_performance(BaseInducer inducer,
					    InstanceList dataList,
					    double desiredStdDev)
{
   return auto_estimate_performance(inducer,dataList,desiredStdDev,defaultMaxTimes);
}

public double auto_estimate_performance(BaseInducer inducer,
					    InstanceList dataList)
{
   return auto_estimate_performance(inducer,dataList,defaultAutoStdDev,defaultMaxTimes);
}

/***************************************************************************
  Trains and tests the inducer using files. Uses the files fileStem.names, 
fileStem-T-F.data, and fileStem-T-F.test, where T is an integer in the range
[0, numTimes-1] and F is an integer in the range [0, numFolds-1]. 
Unimplemented in this version.
***************************************************************************/

public double estimate_performance(BaseInducer inducer, String fileStem)
{
   perfData = null;
   perfData = new PerfData();
   
   for (int time = 0; time < numTimes; time++) {
      double err =
	 estimate_file_performance(inducer, numFolds,
				fileStem + "-" + time, perfData);
      logOptions.LOG(2, "fold error for set " + time + " is " + err + "\n");
   }

   logOptions.LOG(1, "error: " + this + "\n");

   return error();
}

/***************************************************************************
  Attempt to find good values for the numFolds that will be suitable for 
the current training set size. This is an expensive operation that is 
useful if you intend to do many cross validations on variants of this 
dataset. It decreases the number of folds until the error or std-dev 
deteriorates.
***************************************************************************/

public void auto_set_folds(BaseInducer inducer,
				    InstanceList dataList,
				    int maxFolds,
				    int maxTimes,
				    double accEpsilon,
				    double stdDevEpsilon)

{
   int saveLogLevel = get_log_level();
   int saveNumTimes = get_times();
   set_folds(maxFolds); 
   double prevStdDev = 1;       // high number for StdDev
   double prevErr = 1.0;        // worst possible error
   int  prevFolds = maxFolds; // in case we stop on first iteration.

   if (maxFolds < 2)
      Error.fatalErr("CrossValidator::auto_set_folds: maxFolds: " + maxFolds
	  + " less than 2");
   
   if (stdDevEpsilon < 0)
      Error.fatalErr("CrossValidator::auto_set_folds: stdDevEpsilon: "
	  + stdDevEpsilon + " negative");

   if (accEpsilon < 0)
      Error.fatalErr("CrossValidator::auto_set_folds: accEpsilon: "
	  + accEpsilon + " negative");

   if (maxFolds > dataList.num_instances()) {
      maxFolds = dataList.num_instances();
      logOptions.LOG(1, "Max folds set to number of instances: " + maxFolds + "\n");
   }

   do {
      logOptions.LOG(2, "auto_set_folds trying " + get_folds() + " folds...");
      set_log_level(saveLogLevel - 2); // estimate at loglevel-2
      auto_estimate_performance(inducer, dataList, defaultAutoStdDev,
				maxTimes);
      set_log_level(saveLogLevel);

      // The real std-dev is higher since we're suppose to find
      //   a good value for 1xfolds.
      double stdDev = error_std_dev() * Math.sqrt(get_times());
      double errorRate = error();

      logOptions.LOG(2, "mean " + Math.round(error()) * 100 + '%' 
          + " +- " + Math.round(stdDev) * 100 + '%' 
	  + " (average of " + get_times() + ")" +  "\n");

      // If one of the following is true, we stop and return previous number
      // of folds:
      //   1. Error increases by accEpsilon.
      //   2. Std-dev goes up by stdDevEpsilon.
      
      if (errorRate >= prevErr + accEpsilon ||
	  stdDev >= prevStdDev + stdDevEpsilon) {
	 set_times(saveNumTimes);
	 MLJ.ASSERT(prevFolds >= 2,"CrossValidator.auto_set_folds: prevFolds < 2.");
	 set_folds(prevFolds);
	 logOptions.LOG(1, "Fold setting set to " + get_folds()
	     + " (significant deterioration or above std-dev threshold)"
	     + "\n");
	 return;
      }

      prevStdDev = stdDev;
      prevErr = errorRate;
      prevFolds = get_folds();
      int newFolds = get_folds() / 2;
      if (newFolds < 2)
	 newFolds = 2;
      set_folds(newFolds);
   } while (prevFolds > 2);

   set_times(saveNumTimes);
   logOptions.LOG(1, "Fold setting set to " +  get_folds()
           + " (minimum possible)" + "\n");
   return;
}

public void auto_set_folds(BaseInducer inducer,
				    InstanceList dataList)

{
   auto_set_folds(inducer,dataList,defaultMaxFolds,defaultMaxTimes,defaultAccEpsilon,defaultStdDevEpsilon);
}

public void auto_set_folds(BaseInducer inducer,
				    InstanceList dataList,
				    int maxFolds)

{
   auto_set_folds(inducer,dataList,maxFolds,defaultMaxTimes,defaultAccEpsilon,defaultStdDevEpsilon);
}

public void auto_set_folds(BaseInducer inducer,
				    InstanceList dataList,
				    int maxFolds,
				    int maxTimes)

{
   auto_set_folds(inducer,dataList,maxFolds,maxTimes,defaultAccEpsilon,defaultStdDevEpsilon);
}

public void auto_set_folds(BaseInducer inducer,
				    InstanceList dataList,
				    int maxFolds,
				    int maxTimes,
				    double accEpsilon)

{
   auto_set_folds(inducer,dataList,maxFolds,maxTimes,accEpsilon,defaultStdDevEpsilon);
}

/*
public:
   // Public data

   // Methods
   // Auto-settings attempts to set numFolds by an ad-hoc to a reasonable
   //   number such that the variance resulting from the training-set size is
   //   not too big.
   virtual void auto_set_folds(BaseInducer& inducer, InstanceList& dataList,
			       int maxFolds = defaultMaxFolds,
			       int maxTimes = defaultMaxTimes,
			       double AccEpsilon  = defaultAccEpsilon,
			       double stdDevEpsilon = defaultStdDevEpsilon);
*/
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -