⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 decisiontable.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
  /**
   * Classifies an instance for internal leave one out cross validation
   * of feature sets
   *
   * @param instance instance to be "left out" and classified
   * @param instA feature values of the selected features for the instance
   * @return the classification of the instance
   */
  double classifyInstanceLeaveOneOut(Instance instance, double [] instA)
       throws Exception {

    hashKey thekey;
    double [] tempDist;
    double [] normDist;

    thekey = new hashKey(instA);
    if (m_classIsNominal) {

      // if this one is not in the table
      if ((tempDist = (double [])m_entries.get(thekey)) == null) {
	throw new Error("This should never happen!");
      } else {
	normDist = new double [tempDist.length];
	System.arraycopy(tempDist,0,normDist,0,tempDist.length);
	normDist[(int)instance.classValue()] -= instance.weight();

	// update the table
	// first check to see if the class counts are all zero now
	boolean ok = false;
	for (int i=0;i<normDist.length;i++) {
	  if (!Utils.eq(normDist[i],0.0)) {
	    ok = true;
	    break;
	  }
	}
	if (ok) {
	  Utils.normalize(normDist);
	  return Utils.maxIndex(normDist);
	} else {
	  return m_majority;
	}
      }
      //      return Utils.maxIndex(tempDist);
    } else {

      // see if this one is already in the table
      if ((tempDist = (double[])m_entries.get(thekey)) != null) {
	normDist = new double [tempDist.length];
	System.arraycopy(tempDist,0,normDist,0,tempDist.length);
	normDist[0] -= (instance.classValue() * instance.weight());
	normDist[1] -= instance.weight();
	if (Utils.eq(normDist[1],0.0)) {
	    return m_majority;
	} else {
	  return (normDist[0] / normDist[1]);
	}
      } else {
	throw new Error("This should never happen!");
      }
    }
    
    // shouldn't get here 
    // return 0.0;
  }

  /**
   * Calculates the accuracy on a test fold for internal cross validation
   * of feature sets
   *
   * @param fold set of instances to be "left out" and classified
   * @param fs currently selected feature set
   * @return the accuracy for the fold
   */
  double classifyFoldCV(Instances fold, int [] fs) throws Exception {

    int i;
    int ruleCount = 0;
    int numFold = fold.numInstances();
    int numCl = m_theInstances.classAttribute().numValues();
    double [][] class_distribs = new double [numFold][numCl];
    double [] instA = new double [fs.length];
    double [] normDist;
    hashKey thekey;
    double acc = 0.0;
    int classI = m_theInstances.classIndex();
    Instance inst;

    if (m_classIsNominal) {
      normDist = new double [numCl];
    } else {
      normDist = new double [2];
    }

    // first *remove* instances
    for (i=0;i<numFold;i++) {
      inst = fold.instance(i);
      for (int j=0;j<fs.length;j++) {
	if (fs[j] == classI) {
	  instA[j] = Double.MAX_VALUE; // missing for the class
	} else if (inst.isMissing(fs[j])) {
	  instA[j] = Double.MAX_VALUE;
	} else{
	  instA[j] = inst.value(fs[j]);
	}
      }
      thekey = new hashKey(instA);
      if ((class_distribs[i] = (double [])m_entries.get(thekey)) == null) {
	throw new Error("This should never happen!");
      } else {
	if (m_classIsNominal) {
	  class_distribs[i][(int)inst.classValue()] -= inst.weight();
	} else {
	  class_distribs[i][0] -= (inst.classValue() * inst.weight());
	  class_distribs[i][1] -= inst.weight();
	}
	ruleCount++;
      }
    }

    // now classify instances
    for (i=0;i<numFold;i++) {
      inst = fold.instance(i);
      System.arraycopy(class_distribs[i],0,normDist,0,normDist.length);
      if (m_classIsNominal) {
	boolean ok = false;
	for (int j=0;j<normDist.length;j++) {
	  if (!Utils.eq(normDist[j],0.0)) {
	    ok = true;
	    break;
	  }
	}
	if (ok) {
	  Utils.normalize(normDist);
	  if (Utils.maxIndex(normDist) == inst.classValue())
	    acc += inst.weight();
	} else {
	  if (inst.classValue() == m_majority) {
	    acc += inst.weight();
	  }
	}
      } else {
	if (Utils.eq(normDist[1],0.0)) {
	    acc += ((inst.weight() * (m_majority - inst.classValue())) * 
		    (inst.weight() * (m_majority - inst.classValue())));
	} else {
	  double t = (normDist[0] / normDist[1]);
	  acc += ((inst.weight() * (t - inst.classValue())) * 
		  (inst.weight() * (t - inst.classValue())));
	}
      }
    }

    // now re-insert instances
    for (i=0;i<numFold;i++) {
      inst = fold.instance(i);
      if (m_classIsNominal) {
	class_distribs[i][(int)inst.classValue()] += inst.weight();
      } else {
	class_distribs[i][0] += (inst.classValue() * inst.weight());
	class_distribs[i][1] += inst.weight();
      }
    }
    return acc;
  }


  /**
   * Evaluates a feature subset by cross validation
   *
   * @param feature_set the subset to be evaluated
   * @param num_atts the number of attributes in the subset
   * @return the estimated accuracy
   * @exception Exception if subset can't be evaluated
   */
  private double estimateAccuracy(BitSet feature_set, int num_atts)
    throws Exception {

    int i;
    Instances newInstances;
    int [] fs = new int [num_atts];
    double acc = 0.0;
    double [][] evalArray;
    double [] instA = new double [num_atts];
    int classI = m_theInstances.classIndex();
    
    int index = 0;
    for (i=0;i<m_numAttributes;i++) {
      if (feature_set.get(i)) {
	fs[index++] = i;
      }
    }

    // create new hash table
    m_entries = new Hashtable((int)(m_theInstances.numInstances() * 1.5));

    // insert instances into the hash table
    for (i=0;i<m_numInstances;i++) {

      Instance inst = m_theInstances.instance(i);
      for (int j=0;j<fs.length;j++) {
	if (fs[j] == classI) {
	  instA[j] = Double.MAX_VALUE; // missing for the class
	} else if (inst.isMissing(fs[j])) {
	  instA[j] = Double.MAX_VALUE;
	} else {
	  instA[j] = inst.value(fs[j]);
	}
      }
      insertIntoTable(inst, instA);
    }
    
    
    if (m_CVFolds == 1) {

      // calculate leave one out error
      for (i=0;i<m_numInstances;i++) {
	Instance inst = m_theInstances.instance(i);
	for (int j=0;j<fs.length;j++) {
	  if (fs[j] == classI) {
	    instA[j] = Double.MAX_VALUE; // missing for the class
	  } else if (inst.isMissing(fs[j])) {
	    instA[j] = Double.MAX_VALUE;
	  } else {
	    instA[j] = inst.value(fs[j]);
	  }
	}
	double t = classifyInstanceLeaveOneOut(inst, instA);
	if (m_classIsNominal) {
	  if (t == inst.classValue()) {
	    acc+=inst.weight();
	  }
	} else {
	  acc += ((inst.weight() * (t - inst.classValue())) * 
		  (inst.weight() * (t - inst.classValue())));
	}
	// weight_sum += inst.weight();
      }
    } else {
      m_theInstances.randomize(m_rr);
      m_theInstances.stratify(m_CVFolds);

      // calculate 10 fold cross validation error
      for (i=0;i<m_CVFolds;i++) {
	Instances insts = m_theInstances.testCV(m_CVFolds,i);
	acc += classifyFoldCV(insts, fs);
      }
    }
  
    if (m_classIsNominal) {
      return (acc / m_theInstances.sumOfWeights());
    } else {
      return -(Math.sqrt(acc / m_theInstances.sumOfWeights()));   
    }
  }

  /**
   * Returns a String representation of a feature subset
   *
   * @param sub BitSet representation of a subset
   * @return String containing subset
   */
  private String printSub(BitSet sub) {

    int i;

    String s="";
    for (int jj=0;jj<m_numAttributes;jj++) {
      if (sub.get(jj)) {
	s += " "+(jj+1);
      }
    }
    return s;
  }
    
  /**
   * Does a best first search 
   */
  private void best_first() throws Exception {

    int i,j,classI,count=0,fc,tree_count=0;
    int evals=0;
    BitSet best_group, temp_group;
    int [] stale;
    double [] best_merit;
    double merit;
    boolean z;
    boolean added;
    Link tl;
  
    Hashtable lookup = new Hashtable((int)(200.0*m_numAttributes*1.5));
    LinkedList bfList = new LinkedList();
    best_merit = new double[1]; best_merit[0] = 0.0;
    stale = new int[1]; stale[0] = 0;
    best_group = new BitSet(m_numAttributes);

    // Add class to initial subset
    classI = m_theInstances.classIndex();
    best_group.set(classI);
    best_merit[0] = estimateAccuracy(best_group, 1);
    if (m_debug)
      System.out.println("Accuracy of initial subset: "+best_merit[0]);

    // add the initial group to the list
    bfList.addToList(best_group,best_merit[0]);

    // add initial subset to the hashtable
    lookup.put(best_group,"");
    while (stale[0] < m_maxStale) {

      added = false;

      // finished search?
      if (bfList.size()==0) {
	stale[0] = m_maxStale;
	break;
      }

      // copy the feature set at the head of the list
      tl = bfList.getLinkAt(0);
      temp_group = (BitSet)(tl.getGroup().clone());

      // remove the head of the list
      bfList.removeLinkAt(0);

      for (i=0;i<m_numAttributes;i++) {

	// if (search_direction == 1)
	z = ((i != classI) && (!temp_group.get(i)));
	if (z) {

	  // set the bit (feature to add/delete) */
	  temp_group.set(i);
	  
	  /* if this subset has been seen before, then it is already in 
	     the list (or has been fully expanded) */
	  BitSet tt = (BitSet)temp_group.clone();
	  if (lookup.containsKey(tt) == false) {
	    fc = 0;
	    for (int jj=0;jj<m_numAttributes;jj++) {
	      if (tt.get(jj)) {
		fc++;
	      }
	    }
	    merit = estimateAccuracy(tt, fc);
	    if (m_debug) {
	      System.out.println("evaluating: "+printSub(tt)+" "+merit); 
	    }
	    
	    // is this better than the best?
	    // if (search_direction == 1)
	    z = ((merit - best_merit[0]) > 0.00001);
	 
	    // else
	    // z = ((best_merit[0] - merit) > 0.00001);

	    if (z) {
	      if (m_debug) {
		System.out.println("new best feature set: "+printSub(tt)+
				   " "+merit);
	      }
	      added = true;
	      stale[0] = 0;
	      best_merit[0] = merit;
	      best_group = (BitSet)(temp_group.clone());
	    }

	    // insert this one in the list and the hash table
	    bfList.addToList(tt, merit);
	    lookup.put(tt,"");
	    count++;
	  }

	  // unset this addition(deletion)
	  temp_group.clear(i);
	}
      }
      /* if we haven't added a new feature subset then full expansion 
	 of this node hasn't resulted in anything better */
      if (!added) {
	stale[0]++;
      }
    }
   
    // set selected features
    for (i=0,j=0;i<m_numAttributes;i++) {
      if (best_group.get(i)) {
	j++;
      }
    }
    
    m_decisionFeatures = new int[j];
    for (i=0,j=0;i<m_numAttributes;i++) {
      if (best_group.get(i)) {
	m_decisionFeatures[j++] = i;    
      }
    }
  }
 

  /**
   * Resets the options.
   */
  protected void resetOptions()  {

    m_entries = null;
    m_decisionFeatures = null;
    m_debug = false;
    m_useIBk = false;
    m_CVFolds = 1;
    m_maxStale = 5;
    m_displayRules = false;
  }

   /**
   * Constructor for a DecisionTable
   */
  public DecisionTable() {

    resetOptions();
  }

  /**
   * Returns an enumeration describing the available options.
   *
   * @return an enumeration of all the available options.
   */
  public Enumeration listOptions() {

    Vector newVector = new Vector(5);

    newVector.addElement(new Option(
              "\tNumber of fully expanded non improving subsets to consider\n" +
	      "\tbefore terminating a best first search.\n" +
	      "\tUse in conjunction with -B. (Default = 5)",
              "S", 1, "-S <number of non improving nodes>"));
    

    newVector.addElement(new Option(
              "\tUse cross validation to evaluate features.\n" +
	      "\tUse number of folds = 1 for leave one out CV.\n" +
	      "\t(Default = leave one out CV)",
              "X", 1, "-X <number of folds>"));

     newVector.addElement(new Option(
              "\tUse nearest neighbour instead of global table majority.\n",
              "I", 0, "-I"));

     newVector.addElement(new Option(
              "\tDisplay decision table rules.\n",
              "R", 0, "-R")); 
    return newVector.elements();
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String crossValTipText() {
    return "Sets the number of folds for cross validation (1 = leave one out).";
  }

  /**
   * Sets the number of folds for cross validation (1 = leave one out)
   *
   * @param folds the number of folds
   */
  public void setCrossVal(int folds) {

    m_CVFolds = folds;
  }

  /**
   * Gets the number of folds for cross validation
   *
   * @return the number of cross validation folds
   */
  public int getCrossVal() {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -