📄 decisiontable.java
字号:
* @param instA feature values of the selected features for the instance * @return the classification of the instance */ double classifyInstanceLeaveOneOut(Instance instance, double [] instA) throws Exception { hashKey thekey; double [] tempDist; double [] normDist; thekey = new hashKey(instA); if (m_classIsNominal) { // if this one is not in the table if ((tempDist = (double [])m_entries.get(thekey)) == null) { throw new Error("This should never happen!"); } else { normDist = new double [tempDist.length]; System.arraycopy(tempDist,0,normDist,0,tempDist.length); normDist[(int)instance.classValue()] -= instance.weight(); // update the table // first check to see if the class counts are all zero now boolean ok = false; for (int i=0;i<normDist.length;i++) { if (!Utils.eq(normDist[i],0.0)) { ok = true; break; } } if (ok) { Utils.normalize(normDist); return Utils.maxIndex(normDist); } else { return m_majority; } } // return Utils.maxIndex(tempDist); } else { // see if this one is already in the table if ((tempDist = (double[])m_entries.get(thekey)) != null) { normDist = new double [tempDist.length]; System.arraycopy(tempDist,0,normDist,0,tempDist.length); normDist[0] -= (instance.classValue() * instance.weight()); normDist[1] -= instance.weight(); if (Utils.eq(normDist[1],0.0)) { return m_majority; } else { return (normDist[0] / normDist[1]); } } else { throw new Error("This should never happen!"); } } // shouldn't get here // return 0.0; } /** * Calculates the accuracy on a test fold for internal cross validation * of feature sets * * @param fold set of instances to be "left out" and classified * @param fs currently selected feature set * @return the accuracy for the fold */ double classifyFoldCV(Instances fold, int [] fs) throws Exception { int i; int ruleCount = 0; int numFold = fold.numInstances(); int numCl = m_theInstances.classAttribute().numValues(); double [][] class_distribs = new double [numFold][numCl]; double [] instA = new double [fs.length]; double [] normDist; hashKey thekey; double acc = 0.0; int classI = m_theInstances.classIndex(); Instance inst; if (m_classIsNominal) { normDist = new double [numCl]; } else { normDist = new double [2]; } // first *remove* instances for (i=0;i<numFold;i++) { inst = fold.instance(i); for (int j=0;j<fs.length;j++) { if (fs[j] == classI) { instA[j] = Double.MAX_VALUE; // missing for the class } else if (inst.isMissing(fs[j])) { instA[j] = Double.MAX_VALUE; } else{ instA[j] = inst.value(fs[j]); } } thekey = new hashKey(instA); if ((class_distribs[i] = (double [])m_entries.get(thekey)) == null) { throw new Error("This should never happen!"); } else { if (m_classIsNominal) { class_distribs[i][(int)inst.classValue()] -= inst.weight(); } else { class_distribs[i][0] -= (inst.classValue() * inst.weight()); class_distribs[i][1] -= inst.weight(); } ruleCount++; } } // now classify instances for (i=0;i<numFold;i++) { inst = fold.instance(i); System.arraycopy(class_distribs[i],0,normDist,0,normDist.length); if (m_classIsNominal) { boolean ok = false; for (int j=0;j<normDist.length;j++) { if (!Utils.eq(normDist[j],0.0)) { ok = true; break; } } if (ok) { Utils.normalize(normDist); if (Utils.maxIndex(normDist) == inst.classValue()) acc += inst.weight(); } else { if (inst.classValue() == m_majority) { acc += inst.weight(); } } } else { if (Utils.eq(normDist[1],0.0)) { acc += ((inst.weight() * (m_majority - inst.classValue())) * (inst.weight() * (m_majority - inst.classValue()))); } else { double t = (normDist[0] / normDist[1]); acc += ((inst.weight() * (t - inst.classValue())) * (inst.weight() * (t - inst.classValue()))); } } } // now re-insert instances for (i=0;i<numFold;i++) { inst = fold.instance(i); if (m_classIsNominal) { class_distribs[i][(int)inst.classValue()] += inst.weight(); } else { class_distribs[i][0] += (inst.classValue() * inst.weight()); class_distribs[i][1] += inst.weight(); } } return acc; } /** * Evaluates a feature subset by cross validation * * @param feature_set the subset to be evaluated * @param num_atts the number of attributes in the subset * @return the estimated accuracy * @exception Exception if subset can't be evaluated */ private double estimateAccuracy(BitSet feature_set, int num_atts) throws Exception { int i; Instances newInstances; int [] fs = new int [num_atts]; double acc = 0.0; double [][] evalArray; double [] instA = new double [num_atts]; int classI = m_theInstances.classIndex(); int index = 0; for (i=0;i<m_numAttributes;i++) { if (feature_set.get(i)) { fs[index++] = i; } } // create new hash table m_entries = new Hashtable((int)(m_theInstances.numInstances() * 1.5)); // insert instances into the hash table for (i=0;i<m_numInstances;i++) { Instance inst = m_theInstances.instance(i); for (int j=0;j<fs.length;j++) { if (fs[j] == classI) { instA[j] = Double.MAX_VALUE; // missing for the class } else if (inst.isMissing(fs[j])) { instA[j] = Double.MAX_VALUE; } else { instA[j] = inst.value(fs[j]); } } insertIntoTable(inst, instA); } if (m_CVFolds == 1) { // calculate leave one out error for (i=0;i<m_numInstances;i++) { Instance inst = m_theInstances.instance(i); for (int j=0;j<fs.length;j++) { if (fs[j] == classI) { instA[j] = Double.MAX_VALUE; // missing for the class } else if (inst.isMissing(fs[j])) { instA[j] = Double.MAX_VALUE; } else { instA[j] = inst.value(fs[j]); } } double t = classifyInstanceLeaveOneOut(inst, instA); if (m_classIsNominal) { if (t == inst.classValue()) { acc+=inst.weight(); } } else { acc += ((inst.weight() * (t - inst.classValue())) * (inst.weight() * (t - inst.classValue()))); } // weight_sum += inst.weight(); } } else { m_theInstances.randomize(m_rr); m_theInstances.stratify(m_CVFolds); // calculate 10 fold cross validation error for (i=0;i<m_CVFolds;i++) { Instances insts = m_theInstances.testCV(m_CVFolds,i); acc += classifyFoldCV(insts, fs); } } if (m_classIsNominal) { return (acc / m_theInstances.sumOfWeights()); } else { return -(Math.sqrt(acc / m_theInstances.sumOfWeights())); } } /** * Returns a String representation of a feature subset * * @param sub BitSet representation of a subset * @return String containing subset */ private String printSub(BitSet sub) { int i; String s=""; for (int jj=0;jj<m_numAttributes;jj++) { if (sub.get(jj)) { s += " "+(jj+1); } } return s; } /** * Does a best first search */ private void best_first() throws Exception { int i,j,classI,count=0,fc,tree_count=0; int evals=0; BitSet best_group, temp_group; int [] stale; double [] best_merit; double merit; boolean z; boolean added; Link tl; Hashtable lookup = new Hashtable((int)(200.0*m_numAttributes*1.5)); LinkedList bfList = new LinkedList(); best_merit = new double[1]; best_merit[0] = 0.0; stale = new int[1]; stale[0] = 0; best_group = new BitSet(m_numAttributes); // Add class to initial subset classI = m_theInstances.classIndex(); best_group.set(classI); best_merit[0] = estimateAccuracy(best_group, 1); if (m_debug) System.out.println("Accuracy of initial subset: "+best_merit[0]); // add the initial group to the list bfList.addToList(best_group,best_merit[0]); // add initial subset to the hashtable lookup.put(best_group,""); while (stale[0] < m_maxStale) { added = false; // finished search? if (bfList.size()==0) { stale[0] = m_maxStale; break; } // copy the feature set at the head of the list tl = bfList.getLinkAt(0); temp_group = (BitSet)(tl.getGroup().clone()); // remove the head of the list bfList.removeLinkAt(0); for (i=0;i<m_numAttributes;i++) { // if (search_direction == 1) z = ((i != classI) && (!temp_group.get(i))); if (z) { // set the bit (feature to add/delete) */ temp_group.set(i); /* if this subset has been seen before, then it is already in the list (or has been fully expanded) */ BitSet tt = (BitSet)temp_group.clone(); if (lookup.containsKey(tt) == false) { fc = 0; for (int jj=0;jj<m_numAttributes;jj++) { if (tt.get(jj)) { fc++; } } merit = estimateAccuracy(tt, fc); if (m_debug) { System.out.println("evaluating: "+printSub(tt)+" "+merit); } // is this better than the best? // if (search_direction == 1) z = ((merit - best_merit[0]) > 0.00001); // else // z = ((best_merit[0] - merit) > 0.00001); if (z) { if (m_debug) { System.out.println("new best feature set: "+printSub(tt)+ " "+merit); } added = true; stale[0] = 0; best_merit[0] = merit; best_group = (BitSet)(temp_group.clone()); } // insert this one in the list and the hash table bfList.addToList(tt, merit); lookup.put(tt,""); count++; } // unset this addition(deletion) temp_group.clear(i); } } /* if we haven't added a new feature subset then full expansion of this node hasn't resulted in anything better */ if (!added) { stale[0]++; } } // set selected features for (i=0,j=0;i<m_numAttributes;i++) { if (best_group.get(i)) { j++; } } m_decisionFeatures = new int[j]; for (i=0,j=0;i<m_numAttributes;i++) { if (best_group.get(i)) { m_decisionFeatures[j++] = i; } } } /** * Resets the options. */ protected void resetOptions() { m_entries = null; m_decisionFeatures = null; m_debug = false; m_useIBk = false; m_CVFolds = 1; m_maxStale = 5; m_displayRules = false; } /** * Constructor for a DecisionTable */ public DecisionTable() { resetOptions(); } /** * Returns an enumeration describing the available options * * @return an enumeration of all the available options */ public Enumeration listOptions() { Vector newVector = new Vector(5); newVector.addElement(new Option( "\tNumber of fully expanded non improving subsets to consider\n" + "\tbefore terminating a best first search.\n" + "\tUse in conjunction with -B. (Default = 5)", "S", 1, "-S <number of non improving nodes>")); newVector.addElement(new Option( "\tUse cross validation to evaluate features.\n" + "\tUse number of folds = 1 for leave one out CV.\n" + "\t(Default = leave one out CV)", "X", 1, "-X <number of folds>")); newVector.addElement(new Option( "\tUse nearest neighbour instead of global table majority.\n", "I", 0, "-I")); newVector.addElement(new Option( "\tDisplay decision table rules.\n", "R", 0, "-R")); return newVector.elements(); } /** * Sets the number of folds for cross validation (1 = leave one out) * * @param folds the number of folds */ public void setCrossVal(int folds) { m_CVFolds = folds;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -