📄 decisiontable.java
字号:
/**
* Classifies an instance for internal leave one out cross validation
* of feature sets
*
* @param instance instance to be "left out" and classified
* @param instA feature values of the selected features for the instance
* @return the classification of the instance
*/
double classifyInstanceLeaveOneOut(Instance instance, double [] instA)
throws Exception {
hashKey thekey;
double [] tempDist;
double [] normDist;
thekey = new hashKey(instA);
if (m_classIsNominal) {
// if this one is not in the table
if ((tempDist = (double [])m_entries.get(thekey)) == null) {
throw new Error("This should never happen!");
} else {
normDist = new double [tempDist.length];
System.arraycopy(tempDist,0,normDist,0,tempDist.length);
normDist[(int)instance.classValue()] -= instance.weight();
// update the table
// first check to see if the class counts are all zero now
boolean ok = false;
for (int i=0;i<normDist.length;i++) {
if (!Utils.eq(normDist[i],0.0)) {
ok = true;
break;
}
}
if (ok) {
Utils.normalize(normDist);
return Utils.maxIndex(normDist);
} else {
return m_majority;
}
}
// return Utils.maxIndex(tempDist);
} else {
// see if this one is already in the table
if ((tempDist = (double[])m_entries.get(thekey)) != null) {
normDist = new double [tempDist.length];
System.arraycopy(tempDist,0,normDist,0,tempDist.length);
normDist[0] -= (instance.classValue() * instance.weight());
normDist[1] -= instance.weight();
if (Utils.eq(normDist[1],0.0)) {
return m_majority;
} else {
return (normDist[0] / normDist[1]);
}
} else {
throw new Error("This should never happen!");
}
}
// shouldn't get here
// return 0.0;
}
/**
* Calculates the accuracy on a test fold for internal cross validation
* of feature sets
*
* @param fold set of instances to be "left out" and classified
* @param fs currently selected feature set
* @return the accuracy for the fold
*/
double classifyFoldCV(Instances fold, int [] fs) throws Exception {
int i;
int ruleCount = 0;
int numFold = fold.numInstances();
int numCl = m_theInstances.classAttribute().numValues();
double [][] class_distribs = new double [numFold][numCl];
double [] instA = new double [fs.length];
double [] normDist;
hashKey thekey;
double acc = 0.0;
int classI = m_theInstances.classIndex();
Instance inst;
if (m_classIsNominal) {
normDist = new double [numCl];
} else {
normDist = new double [2];
}
// first *remove* instances
for (i=0;i<numFold;i++) {
inst = fold.instance(i);
for (int j=0;j<fs.length;j++) {
if (fs[j] == classI) {
instA[j] = Double.MAX_VALUE; // missing for the class
} else if (inst.isMissing(fs[j])) {
instA[j] = Double.MAX_VALUE;
} else{
instA[j] = inst.value(fs[j]);
}
}
thekey = new hashKey(instA);
if ((class_distribs[i] = (double [])m_entries.get(thekey)) == null) {
throw new Error("This should never happen!");
} else {
if (m_classIsNominal) {
class_distribs[i][(int)inst.classValue()] -= inst.weight();
} else {
class_distribs[i][0] -= (inst.classValue() * inst.weight());
class_distribs[i][1] -= inst.weight();
}
ruleCount++;
}
}
// now classify instances
for (i=0;i<numFold;i++) {
inst = fold.instance(i);
System.arraycopy(class_distribs[i],0,normDist,0,normDist.length);
if (m_classIsNominal) {
boolean ok = false;
for (int j=0;j<normDist.length;j++) {
if (!Utils.eq(normDist[j],0.0)) {
ok = true;
break;
}
}
if (ok) {
Utils.normalize(normDist);
if (Utils.maxIndex(normDist) == inst.classValue())
acc += inst.weight();
} else {
if (inst.classValue() == m_majority) {
acc += inst.weight();
}
}
} else {
if (Utils.eq(normDist[1],0.0)) {
acc += ((inst.weight() * (m_majority - inst.classValue())) *
(inst.weight() * (m_majority - inst.classValue())));
} else {
double t = (normDist[0] / normDist[1]);
acc += ((inst.weight() * (t - inst.classValue())) *
(inst.weight() * (t - inst.classValue())));
}
}
}
// now re-insert instances
for (i=0;i<numFold;i++) {
inst = fold.instance(i);
if (m_classIsNominal) {
class_distribs[i][(int)inst.classValue()] += inst.weight();
} else {
class_distribs[i][0] += (inst.classValue() * inst.weight());
class_distribs[i][1] += inst.weight();
}
}
return acc;
}
/**
* Evaluates a feature subset by cross validation
*
* @param feature_set the subset to be evaluated
* @param num_atts the number of attributes in the subset
* @return the estimated accuracy
* @exception Exception if subset can't be evaluated
*/
private double estimateAccuracy(BitSet feature_set, int num_atts)
throws Exception {
int i;
Instances newInstances;
int [] fs = new int [num_atts];
double acc = 0.0;
double [][] evalArray;
double [] instA = new double [num_atts];
int classI = m_theInstances.classIndex();
int index = 0;
for (i=0;i<m_numAttributes;i++) {
if (feature_set.get(i)) {
fs[index++] = i;
}
}
// create new hash table
m_entries = new Hashtable((int)(m_theInstances.numInstances() * 1.5));
// insert instances into the hash table
for (i=0;i<m_numInstances;i++) {
Instance inst = m_theInstances.instance(i);
for (int j=0;j<fs.length;j++) {
if (fs[j] == classI) {
instA[j] = Double.MAX_VALUE; // missing for the class
} else if (inst.isMissing(fs[j])) {
instA[j] = Double.MAX_VALUE;
} else {
instA[j] = inst.value(fs[j]);
}
}
insertIntoTable(inst, instA);
}
if (m_CVFolds == 1) {
// calculate leave one out error
for (i=0;i<m_numInstances;i++) {
Instance inst = m_theInstances.instance(i);
for (int j=0;j<fs.length;j++) {
if (fs[j] == classI) {
instA[j] = Double.MAX_VALUE; // missing for the class
} else if (inst.isMissing(fs[j])) {
instA[j] = Double.MAX_VALUE;
} else {
instA[j] = inst.value(fs[j]);
}
}
double t = classifyInstanceLeaveOneOut(inst, instA);
if (m_classIsNominal) {
if (t == inst.classValue()) {
acc+=inst.weight();
}
} else {
acc += ((inst.weight() * (t - inst.classValue())) *
(inst.weight() * (t - inst.classValue())));
}
// weight_sum += inst.weight();
}
} else {
m_theInstances.randomize(m_rr);
m_theInstances.stratify(m_CVFolds);
// calculate 10 fold cross validation error
for (i=0;i<m_CVFolds;i++) {
Instances insts = m_theInstances.testCV(m_CVFolds,i);
acc += classifyFoldCV(insts, fs);
}
}
if (m_classIsNominal) {
return (acc / m_theInstances.sumOfWeights());
} else {
return -(Math.sqrt(acc / m_theInstances.sumOfWeights()));
}
}
/**
* Returns a String representation of a feature subset
*
* @param sub BitSet representation of a subset
* @return String containing subset
*/
private String printSub(BitSet sub) {
int i;
String s="";
for (int jj=0;jj<m_numAttributes;jj++) {
if (sub.get(jj)) {
s += " "+(jj+1);
}
}
return s;
}
/**
* Does a best first search
*/
private void best_first() throws Exception {
int i,j,classI,count=0,fc,tree_count=0;
int evals=0;
BitSet best_group, temp_group;
int [] stale;
double [] best_merit;
double merit;
boolean z;
boolean added;
Link tl;
Hashtable lookup = new Hashtable((int)(200.0*m_numAttributes*1.5));
LinkedList bfList = new LinkedList();
best_merit = new double[1]; best_merit[0] = 0.0;
stale = new int[1]; stale[0] = 0;
best_group = new BitSet(m_numAttributes);
// Add class to initial subset
classI = m_theInstances.classIndex();
best_group.set(classI);
best_merit[0] = estimateAccuracy(best_group, 1);
if (m_debug)
System.out.println("Accuracy of initial subset: "+best_merit[0]);
// add the initial group to the list
bfList.addToList(best_group,best_merit[0]);
// add initial subset to the hashtable
lookup.put(best_group,"");
while (stale[0] < m_maxStale) {
added = false;
// finished search?
if (bfList.size()==0) {
stale[0] = m_maxStale;
break;
}
// copy the feature set at the head of the list
tl = bfList.getLinkAt(0);
temp_group = (BitSet)(tl.getGroup().clone());
// remove the head of the list
bfList.removeLinkAt(0);
for (i=0;i<m_numAttributes;i++) {
// if (search_direction == 1)
z = ((i != classI) && (!temp_group.get(i)));
if (z) {
// set the bit (feature to add/delete) */
temp_group.set(i);
/* if this subset has been seen before, then it is already in
the list (or has been fully expanded) */
BitSet tt = (BitSet)temp_group.clone();
if (lookup.containsKey(tt) == false) {
fc = 0;
for (int jj=0;jj<m_numAttributes;jj++) {
if (tt.get(jj)) {
fc++;
}
}
merit = estimateAccuracy(tt, fc);
if (m_debug) {
System.out.println("evaluating: "+printSub(tt)+" "+merit);
}
// is this better than the best?
// if (search_direction == 1)
z = ((merit - best_merit[0]) > 0.00001);
// else
// z = ((best_merit[0] - merit) > 0.00001);
if (z) {
if (m_debug) {
System.out.println("new best feature set: "+printSub(tt)+
" "+merit);
}
added = true;
stale[0] = 0;
best_merit[0] = merit;
best_group = (BitSet)(temp_group.clone());
}
// insert this one in the list and the hash table
bfList.addToList(tt, merit);
lookup.put(tt,"");
count++;
}
// unset this addition(deletion)
temp_group.clear(i);
}
}
/* if we haven't added a new feature subset then full expansion
of this node hasn't resulted in anything better */
if (!added) {
stale[0]++;
}
}
// set selected features
for (i=0,j=0;i<m_numAttributes;i++) {
if (best_group.get(i)) {
j++;
}
}
m_decisionFeatures = new int[j];
for (i=0,j=0;i<m_numAttributes;i++) {
if (best_group.get(i)) {
m_decisionFeatures[j++] = i;
}
}
}
/**
* Resets the options.
*/
protected void resetOptions() {
m_entries = null;
m_decisionFeatures = null;
m_debug = false;
m_useIBk = false;
m_CVFolds = 1;
m_maxStale = 5;
m_displayRules = false;
}
/**
* Constructor for a DecisionTable
*/
public DecisionTable() {
resetOptions();
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(5);
newVector.addElement(new Option(
"\tNumber of fully expanded non improving subsets to consider\n" +
"\tbefore terminating a best first search.\n" +
"\tUse in conjunction with -B. (Default = 5)",
"S", 1, "-S <number of non improving nodes>"));
newVector.addElement(new Option(
"\tUse cross validation to evaluate features.\n" +
"\tUse number of folds = 1 for leave one out CV.\n" +
"\t(Default = leave one out CV)",
"X", 1, "-X <number of folds>"));
newVector.addElement(new Option(
"\tUse nearest neighbour instead of global table majority.\n",
"I", 0, "-I"));
newVector.addElement(new Option(
"\tDisplay decision table rules.\n",
"R", 0, "-R"));
return newVector.elements();
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String crossValTipText() {
return "Sets the number of folds for cross validation (1 = leave one out).";
}
/**
* Sets the number of folds for cross validation (1 = leave one out)
*
* @param folds the number of folds
*/
public void setCrossVal(int folds) {
m_CVFolds = folds;
}
/**
* Gets the number of folds for cross validation
*
* @return the number of cross validation folds
*/
public int getCrossVal() {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -