⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 decisiontable.java

📁 :<<数据挖掘--实用机器学习技术及java实现>>一书的配套源程序
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
   * @param instA feature values of the selected features for the instance   * @return the classification of the instance   */  double classifyInstanceLeaveOneOut(Instance instance, double [] instA)       throws Exception {    hashKey thekey;    double [] tempDist;    double [] normDist;    thekey = new hashKey(instA);    if (m_classIsNominal) {      // if this one is not in the table      if ((tempDist = (double [])m_entries.get(thekey)) == null) {	throw new Error("This should never happen!");      } else {	normDist = new double [tempDist.length];	System.arraycopy(tempDist,0,normDist,0,tempDist.length);	normDist[(int)instance.classValue()] -= instance.weight();	// update the table	// first check to see if the class counts are all zero now	boolean ok = false;	for (int i=0;i<normDist.length;i++) {	  if (!Utils.eq(normDist[i],0.0)) {	    ok = true;	    break;	  }	}	if (ok) {	  Utils.normalize(normDist);	  return Utils.maxIndex(normDist);	} else {	  return m_majority;	}      }      //      return Utils.maxIndex(tempDist);    } else {      // see if this one is already in the table      if ((tempDist = (double[])m_entries.get(thekey)) != null) {	normDist = new double [tempDist.length];	System.arraycopy(tempDist,0,normDist,0,tempDist.length);	normDist[0] -= (instance.classValue() * instance.weight());	normDist[1] -= instance.weight();	if (Utils.eq(normDist[1],0.0)) {	    return m_majority;	} else {	  return (normDist[0] / normDist[1]);	}      } else {	throw new Error("This should never happen!");      }    }        // shouldn't get here     // return 0.0;  }  /**   * Calculates the accuracy on a test fold for internal cross validation   * of feature sets   *   * @param fold set of instances to be "left out" and classified   * @param fs currently selected feature set   * @return the accuracy for the fold   */  double classifyFoldCV(Instances fold, int [] fs) throws Exception {    int i;    int ruleCount = 0;    int numFold = fold.numInstances();    int numCl = m_theInstances.classAttribute().numValues();    double [][] class_distribs = new double [numFold][numCl];    double [] instA = new double [fs.length];    double [] normDist;    hashKey thekey;    double acc = 0.0;    int classI = m_theInstances.classIndex();    Instance inst;    if (m_classIsNominal) {      normDist = new double [numCl];    } else {      normDist = new double [2];    }    // first *remove* instances    for (i=0;i<numFold;i++) {      inst = fold.instance(i);      for (int j=0;j<fs.length;j++) {	if (fs[j] == classI) {	  instA[j] = Double.MAX_VALUE; // missing for the class	} else if (inst.isMissing(fs[j])) {	  instA[j] = Double.MAX_VALUE;	} else{	  instA[j] = inst.value(fs[j]);	}      }      thekey = new hashKey(instA);      if ((class_distribs[i] = (double [])m_entries.get(thekey)) == null) {	throw new Error("This should never happen!");      } else {	if (m_classIsNominal) {	  class_distribs[i][(int)inst.classValue()] -= inst.weight();	} else {	  class_distribs[i][0] -= (inst.classValue() * inst.weight());	  class_distribs[i][1] -= inst.weight();	}	ruleCount++;      }    }    // now classify instances    for (i=0;i<numFold;i++) {      inst = fold.instance(i);      System.arraycopy(class_distribs[i],0,normDist,0,normDist.length);      if (m_classIsNominal) {	boolean ok = false;	for (int j=0;j<normDist.length;j++) {	  if (!Utils.eq(normDist[j],0.0)) {	    ok = true;	    break;	  }	}	if (ok) {	  Utils.normalize(normDist);	  if (Utils.maxIndex(normDist) == inst.classValue())	    acc += inst.weight();	} else {	  if (inst.classValue() == m_majority) {	    acc += inst.weight();	  }	}      } else {	if (Utils.eq(normDist[1],0.0)) {	    acc += ((inst.weight() * (m_majority - inst.classValue())) * 		    (inst.weight() * (m_majority - inst.classValue())));	} else {	  double t = (normDist[0] / normDist[1]);	  acc += ((inst.weight() * (t - inst.classValue())) * 		  (inst.weight() * (t - inst.classValue())));	}      }    }    // now re-insert instances    for (i=0;i<numFold;i++) {      inst = fold.instance(i);      if (m_classIsNominal) {	class_distribs[i][(int)inst.classValue()] += inst.weight();      } else {	class_distribs[i][0] += (inst.classValue() * inst.weight());	class_distribs[i][1] += inst.weight();      }    }    return acc;  }  /**   * Evaluates a feature subset by cross validation   *   * @param feature_set the subset to be evaluated   * @param num_atts the number of attributes in the subset   * @return the estimated accuracy   * @exception Exception if subset can't be evaluated   */  private double estimateAccuracy(BitSet feature_set, int num_atts)    throws Exception {    int i;    Instances newInstances;    int [] fs = new int [num_atts];    double acc = 0.0;    double [][] evalArray;    double [] instA = new double [num_atts];    int classI = m_theInstances.classIndex();        int index = 0;    for (i=0;i<m_numAttributes;i++) {      if (feature_set.get(i)) {	fs[index++] = i;      }    }    // create new hash table    m_entries = new Hashtable((int)(m_theInstances.numInstances() * 1.5));    // insert instances into the hash table    for (i=0;i<m_numInstances;i++) {      Instance inst = m_theInstances.instance(i);      for (int j=0;j<fs.length;j++) {	if (fs[j] == classI) {	  instA[j] = Double.MAX_VALUE; // missing for the class	} else if (inst.isMissing(fs[j])) {	  instA[j] = Double.MAX_VALUE;	} else {	  instA[j] = inst.value(fs[j]);	}      }      insertIntoTable(inst, instA);    }            if (m_CVFolds == 1) {      // calculate leave one out error      for (i=0;i<m_numInstances;i++) {	Instance inst = m_theInstances.instance(i);	for (int j=0;j<fs.length;j++) {	  if (fs[j] == classI) {	    instA[j] = Double.MAX_VALUE; // missing for the class	  } else if (inst.isMissing(fs[j])) {	    instA[j] = Double.MAX_VALUE;	  } else {	    instA[j] = inst.value(fs[j]);	  }	}	double t = classifyInstanceLeaveOneOut(inst, instA);	if (m_classIsNominal) {	  if (t == inst.classValue()) {	    acc+=inst.weight();	  }	} else {	  acc += ((inst.weight() * (t - inst.classValue())) * 		  (inst.weight() * (t - inst.classValue())));	}	// weight_sum += inst.weight();      }    } else {      m_theInstances.randomize(m_rr);      m_theInstances.stratify(m_CVFolds);      // calculate 10 fold cross validation error      for (i=0;i<m_CVFolds;i++) {	Instances insts = m_theInstances.testCV(m_CVFolds,i);	acc += classifyFoldCV(insts, fs);      }    }      if (m_classIsNominal) {      return (acc / m_theInstances.sumOfWeights());    } else {      return -(Math.sqrt(acc / m_theInstances.sumOfWeights()));       }  }  /**   * Returns a String representation of a feature subset   *   * @param sub BitSet representation of a subset   * @return String containing subset   */  private String printSub(BitSet sub) {    int i;    String s="";    for (int jj=0;jj<m_numAttributes;jj++) {      if (sub.get(jj)) {	s += " "+(jj+1);      }    }    return s;  }      /**   * Does a best first search    */  private void best_first() throws Exception {    int i,j,classI,count=0,fc,tree_count=0;    int evals=0;    BitSet best_group, temp_group;    int [] stale;    double [] best_merit;    double merit;    boolean z;    boolean added;    Link tl;      Hashtable lookup = new Hashtable((int)(200.0*m_numAttributes*1.5));    LinkedList bfList = new LinkedList();    best_merit = new double[1]; best_merit[0] = 0.0;    stale = new int[1]; stale[0] = 0;    best_group = new BitSet(m_numAttributes);    // Add class to initial subset    classI = m_theInstances.classIndex();    best_group.set(classI);    best_merit[0] = estimateAccuracy(best_group, 1);    if (m_debug)      System.out.println("Accuracy of initial subset: "+best_merit[0]);    // add the initial group to the list    bfList.addToList(best_group,best_merit[0]);    // add initial subset to the hashtable    lookup.put(best_group,"");    while (stale[0] < m_maxStale) {      added = false;      // finished search?      if (bfList.size()==0) {	stale[0] = m_maxStale;	break;      }      // copy the feature set at the head of the list      tl = bfList.getLinkAt(0);      temp_group = (BitSet)(tl.getGroup().clone());      // remove the head of the list      bfList.removeLinkAt(0);      for (i=0;i<m_numAttributes;i++) {	// if (search_direction == 1)	z = ((i != classI) && (!temp_group.get(i)));	if (z) {	  // set the bit (feature to add/delete) */	  temp_group.set(i);	  	  /* if this subset has been seen before, then it is already in 	     the list (or has been fully expanded) */	  BitSet tt = (BitSet)temp_group.clone();	  if (lookup.containsKey(tt) == false) {	    fc = 0;	    for (int jj=0;jj<m_numAttributes;jj++) {	      if (tt.get(jj)) {		fc++;	      }	    }	    merit = estimateAccuracy(tt, fc);	    if (m_debug) {	      System.out.println("evaluating: "+printSub(tt)+" "+merit); 	    }	    	    // is this better than the best?	    // if (search_direction == 1)	    z = ((merit - best_merit[0]) > 0.00001);	 	    // else	    // z = ((best_merit[0] - merit) > 0.00001);	    if (z) {	      if (m_debug) {		System.out.println("new best feature set: "+printSub(tt)+				   " "+merit);	      }	      added = true;	      stale[0] = 0;	      best_merit[0] = merit;	      best_group = (BitSet)(temp_group.clone());	    }	    // insert this one in the list and the hash table	    bfList.addToList(tt, merit);	    lookup.put(tt,"");	    count++;	  }	  // unset this addition(deletion)	  temp_group.clear(i);	}      }      /* if we haven't added a new feature subset then full expansion 	 of this node hasn't resulted in anything better */      if (!added) {	stale[0]++;      }    }       // set selected features    for (i=0,j=0;i<m_numAttributes;i++) {      if (best_group.get(i)) {	j++;      }    }        m_decisionFeatures = new int[j];    for (i=0,j=0;i<m_numAttributes;i++) {      if (best_group.get(i)) {	m_decisionFeatures[j++] = i;          }    }  }   /**   * Resets the options.   */  protected void resetOptions()  {    m_entries = null;    m_decisionFeatures = null;    m_debug = false;    m_useIBk = false;    m_CVFolds = 1;    m_maxStale = 5;    m_displayRules = false;  }   /**   * Constructor for a DecisionTable   */  public DecisionTable() {    resetOptions();  }  /**   * Returns an enumeration describing the available options   *   * @return an enumeration of all the available options   */  public Enumeration listOptions() {    Vector newVector = new Vector(5);    newVector.addElement(new Option(              "\tNumber of fully expanded non improving subsets to consider\n" +	      "\tbefore terminating a best first search.\n" +	      "\tUse in conjunction with -B. (Default = 5)",              "S", 1, "-S <number of non improving nodes>"));        newVector.addElement(new Option(              "\tUse cross validation to evaluate features.\n" +	      "\tUse number of folds = 1 for leave one out CV.\n" +	      "\t(Default = leave one out CV)",              "X", 1, "-X <number of folds>"));     newVector.addElement(new Option(              "\tUse nearest neighbour instead of global table majority.\n",              "I", 0, "-I"));     newVector.addElement(new Option(              "\tDisplay decision table rules.\n",              "R", 0, "-R"));     return newVector.elements();  }  /**   * Sets the number of folds for cross validation (1 = leave one out)   *   * @param folds the number of folds   */  public void setCrossVal(int folds) {    m_CVFolds = folds;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -