⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 racesearch.java

📁 6个特征提取的机器学习方法
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
		j = k+1;		break;	      }	    }	  }	}      }            if (m_debug) {	System.err.println("Next set : \n"+printSets(raceSets));      }      improved = false;      winnerInfo = raceSubsets(raceSets, data, true, random);      String bs = new String(baseSet);       String win = new String(raceSets[(int)winnerInfo[0]]);      if (bs.compareTo(win) == 0) {	// race finished      } else {	if (winnerInfo[1] < baseSetError || m_rankingRequested) {	  improved = true;	  baseSetError = winnerInfo[1];	  m_bestMerit = baseSetError;	  // find which att is different	  if (m_rankingRequested) {	    for (int i = 0; i < baseSet.length; i++) {	      if (win.charAt(i) != bs.charAt(i)) {		m_rankedAtts[m_rankedSoFar][0] = i;		m_rankedAtts[m_rankedSoFar][1] = winnerInfo[1];		m_rankedSoFar++;	      }	    }	  }	  baseSet = (char [])raceSets[(int)winnerInfo[0]].clone();	} else {	  // Will get here for a subset whose error is outside the delta	  // threshold but is not *significantly* worse than the base	  // subset	  //throw new Exception("RaceSearch: problem in hillClimbRace");	}      }    }    return attributeList(baseSet);  }  /**   * Convert an attribute set to an array of indices   */  private int [] attributeList(char [] list) {    int count = 0;    for (int i=0;i<m_numAttribs;i++) {      if (list[i] == '1') {	count++;      }    }    int [] rlist = new int[count];    count = 0;     for (int i=0;i<m_numAttribs;i++) {       if (list[i] == '1') {	 rlist[count++] = i;       }     }     return rlist;  }  /**   * Races the leave-one-out cross validation errors of a set of   * attribute subsets on a set of instances.   * @param raceSets a set of attribute subset specifications   * @param data the instances to use when cross validating   * @param baseSetIncluded true if the first attribute set is a   * base set generated from the previous race   * @param random a random number generator   * @return the index of the winning subset   * @throws Exception if an error occurs during cross validation   */  private double [] raceSubsets(char [][]raceSets, Instances data,				boolean baseSetIncluded, Random random)     throws Exception {    // the evaluators --- one for each subset    ASEvaluation [] evaluators =       ASEvaluation.makeCopies(m_theEvaluator, raceSets.length);    // array of subsets eliminated from the race    boolean [] eliminated = new boolean [raceSets.length];    // individual statistics    Stats [] individualStats = new Stats [raceSets.length];    // pairwise statistics    PairedStats [][] testers =       new PairedStats[raceSets.length][raceSets.length];    /** do we ignore the base set or not? */    int startPt = m_rankingRequested ? 1 : 0;    for (int i=0;i<raceSets.length;i++) {      individualStats[i] = new Stats();      for (int j=i+1;j<raceSets.length;j++) {	testers[i][j] = new PairedStats(m_sigLevel);      }    }        BitSet [] raceBitSets = new BitSet[raceSets.length];    for (int i=0;i<raceSets.length;i++) {      raceBitSets[i] = new BitSet(m_numAttribs);      for (int j=0;j<m_numAttribs;j++) {	if (raceSets[i][j] == '1') {	  raceBitSets[i].set(j);	}      }    }    // now loop over the data points collecting leave-one-out errors for    // each attribute set    Instances trainCV;    Instances testCV;    Instance testInst;    double [] errors = new double [raceSets.length];    int eliminatedCount = 0;    int processedCount = 0;    // if there is one set left in the race then we need to continue to    // evaluate it for the remaining instances in order to get an    // accurate error estimate    processedCount = 0;    race: for (int i=0;i<m_numFolds;i++) {      // We want to randomize the data the same way for every       // learning scheme.      trainCV = data.trainCV(m_numFolds, i, new Random (1));      testCV = data.testCV(m_numFolds, i);            // loop over the surviving attribute sets building classifiers for this      // training set      for (int j=startPt;j<raceSets.length;j++) {	if (!eliminated[j]) {	  evaluators[j].buildEvaluator(trainCV);	}      }      for (int z=0;z<testCV.numInstances();z++) {	testInst = testCV.instance(z);	processedCount++;	// loop over surviving attribute sets computing errors for this	// test point	for (int zz=startPt;zz<raceSets.length;zz++) {	  if (!eliminated[zz]) {	    if (z == 0) {// first test instance---make sure classifier is built	      errors[zz] = -((HoldOutSubsetEvaluator)evaluators[zz]).		evaluateSubset(raceBitSets[zz], 			       testInst,			       true);	    } else { // must be k fold rather than leave one out	      errors[zz] = -((HoldOutSubsetEvaluator)evaluators[zz]).		evaluateSubset(raceBitSets[zz], 			       testInst,			       false);	    }	  }	}	// now update the stats	for (int j=startPt;j<raceSets.length;j++) {	  if (!eliminated[j]) {	    individualStats[j].add(errors[j]);	    for (int k=j+1;k<raceSets.length;k++) {	      if (!eliminated[k]) {		testers[j][k].add(errors[j], errors[k]);	      }	    }	  }	}      	// test for near identical models and models that are significantly	// worse than some other model	if (processedCount > m_samples-1 && 	    (eliminatedCount < raceSets.length-1)) {	  for (int j=0;j<raceSets.length;j++) {	    if (!eliminated[j]) {	      for (int k=j+1;k<raceSets.length;k++) {		if (!eliminated[k]) {		  testers[j][k].calculateDerived();		  // near identical ?		  if ((testers[j][k].differencesSignificance == 0) && 		      (Utils.eq(testers[j][k].differencesStats.mean, 0.0) ||		      (Utils.gr(m_delta, Math.abs(testers[j][k].						  differencesStats.mean))))) {		    // if they're exactly the same and there is a base set		    // in this race, make sure that the base set is NOT the		    // one eliminated.		    if (Utils.eq(testers[j][k].differencesStats.mean, 0.0)) {		      if (baseSetIncluded) { 			if (j != 0) {			  eliminated[j] = true;			} else {			  eliminated[k] = true;			}			eliminatedCount++;		      } else {			eliminated[j] = true;		      }		      if (m_debug) {			System.err.println("Eliminating (identical) "					   +j+" "+raceBitSets[j].toString()					   +" vs "+k+" "					   +raceBitSets[k].toString()					   +" after "					   +processedCount					   +" evaluations\n"					   +"\nerror "+j+" : "					   +testers[j][k].xStats.mean					   +" vs "+k+" : "					   +testers[j][k].yStats.mean					   +" diff : "					   +testers[j][k].differencesStats					   .mean);		      }		    } else {		      // eliminate the one with the higer error		      if (testers[j][k].xStats.mean > 			  testers[j][k].yStats.mean) {			eliminated[j] = true;			eliminatedCount++;			if (m_debug) {			  System.err.println("Eliminating (near identical) "					   +j+" "+raceBitSets[j].toString()					   +" vs "+k+" "					   +raceBitSets[k].toString()					   +" after "					   +processedCount					   +" evaluations\n"					   +"\nerror "+j+" : "					   +testers[j][k].xStats.mean					   +" vs "+k+" : "					   +testers[j][k].yStats.mean					   +" diff : "					   +testers[j][k].differencesStats					   .mean);			}			break;		      } else {			eliminated[k] = true;			eliminatedCount++;			if (m_debug) {			  System.err.println("Eliminating (near identical) "					   +k+" "+raceBitSets[k].toString()					   +" vs "+j+" "					   +raceBitSets[j].toString()					   +" after "					   +processedCount					   +" evaluations\n"					   +"\nerror "+k+" : "					   +testers[j][k].yStats.mean					   +" vs "+j+" : "					   +testers[j][k].xStats.mean					   +" diff : "					   +testers[j][k].differencesStats					     .mean);			}		      }		    }		  } else {		    // significantly worse ?		    if (testers[j][k].differencesSignificance != 0) {		      if (testers[j][k].differencesSignificance > 0) {			eliminated[j] = true;			eliminatedCount++;			if (m_debug) {			  System.err.println("Eliminating (-worse) "					   +j+" "+raceBitSets[j].toString()					   +" vs "+k+" "					   +raceBitSets[k].toString()					   +" after "					   +processedCount					   +" evaluations"					   +"\nerror "+j+" : "					   +testers[j][k].xStats.mean					   +" vs "+k+" : "					   +testers[j][k].yStats.mean);			}			break;		      } else {			eliminated[k] = true;			eliminatedCount++;			if (m_debug) {			  System.err.println("Eliminating (worse) "					   +k+" "+raceBitSets[k].toString()					   +" vs "+j+" "					   +raceBitSets[j].toString()					   +" after "					   +processedCount					   +" evaluations"					   +"\nerror "+k+" : "					   +testers[j][k].yStats.mean					   +" vs "+j+" : "					   +testers[j][k].xStats.mean);			}		      }		    }		  }		}    	      }	    }	  }	}	// if there is a base set from the previous race and it's the	// only remaining subset then terminate the race.	if (eliminatedCount == raceSets.length-1 && baseSetIncluded &&	    !eliminated[0] && !m_rankingRequested) {	  break race;	}      }    }    if (m_debug) {      System.err.println("*****eliminated count: "+eliminatedCount);    }    double bestError = Double.MAX_VALUE;    int bestIndex=0;    // return the index of the winner    for (int i=startPt;i<raceSets.length;i++) {      if (!eliminated[i]) {	individualStats[i].calculateDerived();	if (m_debug) {	  System.err.println("Remaining error: "+raceBitSets[i].toString()			     +" "+individualStats[i].mean);	}	if (individualStats[i].mean < bestError) {	  bestError = individualStats[i].mean;	  bestIndex = i;	}      }    }    double [] retInfo = new double[2];    retInfo[0] = bestIndex;    retInfo[1] = bestError;        if (m_debug) {      System.err.print("Best set from race : ");            for (int i=0;i<m_numAttribs;i++) {	if (raceSets[bestIndex][i] == '1') {	  System.err.print('1');	} else {	  System.err.print('0');	}      }      System.err.println(" :"+bestError+" Processed : "+(processedCount)			 +"\n"+individualStats[bestIndex].toString());    }    return retInfo;  }  /**   * Returns a string represenation   *    * @return a string representation   */  public String toString() {    StringBuffer text = new StringBuffer();        text.append("\tRaceSearch.\n\tRace type : ");    switch (m_raceType) {    case FORWARD_RACE:       text.append("forward selection race\n\tBase set : no attributes");      break;    case BACKWARD_RACE:      text.append("backward elimination race\n\tBase set : all attributes");      break;    case SCHEMATA_RACE:      text.append("schemata race\n\tBase set : no attributes");      break;    case RANK_RACE:      text.append("rank race\n\tBase set : no attributes\n\t");      text.append("Attribute evaluator : "		  + getAttributeEvaluator().getClass().getName() +" ");      if (m_ASEval instanceof OptionHandler) {	String[] evaluatorOptions = new String[0];	evaluatorOptions = ((OptionHandler)m_ASEval).getOptions();	for (int i=0;i<evaluatorOptions.length;i++) {	  text.append(evaluatorOptions[i]+' ');	}      }      text.append("\n");      text.append("\tAttribute ranking : \n");      int rlength = (int)(Math.log(m_Ranking.length) / Math.log(10) + 1);      for (int i=0;i<m_Ranking.length;i++) {	text.append("\t "+Utils.doubleToString((double)(m_Ranking[i]+1),					       rlength,0)		    +" "+m_Instances.attribute(m_Ranking[i]).name()+'\n');      }      break;    }    text.append("\n\tCross validation mode : ");    if (m_xvalType == TEN_FOLD) {      text.append("10 fold");    } else {      text.append("Leave-one-out");    }    text.append("\n\tMerit of best subset found : ");    int fieldwidth = 3;    double precision = (m_bestMerit - (int)m_bestMerit);    if (Math.abs(m_bestMerit) > 0) {      fieldwidth = (int)Math.abs((Math.log(Math.abs(m_bestMerit)) / 				  Math.log(10)))+2;    }    if (Math.abs(precision) > 0) {      precision = Math.abs((Math.log(Math.abs(precision)) / Math.log(10)))+3;    } else {      precision = 2;    }    text.append(Utils.doubleToString(Math.abs(m_bestMerit),				     fieldwidth+(int)precision,				     (int)precision)+"\n");    return text.toString();      }  /**   * Reset the search method.   */  protected void resetOptions () {    m_sigLevel = 0.001;    m_delta = 0.001;    m_ASEval = new GainRatioAttributeEval();    m_Ranking = null;    m_raceType = FORWARD_RACE;    m_debug = false;    m_theEvaluator = null;    m_bestMerit = -Double.MAX_VALUE;    m_numFolds = 10;  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -