⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 crossvalidator.java

📁 java数据挖掘算法
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
package fss;import shared.*;import shared.Error;import java.lang.*;

public class CrossValidator extends PerfEstimator {

   private int numFolds;
   private int numTimes;

   protected double fraction;

   public static int  defaultNumFolds = 10;
   public static int  defaultNumTimes = 1;
   public static int  defaultMaxFolds = 20;        // for auto_set_folds
   public static int  defaultAutoFoldTimes;   // for auto_set_folds
   public static double defaultStdDevEpsilon = 0.001;   // for auto_set_folds
   public static double defaultAccEpsilon = 0.005;      // for auto_set_folds
   public static int  defaultMaxTimes = 10;        // for auto_estimate
   public static double defaultAutoStdDev = 0.01;      // for auto_estimate


/***************************************************************************
  This class has no access to a copy constructor.
***************************************************************************/
   private CrossValidator(CrossValidator source){}

/***************************************************************************
  This class has no access to an assign method.
***************************************************************************/
   private void assign(CrossValidator source){}


/***************************************************************************
  Description : Estimate error for a single time (multi fold).
  Comments    : protected function
***************************************************************************/
protected double estimate_time_performance(BaseInducer inducer,
					    InstanceList dataList, 
					    int time, int folds)
{
   int totalInstances = dataList.num_instances();

   InstanceList shuffledList = dataList.shuffle(rand_num_gen());
   PerfData foldData = new PerfData();
   logOptions.DRIBBLE(folds + " folds: ");
   for (int fold = 0; fold < folds; fold++) {
      logOptions.DRIBBLE(fold + 1 + " ");
      int numInSplit = totalInstances / folds +
	 ((totalInstances % folds > fold)? 1:0);
      logOptions.LOG(3, "Number of instances in fold " + fold + ": " +
	  numInSplit + ".");
      InstanceList testList = shuffledList.split_prefix(numInSplit);
      InstanceList fractList = shuffledList;
      logOptions.LOG(3, "Total weights of fold " + fold + " (train/test): "
	  + fractList.total_weight() + "/" + testList.total_weight()
	  + "\n");
      
      if(fraction < 1.0) {
	 double numInFract = fraction * (double)(shuffledList.num_instances());
	 int intNumInFract = (int)(numInFract + 0.5);
	 if(intNumInFract == 0)
	    Error.fatalErr("CrossValidator.estimate_time_performance: "
	       +"No instances left in cv fraction");
	 fractList = shuffledList.split_prefix(intNumInFract);
      }
      
      logOptions.LOG(6, "Training set is:" + "\n" + fractList + "\n");
      logOptions.LOG(6, "Test set is:" + "\n" + testList + "\n");
      boolean saveDribble = GlobalOptions.dribble;
      GlobalOptions.dribble = false;
      double error = 
	 train_and_test(inducer, fractList, testList,
			"-" + time + "-" + fold,
			perfData);
      GlobalOptions.dribble = saveDribble;
      logOptions.LOG(3, "Error " + error + "\n");
      foldData.insert_error(error);
      
      if(fraction < 1.0) {
	 // clean up fractional list
	 fractList.unite(shuffledList);
	 shuffledList = fractList;
      }
      
      shuffledList.unite(testList);
      testList = null; //testList is returned null from InstanceList.unite -JL
      MLJ.ASSERT(testList == null,"CrossValidator.estimate_time_performance: testList != null.");
   }
   logOptions.DRIBBLE("\n");
   logOptions.LOG(2, "fold " + foldData + ". ");
   shuffledList = null;

   return foldData.get_error_data().mean();
}   

/***************************************************************************
  Contructor for cross-validation. See estimate_performance with same 
arguments.
***************************************************************************/
public CrossValidator(int nFolds, int nTimes)
{
   set_folds(nFolds);
   set_times(nTimes);
   set_fraction(1.0);
}

/***************************************************************************
  Contructor for cross-validation. See estimate_performance with same 
arguments.
***************************************************************************/
public CrossValidator(int nFolds)
{
   set_folds(nFolds);
   set_times(defaultNumTimes);
   set_fraction(1.0);
}

/***************************************************************************
  Contructor for cross-validation. See estimate_performance with same 
arguments.
***************************************************************************/
public CrossValidator()
{
   set_folds(defaultNumFolds);
   set_times(defaultNumTimes);
   set_fraction(1.0);
}


public void set_folds(int num)
{
   if (num == 0)
      Error.fatalErr( "CrossValidator.set_folds: num folds (" + num + ") == 0");

   numFolds = num;
}
   
public void set_times(int num)
{
   if (num <= 0)
      Error.fatalErr( "CrossValidator.set_times: num times (" + num + ") <= 0");
   numTimes = num;

}

public void set_fraction(double fract)
{
   if( fract <= 0 || fract > 1.0)
      Error.fatalErr( "CrossValidator.set_fraction: " + fract + " is out of the range (0,1]");
   fraction = fract;
}

public int get_folds(){return numFolds;}
public int get_times(){MLJ.ASSERT(numTimes > 0,"CrossValidator.get_times: numTimes <= 0"); return numTimes;}

/***************************************************************************
  Prints identifying string for this estimator.  This includes number of 
times and folds.
***************************************************************************/
public String description()
{
  return numTimes + "x" + numFolds + " Cross validator";
}

/***************************************************************************
  Trains and tests the inducer using "numFolds"-fold cross-validation, and 
repeated numTimes times. If numFolds is negative, it means leave-k-out, 
where k is Math.abs(numFolds). Shuffles the trainList before each time cross 
validation is performed. Use init_rand_num_gen(seed) to achieve 
reproducible results; otherwise results may vary because of variations in 
the shuffling of the data.
***************************************************************************/

// Helper function to help support leave-k-out
public static String compute_folds(int numFolds, int totalInstances)
{
   if (totalInstances == 0)
	 Error.fatalErr("CrossValidator::estimate_performance: 0 instances in dataList");

   if (totalInstances == 1)
      Error.fatalErr("CrossValidator::estimate_performance: Cannot estimate error for 1 instance" ); 

   if (numFolds == 0 || numFolds > totalInstances)
      Error.fatalErr("CrossValidator::estimate_performance: number of folds ("
	  + numFolds + ") is invalid for data with "
	  + totalInstances + " instances" );
   
   String foldsStr;
   if (numFolds > 0) 
      foldsStr = String.valueOf(numFolds);
   else { // leave-Math.abs(numFolds)-out
      int actualFolds = (int)(Math.ceil((double)(totalInstances) / Math.abs(numFolds)) + 0.5);
      if (actualFolds < 2)
	 Error.fatalErr("CrossValidator::estimate_performance: number of folds ("
	     + numFolds + ") will leave no training instances");
      foldsStr = actualFolds + " (leave "
             + Math.abs(numFolds) + " out)";
      numFolds = actualFolds;
      MLJ.ASSERT(numFolds > 1,"CrossValidator.compute_folds: numFolds <= 1.");
   }
   return foldsStr;
}

public double estimate_performance(BaseInducer inducer, 
				       InstanceList dataList)
{
   // copy dataList in slow debug mode to check for ordering problems
   InstanceList dataListPtr = dataList;
   if (Basics.DBGSLOW) dataListPtr = (InstanceList)dataList.clone();
   
   int totalInstances = dataList.num_instances();
   int actualFolds = numFolds; // cannot be negative
   String foldsStr = compute_folds(actualFolds, totalInstances);

   logOptions.LOG(1, "Inducer: " + inducer.description() + "\n");
   logOptions.LOG(1, "Number of folds: " + foldsStr + ", Number of times: " 
       + numTimes + "\n");
   perfData = null;
   perfData = new PerfData();
   for (int time = 0; time < numTimes; time++) {
      if (numTimes > 1)
	 logOptions.DRIBBLE("Time " + time + "\n");
      estimate_time_performance(inducer, dataListPtr, time, actualFolds);
      if (numTimes > 1) // Don't print this if it's only once because we get
			// the same output below
         logOptions.LOG(2, "Overall: " + this + "\n");
   }

   // check the copied data list ordering in slow debug mode to
   // make sure CValidator does not mess up list ordering
   if (Basics.DBGSLOW){ 
      if(!(dataList == dataListPtr))
        Error.err( "CrossValidator::estimate_performance: ordering of dataList "
               +"changed during cross-validation" + "\n");
      dataListPtr = null;
   }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -