⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 racesearch.java

📁 为了下东西 随便发了个 datamining 的源代码
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
	  if (m_rankingRequested) {
	    for (int i = 0; i < baseSet.length; i++) {
	      if (win.charAt(i) != bs.charAt(i)) {
		m_rankedAtts[m_rankedSoFar][0] = i;
		m_rankedAtts[m_rankedSoFar][1] = winnerInfo[1];
		m_rankedSoFar++;
	      }
	    }
	  }
	  baseSet = (char [])raceSets[(int)winnerInfo[0]].clone();
	} else {
	  // Will get here for a subset whose error is outside the delta
	  // threshold but is not *significantly* worse than the base
	  // subset
	  //throw new Exception("RaceSearch: problem in hillClimbRace");
	}
      }
    }
    return attributeList(baseSet);
  }

  /**
   * Convert an attribute set to an array of indices
   */
  private int [] attributeList(char [] list) {
    int count = 0;

    for (int i=0;i<m_numAttribs;i++) {
      if (list[i] == '1') {
	count++;
      }
    }

    int [] rlist = new int[count];
    count = 0;
     for (int i=0;i<m_numAttribs;i++) {
       if (list[i] == '1') {
	 rlist[count++] = i;
       }
     }

     return rlist;
  }

  /**
   * Races the leave-one-out cross validation errors of a set of
   * attribute subsets on a set of instances.
   * @param raceSets a set of attribute subset specifications
   * @param data the instances to use when cross validating
   * @param baseSetIncluded true if the first attribute set is a
   * base set generated from the previous race
   * @param random a random number generator
   * @return the index of the winning subset
   * @exception Exception if an error occurs during cross validation
   */
  private double [] raceSubsets(char [][]raceSets, Instances data,
				boolean baseSetIncluded, Random random) 
    throws Exception {
    // the evaluators --- one for each subset
    ASEvaluation [] evaluators = 
      ASEvaluation.makeCopies(m_theEvaluator, raceSets.length);

    // array of subsets eliminated from the race
    boolean [] eliminated = new boolean [raceSets.length];

    // individual statistics
    Stats [] individualStats = new Stats [raceSets.length];

    // pairwise statistics
    PairedStats [][] testers = 
      new PairedStats[raceSets.length][raceSets.length];

    /** do we ignore the base set or not? */
    int startPt = m_rankingRequested ? 1 : 0;

    for (int i=0;i<raceSets.length;i++) {
      individualStats[i] = new Stats();
      for (int j=i+1;j<raceSets.length;j++) {
	testers[i][j] = new PairedStats(m_sigLevel);
      }
    }
    
    BitSet [] raceBitSets = new BitSet[raceSets.length];
    for (int i=0;i<raceSets.length;i++) {
      raceBitSets[i] = new BitSet(m_numAttribs);
      for (int j=0;j<m_numAttribs;j++) {
	if (raceSets[i][j] == '1') {
	  raceBitSets[i].set(j);
	}
      }
    }

    // now loop over the data points collecting leave-one-out errors for
    // each attribute set
    Instances trainCV;
    Instances testCV;
    Instance testInst;
    double [] errors = new double [raceSets.length];
    int eliminatedCount = 0;
    int processedCount = 0;
    // if there is one set left in the race then we need to continue to
    // evaluate it for the remaining instances in order to get an
    // accurate error estimate
    Stats clearWinner = null;
    int foldSize=1;
    processedCount = 0;
    race: for (int i=0;i<m_numFolds;i++) {
      trainCV = data.trainCV(m_numFolds, i, random);
      testCV = data.testCV(m_numFolds, i);
      foldSize = testCV.numInstances();
      
      // loop over the surviving attribute sets building classifiers for this
      // training set
      for (int j=startPt;j<raceSets.length;j++) {
	if (!eliminated[j]) {
	  evaluators[j].buildEvaluator(trainCV);
	}
      }

      for (int z=0;z<testCV.numInstances();z++) {
	testInst = testCV.instance(z);
	processedCount++;

	// loop over surviving attribute sets computing errors for this
	// test point
	for (int zz=startPt;zz<raceSets.length;zz++) {
	  if (!eliminated[zz]) {
	    if (z == 0) {// first test instance---make sure classifier is built
	      errors[zz] = -((HoldOutSubsetEvaluator)evaluators[zz]).
		evaluateSubset(raceBitSets[zz], 
			       testInst,
			       true);
	    } else { // must be k fold rather than leave one out
	      errors[zz] = -((HoldOutSubsetEvaluator)evaluators[zz]).
		evaluateSubset(raceBitSets[zz], 
			       testInst,
			       false);
	    }
	  }
	}

	// now update the stats
	for (int j=startPt;j<raceSets.length;j++) {
	  if (!eliminated[j]) {
	    individualStats[j].add(errors[j]);
	    for (int k=j+1;k<raceSets.length;k++) {
	      if (!eliminated[k]) {
		testers[j][k].add(errors[j], errors[k]);
	      }
	    }
	  }
	}
      
	// test for near identical models and models that are significantly
	// worse than some other model
	if (processedCount > m_samples-1 && 
	    (eliminatedCount < raceSets.length-1)) {
	  for (int j=0;j<raceSets.length;j++) {
	    if (!eliminated[j]) {
	      for (int k=j+1;k<raceSets.length;k++) {
		if (!eliminated[k]) {
		  testers[j][k].calculateDerived();
		  // near identical ?
		  if ((testers[j][k].differencesSignificance == 0) && 
		      (Utils.eq(testers[j][k].differencesStats.mean, 0.0) ||
		      (Utils.gr(m_delta, Math.abs(testers[j][k].
						  differencesStats.mean))))) {
		    // if they're exactly the same and there is a base set
		    // in this race, make sure that the base set is NOT the
		    // one eliminated.
		    if (Utils.eq(testers[j][k].differencesStats.mean, 0.0)) {

		      if (baseSetIncluded) { 
			if (j != 0) {
			  eliminated[j] = true;
			} else {
			  eliminated[k] = true;
			}
			eliminatedCount++;
		      } else {
			eliminated[j] = true;
		      }
		      if (m_debug) {
			System.err.println("Eliminating (identical) "
					   +j+" "+raceBitSets[j].toString()
					   +" vs "+k+" "
					   +raceBitSets[k].toString()
					   +" after "
					   +processedCount
					   +" evaluations\n"
					   +"\nerror "+j+" : "
					   +testers[j][k].xStats.mean
					   +" vs "+k+" : "
					   +testers[j][k].yStats.mean
					   +" diff : "
					   +testers[j][k].differencesStats
					   .mean);
		      }
		    } else {
		      // eliminate the one with the higer error
		      if (testers[j][k].xStats.mean > 
			  testers[j][k].yStats.mean) {
			eliminated[j] = true;
			eliminatedCount++;
			if (m_debug) {
			  System.err.println("Eliminating (near identical) "
					   +j+" "+raceBitSets[j].toString()
					   +" vs "+k+" "
					   +raceBitSets[k].toString()
					   +" after "
					   +processedCount
					   +" evaluations\n"
					   +"\nerror "+j+" : "
					   +testers[j][k].xStats.mean
					   +" vs "+k+" : "
					   +testers[j][k].yStats.mean
					   +" diff : "
					   +testers[j][k].differencesStats
					   .mean);
			}
			break;
		      } else {
			eliminated[k] = true;
			eliminatedCount++;
			if (m_debug) {
			  System.err.println("Eliminating (near identical) "
					   +k+" "+raceBitSets[k].toString()
					   +" vs "+j+" "
					   +raceBitSets[j].toString()
					   +" after "
					   +processedCount
					   +" evaluations\n"
					   +"\nerror "+k+" : "
					   +testers[j][k].yStats.mean
					   +" vs "+j+" : "
					   +testers[j][k].xStats.mean
					   +" diff : "
					   +testers[j][k].differencesStats
					     .mean);
			}
		      }
		    }
		  } else {
		    // significantly worse ?
		    if (testers[j][k].differencesSignificance != 0) {
		      if (testers[j][k].differencesSignificance > 0) {
			eliminated[j] = true;
			eliminatedCount++;
			if (m_debug) {
			  System.err.println("Eliminating (-worse) "
					   +j+" "+raceBitSets[j].toString()
					   +" vs "+k+" "
					   +raceBitSets[k].toString()
					   +" after "
					   +processedCount
					   +" evaluations"
					   +"\nerror "+j+" : "
					   +testers[j][k].xStats.mean
					   +" vs "+k+" : "
					   +testers[j][k].yStats.mean);
			}
			break;
		      } else {
			eliminated[k] = true;
			eliminatedCount++;
			if (m_debug) {
			  System.err.println("Eliminating (worse) "
					   +k+" "+raceBitSets[k].toString()
					   +" vs "+j+" "
					   +raceBitSets[j].toString()
					   +" after "
					   +processedCount
					   +" evaluations"
					   +"\nerror "+k+" : "
					   +testers[j][k].yStats.mean
					   +" vs "+j+" : "
					   +testers[j][k].xStats.mean);
			}
		      }
		    }
		  }
		}    
	      }
	    }
	  }
	}
	// if there is a base set from the previous race and it's the
	// only remaining subset then terminate the race.
	if (eliminatedCount == raceSets.length-1 && baseSetIncluded &&
	    !eliminated[0] && !m_rankingRequested) {
	  break race;
	}
      }
    }

    if (m_debug) {
      System.err.println("*****eliminated count: "+eliminatedCount);
    }
    double bestError = Double.MAX_VALUE;
    int bestIndex=0;
    // return the index of the winner
    for (int i=startPt;i<raceSets.length;i++) {
      if (!eliminated[i]) {
	individualStats[i].calculateDerived();
	if (m_debug) {
	  System.err.println("Remaining error: "+raceBitSets[i].toString()
			     +" "+individualStats[i].mean);
	}
	if (individualStats[i].mean < bestError) {
	  bestError = individualStats[i].mean;
	  bestIndex = i;
	}
      }
    }

    double [] retInfo = new double[2];
    retInfo[0] = bestIndex;
    retInfo[1] = bestError;
    
    if (m_debug) {
      System.err.print("Best set from race : ");
      
      for (int i=0;i<m_numAttribs;i++) {
	if (raceSets[bestIndex][i] == '1') {
	  System.err.print('1');
	} else {
	  System.err.print('0');
	}
      }
      System.err.println(" :"+bestError+" Processed : "+(processedCount)
			 +"\n"+individualStats[bestIndex].toString());
    }
    return retInfo;
  }

  public String toString() {
    StringBuffer text = new StringBuffer();
    
    text.append("\tRaceSearch.\n\tRace type : ");
    switch (m_raceType) {
    case FORWARD_RACE: 
      text.append("forward selection race\n\tBase set : no attributes");
      break;
    case BACKWARD_RACE:
      text.append("backward elimination race\n\tBase set : all attributes");
      break;
    case SCHEMATA_RACE:
      text.append("schemata race\n\tBase set : no attributes");
      break;
    case RANK_RACE:
      text.append("rank race\n\tBase set : no attributes\n\t");
      text.append("Attribute evaluator : "
		  + getAttributeEvaluator().getClass().getName() +" ");
      if (m_ASEval instanceof OptionHandler) {
	String[] evaluatorOptions = new String[0];
	evaluatorOptions = ((OptionHandler)m_ASEval).getOptions();
	for (int i=0;i<evaluatorOptions.length;i++) {
	  text.append(evaluatorOptions[i]+' ');
	}
      }
      text.append("\n");
      text.append("\tAttribute ranking : \n");
      int rlength = (int)(Math.log(m_Ranking.length) / Math.log(10) + 1);
      for (int i=0;i<m_Ranking.length;i++) {
	text.append("\t "+Utils.doubleToString((double)(m_Ranking[i]+1),
					       rlength,0)
		    +" "+m_Instances.attribute(m_Ranking[i]).name()+'\n');
      }
      break;
    }
    text.append("\n\tCross validation mode : ");
    if (m_xvalType == TEN_FOLD) {
      text.append("10 fold");
    } else {
      text.append("Leave-one-out");
    }

    text.append("\n\tMerit of best subset found : ");
    int fieldwidth = 3;
    double precision = (m_bestMerit - (int)m_bestMerit);
    if (Math.abs(m_bestMerit) > 0) {
      fieldwidth = (int)Math.abs((Math.log(Math.abs(m_bestMerit)) / 
				  Math.log(10)))+2;
    }
    if (Math.abs(precision) > 0) {
      precision = Math.abs((Math.log(Math.abs(precision)) / Math.log(10)))+3;
    } else {
      precision = 2;
    }

    text.append(Utils.doubleToString(Math.abs(m_bestMerit),
				     fieldwidth+(int)precision,
				     (int)precision)+"\n");
    return text.toString();
    
  }

  /**
   * Reset the search method.
   */
  protected void resetOptions () {
    m_sigLevel = 0.001;
    m_delta = 0.001;
    m_ASEval = new GainRatioAttributeEval();
    m_Ranking = null;
    m_raceType = FORWARD_RACE;
    m_debug = false;
    m_theEvaluator = null;
    m_bestMerit = -Double.MAX_VALUE;
    m_numFolds = 10;
  }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -