⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 id3inducer.java

📁 Decision Tree 决策树算法ID3 数据挖掘 分类
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
package id3;
import java.lang.*;
import java.util.*;
import shared.*;
import shared.Error;

/** The ID3Class is the Java implementation of the ID3 algorithm. The
 * ID3 algorithm is a top-down decision-tree induction algorithm. This
 * algorithm uses the mutual information (original gain criteria),and
 * not the more recent information gain ratio.<P>
 * Complexity:<P>
 * Our split() method uses entropy and takes time O(vy) where v is
 * the total number of attribute values (over all attributes) and y
 * is the number of label values. This can be derived by noting that
 * mutual_info is computed for each attribute.<P>
 * Node categorizers (for predict) are AttrCategorizer and take
 * constant time, thus the overall prediction time is O(path-length).<P>
 * See TDDTInducer for more complexity information.<P>
 * Enhancements:<P>
 * The ID3Compute entropy once for the node, and pass it along to
 * avoid multiple computations like we do now.<P>
 *
 * @author James Louis 12/7/2000 Ported to Java
 * @author Clay Kunz 10/22/96 Changed bestSi to a pointer everywhere so
 * that we don't copy lots of split objects
 * around.
 * @author Yeogirl Yun 7/4/95 Added copy constructor.
 * @author Ronny Kohavi 9/08/93 Initial revision (.h,.c)
 */
public class ID3Inducer extends TDDTInducer
{
    /** Constructor.
     * @param dscr    The description of this inducer.
     * @param aCgraph A previously developed Cgraph.
     */
   public ID3Inducer(String dscr, CGraph aCgraph)
   {
      super(dscr, aCgraph);
   }

   /** Constructor.
    * @param dscr The description of this inducer.
    */
   public ID3Inducer(String dscr)
   {
      super(dscr); 
   }

   /** Copy Constructor.
    * @param source The original ID3Inducer that is being copied.
    */
   public ID3Inducer(ID3Inducer source)
   {
      super(source);
   }

   /** Returns the AttrCategorizer that splits on the best attribute found using
    * mutual information(information gain). Returns null if there is nothing
    * good to split on. Ties between this attribute and earlier attributes are
    * broken.
    * @param catNames The names of the categories that each instance may be
    * catagorized under.
    * @return The NodeCategorizer that splits on the best attribute found. May be
    * null if no good attribute split is found.
    */
   public  NodeCategorizer best_split(LinkedList catNames) 
   {
      Schema schema = TS.get_schema();
//schema used to be SchemaRC :JL
// @@ change these to return an index instead of bestSplit.
//   SplitAttr noSplit;
//bestSplit used to be set equal to noSplit : JL
      SplitAttr[] bestSplit = new SplitAttr[1]; 
	bestSplit[0] = new SplitAttr();
      SplitAttr[] splits = new SplitAttr[schema.num_attr()];
	for(int z = 0; z < splits.length;z++) splits[z] = new SplitAttr();
// @@ Call routine to initialize splits - sets penalty, minSplit
      if (!find_splits(bestSplit, splits)) return null;
      MLJ.ASSERT((bestSplit[0] != null) &&  (bestSplit[0].split_type() != SplitAttr.noReasonableSplit),
		"ID3Inducer:best_split--(bestSplit == null)"+
		"or(bestSplit.split_type() == noReasonableSplit)");
      NodeCategorizer bestCat = null;
      bestCat = split_to_cat(bestSplit[0], catNames);
      MLJ.ASSERT(bestCat != null,"ID3Inducer:best_split--bestCat == null");
//   DBG(bestCat->OK());
      logOptions.LOG(2, "Created split on attribute "+bestSplit[0].get_attr_num()+" ("+
          schema.attr_name(bestSplit[0].get_attr_num())+") at level "+
          get_level()+'\n');
      bestCat.build_distr(instance_list());
      return bestCat;
   }

   /** Fills in the array of splits for current subtree. It does very
    * little, but rarely overriden whereas best_split_info is overridden
    * by subclasses.
    * @return False if there is only one label value, the maximum number
    * of splits is reached, or if there is no reasonable split
    * available.
    * @param bestSplit This is an array of the best splits found during the
    * splitting process.
    * @param splits This is an array of all splits found during the
    * splitting process.
    */
   public boolean find_splits(SplitAttr[] bestSplit,
			    SplitAttr[] splits) 
   {
      if (TS.counters().label_num_vals() == 1)
         return false; // if we have one label value, we're done.
      if ((get_max_level() > 0)&&(get_level() >= get_max_level())) {
         logOptions.LOG(2, "Maximum level "+get_max_level()+" reached "+'\n');
         return false;
      }
      logOptions.LOG(3, TS.counters().toString());
      best_split_info(bestSplit, splits);
      return (bestSplit[0].split_type() != SplitAttr.noReasonableSplit);
   }

   /** Fills in the array of SplitAttr for current subtree. This function
    * is a good candidate to override in subclasses.
    * @param bestSplit	This is an array of the best splits found during the
    * splitting process.
    * @param splits	This is an array of all splits found during the
    * splitting process.
    */
   public  void best_split_info(SplitAttr[] bestSplit, SplitAttr[] splits) 
   {
      Schema schema = TS.get_schema();
   		//schema used to be SchemaRC : JL
      int numAttributes = schema.num_attr();
   
      StatData allMutualInfo = new StatData();
      StatData allNonMultiValMutualInfo = new StatData();
   
      RealAndLabelColumn[] realColumns = null;
      if (get_have_continuous_attributes()) {
         boolean[] mask = new boolean[numAttributes];
         for(int z = 0; z < numAttributes; z++) mask[z] = true;
         realColumns = TS.transpose(mask);
      }

      for (int attrNum = 0; attrNum < numAttributes; attrNum++) {
         split_info(attrNum, splits[attrNum], realColumns);
         // Find the mean of the mutual information over all attributes
         //   with reasonable splits.  From c4.5, we accumulate separately
         //   the mutual information that originates from attributes that
         //   do not have "too many" values.  Unless ALL attributes fail
         //   this criterion we use only those from the "smaller" attributes.
         // @@ We may want to compute the mean only when it's needed, i.e.,
         // @@ for gain-ratio emulation
         if (splits[attrNum].split_type() != SplitAttr.noReasonableSplit) {
            double mi = splits[attrNum].get_mutual_info(false, true);
            MLJ.ASSERT(mi >= 0,"ID3Inducer.best_split_info(SplitAttr,SplitAttr[])--"+
   			" mi < 0");
            logOptions.LOG(3, "Adding mutualInfo "+mi+" to mean.");
            allMutualInfo.insert(mi);
            if (!multi_val_attribute(attrNum)) {
               allNonMultiValMutualInfo.insert(mi);
               logOptions.LOG(3, "  It's not multi-val.");
            }
   	   logOptions.LOG(3,'\n');
         }
      }
      realColumns = null;
      pick_best_split(bestSplit, splits, allMutualInfo,allNonMultiValMutualInfo);
   }

   /** Return true if the attribute has many values according to
    * the C4.5 definition.
    * @return True if this attribute has many values, False otherwise.
    * @param attrNum	The number of the attribute being checked.
    */
   public boolean multi_val_attribute(int attrNum) 
   {
      double totalWeight = get_total_inst_weight();
      MLJ.ASSERT(totalWeight >= 0,"ID3Inducer.multi_val_attribute(int)--"+
   		 " totalWeight < 0");
      Schema schema = TS.get_schema();
//schema used to be SchemaRC : JL
      return ((schema.attr_info(attrNum).can_cast_to_nominal())&&(schema.num_attr_values(attrNum) >= (0.3 * totalWeight)));
   }

   /** Choose the best attribute to split the on from all possible splits.
    * @param bestSplit	The array of the best splits found during splitting
    * process.
    * @param splits	The array of all splits found during the splitting
    * process.
    * @param allMutualInfo	Statistical information about all instances.
    * @param allNonMultiValMutualInfo	Statistical information about instances
    * where an attribute can only have one
    * value at a time.
    */
   public void pick_best_split(SplitAttr[] bestSplit,
					SplitAttr[] splits,
					StatData allMutualInfo,
					StatData allNonMultiValMutualInfo) 
   {
      Schema schema = TS.get_schema();
      int numAttributes = schema.num_attr();

      if (get_split_score_criterion() != SplitScore.gainRatio) {
         for (int attrNum = 0; attrNum < numAttributes; attrNum++) {
            SplitAttr split = splits[attrNum];
            if (split.split_type() != SplitAttr.noReasonableSplit) {
      	    // Remember the best.  MLJ.realEpsilon is added because on
      	    //   monk1, the difference is 1e-16, and we want to tie break
      	    //   exactly as C4.5 does.
      	    // First half of test is because bestSplit might be unset, in
      	    //   which case we can't get its criterion score.
               if (bestSplit[0].split_type() == SplitAttr.noReasonableSplit
                  || split.score() > (bestSplit[0].score() + MLJ.realEpsilon))
                  bestSplit[0] = split;
            }
         }
      } else { // gain ratio
         double meanMutualInfo = Globals.UNDEFINED_REAL;
         if (allMutualInfo.size() > 0) 
         if (all_attributes_multi_val() || allNonMultiValMutualInfo.size() == 0) {
            meanMutualInfo = allMutualInfo.mean();
            if (all_attributes_multi_val()) logOptions.LOG(3, "All attributes are multi-val."+'\n');
         }
         else
            meanMutualInfo = allNonMultiValMutualInfo.mean();      
         logOptions.LOG(3,"Mean mutual info is "+meanMutualInfo+'\n');
   
         // Look at the criterion score for each attribute.  Any time an
         //   attribute has a mutual info greater than the mean mutual info
         //   it's a candidate for chosing as best.  If its score is
         //   greater than the max so far, pick it.
         double maxScore = Globals.UNDEFINED_REAL;
         boolean foundScoreAboveMean = false;
         for (int attrNum = 0; attrNum < numAttributes; attrNum++) {
            SplitAttr split = splits[attrNum];
            logOptions.LOG(3,"For attribute "+attrNum+", checking for reasonable split");
            if (split.split_type() == SplitAttr.noReasonableSplit){
               logOptions.LOG(3,"...Sorry, no reasonable split"+'\n');
            }
            else {
               boolean mutualInfoAboveMean = split.get_mutual_info(false,true) >
               meanMutualInfo + MLJ.realEpsilon;
   	    // was || maxScore == Globals.UNDEFINED_REAL)
//               if (maxScore == Globals.UNDEFINED_REAL) MLJ.ASSERT(!foundScoreAboveMean);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -