📄 racesearch.java
字号:
if (m_rankingRequested) {
for (int i = 0; i < baseSet.length; i++) {
if (win.charAt(i) != bs.charAt(i)) {
m_rankedAtts[m_rankedSoFar][0] = i;
m_rankedAtts[m_rankedSoFar][1] = winnerInfo[1];
m_rankedSoFar++;
}
}
}
baseSet = (char [])raceSets[(int)winnerInfo[0]].clone();
} else {
// Will get here for a subset whose error is outside the delta
// threshold but is not *significantly* worse than the base
// subset
//throw new Exception("RaceSearch: problem in hillClimbRace");
}
}
}
return attributeList(baseSet);
}
/**
* Convert an attribute set to an array of indices
*/
private int [] attributeList(char [] list) {
int count = 0;
for (int i=0;i<m_numAttribs;i++) {
if (list[i] == '1') {
count++;
}
}
int [] rlist = new int[count];
count = 0;
for (int i=0;i<m_numAttribs;i++) {
if (list[i] == '1') {
rlist[count++] = i;
}
}
return rlist;
}
/**
* Races the leave-one-out cross validation errors of a set of
* attribute subsets on a set of instances.
* @param raceSets a set of attribute subset specifications
* @param data the instances to use when cross validating
* @param baseSetIncluded true if the first attribute set is a
* base set generated from the previous race
* @param random a random number generator
* @return the index of the winning subset
* @exception Exception if an error occurs during cross validation
*/
private double [] raceSubsets(char [][]raceSets, Instances data,
boolean baseSetIncluded, Random random)
throws Exception {
// the evaluators --- one for each subset
ASEvaluation [] evaluators =
ASEvaluation.makeCopies(m_theEvaluator, raceSets.length);
// array of subsets eliminated from the race
boolean [] eliminated = new boolean [raceSets.length];
// individual statistics
Stats [] individualStats = new Stats [raceSets.length];
// pairwise statistics
PairedStats [][] testers =
new PairedStats[raceSets.length][raceSets.length];
/** do we ignore the base set or not? */
int startPt = m_rankingRequested ? 1 : 0;
for (int i=0;i<raceSets.length;i++) {
individualStats[i] = new Stats();
for (int j=i+1;j<raceSets.length;j++) {
testers[i][j] = new PairedStats(m_sigLevel);
}
}
BitSet [] raceBitSets = new BitSet[raceSets.length];
for (int i=0;i<raceSets.length;i++) {
raceBitSets[i] = new BitSet(m_numAttribs);
for (int j=0;j<m_numAttribs;j++) {
if (raceSets[i][j] == '1') {
raceBitSets[i].set(j);
}
}
}
// now loop over the data points collecting leave-one-out errors for
// each attribute set
Instances trainCV;
Instances testCV;
Instance testInst;
double [] errors = new double [raceSets.length];
int eliminatedCount = 0;
int processedCount = 0;
// if there is one set left in the race then we need to continue to
// evaluate it for the remaining instances in order to get an
// accurate error estimate
Stats clearWinner = null;
int foldSize=1;
processedCount = 0;
race: for (int i=0;i<m_numFolds;i++) {
trainCV = data.trainCV(m_numFolds, i, random);
testCV = data.testCV(m_numFolds, i);
foldSize = testCV.numInstances();
// loop over the surviving attribute sets building classifiers for this
// training set
for (int j=startPt;j<raceSets.length;j++) {
if (!eliminated[j]) {
evaluators[j].buildEvaluator(trainCV);
}
}
for (int z=0;z<testCV.numInstances();z++) {
testInst = testCV.instance(z);
processedCount++;
// loop over surviving attribute sets computing errors for this
// test point
for (int zz=startPt;zz<raceSets.length;zz++) {
if (!eliminated[zz]) {
if (z == 0) {// first test instance---make sure classifier is built
errors[zz] = -((HoldOutSubsetEvaluator)evaluators[zz]).
evaluateSubset(raceBitSets[zz],
testInst,
true);
} else { // must be k fold rather than leave one out
errors[zz] = -((HoldOutSubsetEvaluator)evaluators[zz]).
evaluateSubset(raceBitSets[zz],
testInst,
false);
}
}
}
// now update the stats
for (int j=startPt;j<raceSets.length;j++) {
if (!eliminated[j]) {
individualStats[j].add(errors[j]);
for (int k=j+1;k<raceSets.length;k++) {
if (!eliminated[k]) {
testers[j][k].add(errors[j], errors[k]);
}
}
}
}
// test for near identical models and models that are significantly
// worse than some other model
if (processedCount > m_samples-1 &&
(eliminatedCount < raceSets.length-1)) {
for (int j=0;j<raceSets.length;j++) {
if (!eliminated[j]) {
for (int k=j+1;k<raceSets.length;k++) {
if (!eliminated[k]) {
testers[j][k].calculateDerived();
// near identical ?
if ((testers[j][k].differencesSignificance == 0) &&
(Utils.eq(testers[j][k].differencesStats.mean, 0.0) ||
(Utils.gr(m_delta, Math.abs(testers[j][k].
differencesStats.mean))))) {
// if they're exactly the same and there is a base set
// in this race, make sure that the base set is NOT the
// one eliminated.
if (Utils.eq(testers[j][k].differencesStats.mean, 0.0)) {
if (baseSetIncluded) {
if (j != 0) {
eliminated[j] = true;
} else {
eliminated[k] = true;
}
eliminatedCount++;
} else {
eliminated[j] = true;
}
if (m_debug) {
System.err.println("Eliminating (identical) "
+j+" "+raceBitSets[j].toString()
+" vs "+k+" "
+raceBitSets[k].toString()
+" after "
+processedCount
+" evaluations\n"
+"\nerror "+j+" : "
+testers[j][k].xStats.mean
+" vs "+k+" : "
+testers[j][k].yStats.mean
+" diff : "
+testers[j][k].differencesStats
.mean);
}
} else {
// eliminate the one with the higer error
if (testers[j][k].xStats.mean >
testers[j][k].yStats.mean) {
eliminated[j] = true;
eliminatedCount++;
if (m_debug) {
System.err.println("Eliminating (near identical) "
+j+" "+raceBitSets[j].toString()
+" vs "+k+" "
+raceBitSets[k].toString()
+" after "
+processedCount
+" evaluations\n"
+"\nerror "+j+" : "
+testers[j][k].xStats.mean
+" vs "+k+" : "
+testers[j][k].yStats.mean
+" diff : "
+testers[j][k].differencesStats
.mean);
}
break;
} else {
eliminated[k] = true;
eliminatedCount++;
if (m_debug) {
System.err.println("Eliminating (near identical) "
+k+" "+raceBitSets[k].toString()
+" vs "+j+" "
+raceBitSets[j].toString()
+" after "
+processedCount
+" evaluations\n"
+"\nerror "+k+" : "
+testers[j][k].yStats.mean
+" vs "+j+" : "
+testers[j][k].xStats.mean
+" diff : "
+testers[j][k].differencesStats
.mean);
}
}
}
} else {
// significantly worse ?
if (testers[j][k].differencesSignificance != 0) {
if (testers[j][k].differencesSignificance > 0) {
eliminated[j] = true;
eliminatedCount++;
if (m_debug) {
System.err.println("Eliminating (-worse) "
+j+" "+raceBitSets[j].toString()
+" vs "+k+" "
+raceBitSets[k].toString()
+" after "
+processedCount
+" evaluations"
+"\nerror "+j+" : "
+testers[j][k].xStats.mean
+" vs "+k+" : "
+testers[j][k].yStats.mean);
}
break;
} else {
eliminated[k] = true;
eliminatedCount++;
if (m_debug) {
System.err.println("Eliminating (worse) "
+k+" "+raceBitSets[k].toString()
+" vs "+j+" "
+raceBitSets[j].toString()
+" after "
+processedCount
+" evaluations"
+"\nerror "+k+" : "
+testers[j][k].yStats.mean
+" vs "+j+" : "
+testers[j][k].xStats.mean);
}
}
}
}
}
}
}
}
}
// if there is a base set from the previous race and it's the
// only remaining subset then terminate the race.
if (eliminatedCount == raceSets.length-1 && baseSetIncluded &&
!eliminated[0] && !m_rankingRequested) {
break race;
}
}
}
if (m_debug) {
System.err.println("*****eliminated count: "+eliminatedCount);
}
double bestError = Double.MAX_VALUE;
int bestIndex=0;
// return the index of the winner
for (int i=startPt;i<raceSets.length;i++) {
if (!eliminated[i]) {
individualStats[i].calculateDerived();
if (m_debug) {
System.err.println("Remaining error: "+raceBitSets[i].toString()
+" "+individualStats[i].mean);
}
if (individualStats[i].mean < bestError) {
bestError = individualStats[i].mean;
bestIndex = i;
}
}
}
double [] retInfo = new double[2];
retInfo[0] = bestIndex;
retInfo[1] = bestError;
if (m_debug) {
System.err.print("Best set from race : ");
for (int i=0;i<m_numAttribs;i++) {
if (raceSets[bestIndex][i] == '1') {
System.err.print('1');
} else {
System.err.print('0');
}
}
System.err.println(" :"+bestError+" Processed : "+(processedCount)
+"\n"+individualStats[bestIndex].toString());
}
return retInfo;
}
public String toString() {
StringBuffer text = new StringBuffer();
text.append("\tRaceSearch.\n\tRace type : ");
switch (m_raceType) {
case FORWARD_RACE:
text.append("forward selection race\n\tBase set : no attributes");
break;
case BACKWARD_RACE:
text.append("backward elimination race\n\tBase set : all attributes");
break;
case SCHEMATA_RACE:
text.append("schemata race\n\tBase set : no attributes");
break;
case RANK_RACE:
text.append("rank race\n\tBase set : no attributes\n\t");
text.append("Attribute evaluator : "
+ getAttributeEvaluator().getClass().getName() +" ");
if (m_ASEval instanceof OptionHandler) {
String[] evaluatorOptions = new String[0];
evaluatorOptions = ((OptionHandler)m_ASEval).getOptions();
for (int i=0;i<evaluatorOptions.length;i++) {
text.append(evaluatorOptions[i]+' ');
}
}
text.append("\n");
text.append("\tAttribute ranking : \n");
int rlength = (int)(Math.log(m_Ranking.length) / Math.log(10) + 1);
for (int i=0;i<m_Ranking.length;i++) {
text.append("\t "+Utils.doubleToString((double)(m_Ranking[i]+1),
rlength,0)
+" "+m_Instances.attribute(m_Ranking[i]).name()+'\n');
}
break;
}
text.append("\n\tCross validation mode : ");
if (m_xvalType == TEN_FOLD) {
text.append("10 fold");
} else {
text.append("Leave-one-out");
}
text.append("\n\tMerit of best subset found : ");
int fieldwidth = 3;
double precision = (m_bestMerit - (int)m_bestMerit);
if (Math.abs(m_bestMerit) > 0) {
fieldwidth = (int)Math.abs((Math.log(Math.abs(m_bestMerit)) /
Math.log(10)))+2;
}
if (Math.abs(precision) > 0) {
precision = Math.abs((Math.log(Math.abs(precision)) / Math.log(10)))+3;
} else {
precision = 2;
}
text.append(Utils.doubleToString(Math.abs(m_bestMerit),
fieldwidth+(int)precision,
(int)precision)+"\n");
return text.toString();
}
/**
* Reset the search method.
*/
protected void resetOptions () {
m_sigLevel = 0.001;
m_delta = 0.001;
m_ASEval = new GainRatioAttributeEval();
m_Ranking = null;
m_raceType = FORWARD_RACE;
m_debug = false;
m_theEvaluator = null;
m_bestMerit = -Double.MAX_VALUE;
m_numFolds = 10;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -