📄 id3inducer.java
字号:
package id3;
import java.lang.*;
import java.util.*;
import shared.*;
import shared.Error;
/** The ID3Class is the Java implementation of the ID3 algorithm. The
* ID3 algorithm is a top-down decision-tree induction algorithm. This
* algorithm uses the mutual information (original gain criteria),and
* not the more recent information gain ratio.<P>
* Complexity:<P>
* Our split() method uses entropy and takes time O(vy) where v is
* the total number of attribute values (over all attributes) and y
* is the number of label values. This can be derived by noting that
* mutual_info is computed for each attribute.<P>
* Node categorizers (for predict) are AttrCategorizer and take
* constant time, thus the overall prediction time is O(path-length).<P>
* See TDDTInducer for more complexity information.<P>
* Enhancements:<P>
* The ID3Compute entropy once for the node, and pass it along to
* avoid multiple computations like we do now.<P>
*
* @author James Louis 12/7/2000 Ported to Java
* @author Clay Kunz 10/22/96 Changed bestSi to a pointer everywhere so
* that we don't copy lots of split objects
* around.
* @author Yeogirl Yun 7/4/95 Added copy constructor.
* @author Ronny Kohavi 9/08/93 Initial revision (.h,.c)
*/
public class ID3Inducer extends TDDTInducer
{
/** Constructor.
* @param dscr The description of this inducer.
* @param aCgraph A previously developed Cgraph.
*/
public ID3Inducer(String dscr, CGraph aCgraph)
{
super(dscr, aCgraph);
}
/** Constructor.
* @param dscr The description of this inducer.
*/
public ID3Inducer(String dscr)
{
super(dscr);
}
/** Copy Constructor.
* @param source The original ID3Inducer that is being copied.
*/
public ID3Inducer(ID3Inducer source)
{
super(source);
}
/** Returns the AttrCategorizer that splits on the best attribute found using
* mutual information(information gain). Returns null if there is nothing
* good to split on. Ties between this attribute and earlier attributes are
* broken.
* @param catNames The names of the categories that each instance may be
* catagorized under.
* @return The NodeCategorizer that splits on the best attribute found. May be
* null if no good attribute split is found.
*/
public NodeCategorizer best_split(LinkedList catNames)
{
Schema schema = TS.get_schema();
//schema used to be SchemaRC :JL
// @@ change these to return an index instead of bestSplit.
// SplitAttr noSplit;
//bestSplit used to be set equal to noSplit : JL
SplitAttr[] bestSplit = new SplitAttr[1];
bestSplit[0] = new SplitAttr();
SplitAttr[] splits = new SplitAttr[schema.num_attr()];
for(int z = 0; z < splits.length;z++) splits[z] = new SplitAttr();
// @@ Call routine to initialize splits - sets penalty, minSplit
if (!find_splits(bestSplit, splits)) return null;
MLJ.ASSERT((bestSplit[0] != null) && (bestSplit[0].split_type() != SplitAttr.noReasonableSplit),
"ID3Inducer:best_split--(bestSplit == null)"+
"or(bestSplit.split_type() == noReasonableSplit)");
NodeCategorizer bestCat = null;
bestCat = split_to_cat(bestSplit[0], catNames);
MLJ.ASSERT(bestCat != null,"ID3Inducer:best_split--bestCat == null");
// DBG(bestCat->OK());
logOptions.LOG(2, "Created split on attribute "+bestSplit[0].get_attr_num()+" ("+
schema.attr_name(bestSplit[0].get_attr_num())+") at level "+
get_level()+'\n');
bestCat.build_distr(instance_list());
return bestCat;
}
/** Fills in the array of splits for current subtree. It does very
* little, but rarely overriden whereas best_split_info is overridden
* by subclasses.
* @return False if there is only one label value, the maximum number
* of splits is reached, or if there is no reasonable split
* available.
* @param bestSplit This is an array of the best splits found during the
* splitting process.
* @param splits This is an array of all splits found during the
* splitting process.
*/
public boolean find_splits(SplitAttr[] bestSplit,
SplitAttr[] splits)
{
if (TS.counters().label_num_vals() == 1)
return false; // if we have one label value, we're done.
if ((get_max_level() > 0)&&(get_level() >= get_max_level())) {
logOptions.LOG(2, "Maximum level "+get_max_level()+" reached "+'\n');
return false;
}
logOptions.LOG(3, TS.counters().toString());
best_split_info(bestSplit, splits);
return (bestSplit[0].split_type() != SplitAttr.noReasonableSplit);
}
/** Fills in the array of SplitAttr for current subtree. This function
* is a good candidate to override in subclasses.
* @param bestSplit This is an array of the best splits found during the
* splitting process.
* @param splits This is an array of all splits found during the
* splitting process.
*/
public void best_split_info(SplitAttr[] bestSplit, SplitAttr[] splits)
{
Schema schema = TS.get_schema();
//schema used to be SchemaRC : JL
int numAttributes = schema.num_attr();
StatData allMutualInfo = new StatData();
StatData allNonMultiValMutualInfo = new StatData();
RealAndLabelColumn[] realColumns = null;
if (get_have_continuous_attributes()) {
boolean[] mask = new boolean[numAttributes];
for(int z = 0; z < numAttributes; z++) mask[z] = true;
realColumns = TS.transpose(mask);
}
for (int attrNum = 0; attrNum < numAttributes; attrNum++) {
split_info(attrNum, splits[attrNum], realColumns);
// Find the mean of the mutual information over all attributes
// with reasonable splits. From c4.5, we accumulate separately
// the mutual information that originates from attributes that
// do not have "too many" values. Unless ALL attributes fail
// this criterion we use only those from the "smaller" attributes.
// @@ We may want to compute the mean only when it's needed, i.e.,
// @@ for gain-ratio emulation
if (splits[attrNum].split_type() != SplitAttr.noReasonableSplit) {
double mi = splits[attrNum].get_mutual_info(false, true);
MLJ.ASSERT(mi >= 0,"ID3Inducer.best_split_info(SplitAttr,SplitAttr[])--"+
" mi < 0");
logOptions.LOG(3, "Adding mutualInfo "+mi+" to mean.");
allMutualInfo.insert(mi);
if (!multi_val_attribute(attrNum)) {
allNonMultiValMutualInfo.insert(mi);
logOptions.LOG(3, " It's not multi-val.");
}
logOptions.LOG(3,'\n');
}
}
realColumns = null;
pick_best_split(bestSplit, splits, allMutualInfo,allNonMultiValMutualInfo);
}
/** Return true if the attribute has many values according to
* the C4.5 definition.
* @return True if this attribute has many values, False otherwise.
* @param attrNum The number of the attribute being checked.
*/
public boolean multi_val_attribute(int attrNum)
{
double totalWeight = get_total_inst_weight();
MLJ.ASSERT(totalWeight >= 0,"ID3Inducer.multi_val_attribute(int)--"+
" totalWeight < 0");
Schema schema = TS.get_schema();
//schema used to be SchemaRC : JL
return ((schema.attr_info(attrNum).can_cast_to_nominal())&&(schema.num_attr_values(attrNum) >= (0.3 * totalWeight)));
}
/** Choose the best attribute to split the on from all possible splits.
* @param bestSplit The array of the best splits found during splitting
* process.
* @param splits The array of all splits found during the splitting
* process.
* @param allMutualInfo Statistical information about all instances.
* @param allNonMultiValMutualInfo Statistical information about instances
* where an attribute can only have one
* value at a time.
*/
public void pick_best_split(SplitAttr[] bestSplit,
SplitAttr[] splits,
StatData allMutualInfo,
StatData allNonMultiValMutualInfo)
{
Schema schema = TS.get_schema();
int numAttributes = schema.num_attr();
if (get_split_score_criterion() != SplitScore.gainRatio) {
for (int attrNum = 0; attrNum < numAttributes; attrNum++) {
SplitAttr split = splits[attrNum];
if (split.split_type() != SplitAttr.noReasonableSplit) {
// Remember the best. MLJ.realEpsilon is added because on
// monk1, the difference is 1e-16, and we want to tie break
// exactly as C4.5 does.
// First half of test is because bestSplit might be unset, in
// which case we can't get its criterion score.
if (bestSplit[0].split_type() == SplitAttr.noReasonableSplit
|| split.score() > (bestSplit[0].score() + MLJ.realEpsilon))
bestSplit[0] = split;
}
}
} else { // gain ratio
double meanMutualInfo = Globals.UNDEFINED_REAL;
if (allMutualInfo.size() > 0)
if (all_attributes_multi_val() || allNonMultiValMutualInfo.size() == 0) {
meanMutualInfo = allMutualInfo.mean();
if (all_attributes_multi_val()) logOptions.LOG(3, "All attributes are multi-val."+'\n');
}
else
meanMutualInfo = allNonMultiValMutualInfo.mean();
logOptions.LOG(3,"Mean mutual info is "+meanMutualInfo+'\n');
// Look at the criterion score for each attribute. Any time an
// attribute has a mutual info greater than the mean mutual info
// it's a candidate for chosing as best. If its score is
// greater than the max so far, pick it.
double maxScore = Globals.UNDEFINED_REAL;
boolean foundScoreAboveMean = false;
for (int attrNum = 0; attrNum < numAttributes; attrNum++) {
SplitAttr split = splits[attrNum];
logOptions.LOG(3,"For attribute "+attrNum+", checking for reasonable split");
if (split.split_type() == SplitAttr.noReasonableSplit){
logOptions.LOG(3,"...Sorry, no reasonable split"+'\n');
}
else {
boolean mutualInfoAboveMean = split.get_mutual_info(false,true) >
meanMutualInfo + MLJ.realEpsilon;
// was || maxScore == Globals.UNDEFINED_REAL)
// if (maxScore == Globals.UNDEFINED_REAL) MLJ.ASSERT(!foundScoreAboveMean);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -