📄 racesearch.java
字号:
// individual statistics Stats [] individualStats = new Stats [raceSets.length]; // pairwise statistics PairedStats [][] testers = new PairedStats[raceSets.length][raceSets.length]; /** do we ignore the base set or not? */ int startPt = m_rankingRequested ? 1 : 0; for (int i=0;i<raceSets.length;i++) { individualStats[i] = new Stats(); for (int j=i+1;j<raceSets.length;j++) { testers[i][j] = new PairedStats(m_sigLevel); } } BitSet [] raceBitSets = new BitSet[raceSets.length]; for (int i=0;i<raceSets.length;i++) { raceBitSets[i] = new BitSet(m_numAttribs); for (int j=0;j<m_numAttribs;j++) { if (raceSets[i][j] == '1') { raceBitSets[i].set(j); } } } // now loop over the data points collecting leave-one-out errors for // each attribute set Instances trainCV; Instances testCV; Instance testInst; double [] errors = new double [raceSets.length]; int eliminatedCount = 0; int processedCount = 0; // if there is one set left in the race then we need to continue to // evaluate it for the remaining instances in order to get an // accurate error estimate processedCount = 0; race: for (int i=0;i<m_numFolds;i++) { // We want to randomize the data the same way for every // learning scheme. trainCV = data.trainCV(m_numFolds, i, new Random (1)); testCV = data.testCV(m_numFolds, i); // loop over the surviving attribute sets building classifiers for this // training set for (int j=startPt;j<raceSets.length;j++) { if (!eliminated[j]) { evaluators[j].buildEvaluator(trainCV); } } for (int z=0;z<testCV.numInstances();z++) { testInst = testCV.instance(z); processedCount++; // loop over surviving attribute sets computing errors for this // test point for (int zz=startPt;zz<raceSets.length;zz++) { if (!eliminated[zz]) { if (z == 0) {// first test instance---make sure classifier is built errors[zz] = -((HoldOutSubsetEvaluator)evaluators[zz]). evaluateSubset(raceBitSets[zz], testInst, true); } else { // must be k fold rather than leave one out errors[zz] = -((HoldOutSubsetEvaluator)evaluators[zz]). evaluateSubset(raceBitSets[zz], testInst, false); } } } // now update the stats for (int j=startPt;j<raceSets.length;j++) { if (!eliminated[j]) { individualStats[j].add(errors[j]); for (int k=j+1;k<raceSets.length;k++) { if (!eliminated[k]) { testers[j][k].add(errors[j], errors[k]); } } } } // test for near identical models and models that are significantly // worse than some other model if (processedCount > m_samples-1 && (eliminatedCount < raceSets.length-1)) { for (int j=0;j<raceSets.length;j++) { if (!eliminated[j]) { for (int k=j+1;k<raceSets.length;k++) { if (!eliminated[k]) { testers[j][k].calculateDerived(); // near identical ? if ((testers[j][k].differencesSignificance == 0) && (Utils.eq(testers[j][k].differencesStats.mean, 0.0) || (Utils.gr(m_delta, Math.abs(testers[j][k]. differencesStats.mean))))) { // if they're exactly the same and there is a base set // in this race, make sure that the base set is NOT the // one eliminated. if (Utils.eq(testers[j][k].differencesStats.mean, 0.0)) { if (baseSetIncluded) { if (j != 0) { eliminated[j] = true; } else { eliminated[k] = true; } eliminatedCount++; } else { eliminated[j] = true; } if (m_debug) { System.err.println("Eliminating (identical) " +j+" "+raceBitSets[j].toString() +" vs "+k+" " +raceBitSets[k].toString() +" after " +processedCount +" evaluations\n" +"\nerror "+j+" : " +testers[j][k].xStats.mean +" vs "+k+" : " +testers[j][k].yStats.mean +" diff : " +testers[j][k].differencesStats .mean); } } else { // eliminate the one with the higer error if (testers[j][k].xStats.mean > testers[j][k].yStats.mean) { eliminated[j] = true; eliminatedCount++; if (m_debug) { System.err.println("Eliminating (near identical) " +j+" "+raceBitSets[j].toString() +" vs "+k+" " +raceBitSets[k].toString() +" after " +processedCount +" evaluations\n" +"\nerror "+j+" : " +testers[j][k].xStats.mean +" vs "+k+" : " +testers[j][k].yStats.mean +" diff : " +testers[j][k].differencesStats .mean); } break; } else { eliminated[k] = true; eliminatedCount++; if (m_debug) { System.err.println("Eliminating (near identical) " +k+" "+raceBitSets[k].toString() +" vs "+j+" " +raceBitSets[j].toString() +" after " +processedCount +" evaluations\n" +"\nerror "+k+" : " +testers[j][k].yStats.mean +" vs "+j+" : " +testers[j][k].xStats.mean +" diff : " +testers[j][k].differencesStats .mean); } } } } else { // significantly worse ? if (testers[j][k].differencesSignificance != 0) { if (testers[j][k].differencesSignificance > 0) { eliminated[j] = true; eliminatedCount++; if (m_debug) { System.err.println("Eliminating (-worse) " +j+" "+raceBitSets[j].toString() +" vs "+k+" " +raceBitSets[k].toString() +" after " +processedCount +" evaluations" +"\nerror "+j+" : " +testers[j][k].xStats.mean +" vs "+k+" : " +testers[j][k].yStats.mean); } break; } else { eliminated[k] = true; eliminatedCount++; if (m_debug) { System.err.println("Eliminating (worse) " +k+" "+raceBitSets[k].toString() +" vs "+j+" " +raceBitSets[j].toString() +" after " +processedCount +" evaluations" +"\nerror "+k+" : " +testers[j][k].yStats.mean +" vs "+j+" : " +testers[j][k].xStats.mean); } } } } } } } } } // if there is a base set from the previous race and it's the // only remaining subset then terminate the race. if (eliminatedCount == raceSets.length-1 && baseSetIncluded && !eliminated[0] && !m_rankingRequested) { break race; } } } if (m_debug) { System.err.println("*****eliminated count: "+eliminatedCount); } double bestError = Double.MAX_VALUE; int bestIndex=0; // return the index of the winner for (int i=startPt;i<raceSets.length;i++) { if (!eliminated[i]) { individualStats[i].calculateDerived(); if (m_debug) { System.err.println("Remaining error: "+raceBitSets[i].toString() +" "+individualStats[i].mean); } if (individualStats[i].mean < bestError) { bestError = individualStats[i].mean; bestIndex = i; } } } double [] retInfo = new double[2]; retInfo[0] = bestIndex; retInfo[1] = bestError; if (m_debug) { System.err.print("Best set from race : "); for (int i=0;i<m_numAttribs;i++) { if (raceSets[bestIndex][i] == '1') { System.err.print('1'); } else { System.err.print('0'); } } System.err.println(" :"+bestError+" Processed : "+(processedCount) +"\n"+individualStats[bestIndex].toString()); } return retInfo; } /** * Returns a string represenation * * @return a string representation */ public String toString() { StringBuffer text = new StringBuffer(); text.append("\tRaceSearch.\n\tRace type : "); switch (m_raceType) { case FORWARD_RACE: text.append("forward selection race\n\tBase set : no attributes"); break; case BACKWARD_RACE: text.append("backward elimination race\n\tBase set : all attributes"); break; case SCHEMATA_RACE: text.append("schemata race\n\tBase set : no attributes"); break; case RANK_RACE: text.append("rank race\n\tBase set : no attributes\n\t"); text.append("Attribute evaluator : " + getAttributeEvaluator().getClass().getName() +" "); if (m_ASEval instanceof OptionHandler) { String[] evaluatorOptions = new String[0]; evaluatorOptions = ((OptionHandler)m_ASEval).getOptions(); for (int i=0;i<evaluatorOptions.length;i++) { text.append(evaluatorOptions[i]+' '); } } text.append("\n"); text.append("\tAttribute ranking : \n"); int rlength = (int)(Math.log(m_Ranking.length) / Math.log(10) + 1); for (int i=0;i<m_Ranking.length;i++) { text.append("\t "+Utils.doubleToString((double)(m_Ranking[i]+1), rlength,0) +" "+m_Instances.attribute(m_Ranking[i]).name()+'\n'); } break; } text.append("\n\tCross validation mode : "); if (m_xvalType == TEN_FOLD) { text.append("10 fold"); } else { text.append("Leave-one-out"); } text.append("\n\tMerit of best subset found : "); int fieldwidth = 3; double precision = (m_bestMerit - (int)m_bestMerit); if (Math.abs(m_bestMerit) > 0) { fieldwidth = (int)Math.abs((Math.log(Math.abs(m_bestMerit)) / Math.log(10)))+2; } if (Math.abs(precision) > 0) { precision = Math.abs((Math.log(Math.abs(precision)) / Math.log(10)))+3; } else { precision = 2; } text.append(Utils.doubleToString(Math.abs(m_bestMerit), fieldwidth+(int)precision, (int)precision)+"\n"); return text.toString(); } /** * Reset the search method. */ protected void resetOptions () { m_sigLevel = 0.001; m_delta = 0.001; m_ASEval = new GainRatioAttributeEval(); m_Ranking = null; m_raceType = FORWARD_RACE; m_debug = false; m_theEvaluator = null; m_bestMerit = -Double.MAX_VALUE; m_numFolds = 10; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -