📄 tddtinducer.java
字号:
String wDist = stream.mem_buf();
stream = null; //delete stream;
MString& newDescr = leafCat.description();
String dbgDescr = newDescr + " (#=" + MString(totalWeight,0) +
" Err=" + MString(numErrors, 0) + "/" +
String(pessimisticErrors, 2) + ")\\npDist=" + pDist +
"\\nwDist=" + wDist;
leafCat.set_description(dbgDescr);
}
*/
Categorizer cat = leafCat;
LeafCategorizer leafCategorizer = new LeafCategorizer(cat);
// DBG(ASSERT(cat == null));
return leafCategorizer;
}
/** Induce decision tree from a given split. The split is provided
* in the form of a categorizer, which picks which subtree a given
* instance will follow.
* @param decisionTree Decision tree induced.
* @param splitCat The categorizer for this split.
* @param catNames List of category names.
* @param tieBreakingOrder Order for breaking distribution ties.
* @param numSubtreeErrors Number of errors this subtree produces in categorization of Instances.
* @param pessimisticSubtreeErrors Error estimate if this was a leaf node.
* @param numLeaves Number of leaves on a subtree.
* @param remainingSiblings Siblings that have not be induced yet.
*/
protected void induce_tree_from_split(DecisionTree decisionTree, NodeCategorizer splitCat, LinkedList catNames, int[] tieBreakingOrder, DoubleRef numSubtreeErrors, DoubleRef pessimisticSubtreeErrors, IntRef numLeaves, int remainingSiblings)
{
int[] myTiebreakingOrder =
CatDist.merge_tiebreaking_order(tieBreakingOrder,
TS.counters().label_counts());
InstanceList[] instLists =
splitCat.split_instance_list(instance_list());
// Add one if we have unknown instances
// IFDRIBBLE(dribble_level(level, splitCat.description(), remainingSiblings));
numSubtreeErrors.value = 0;
pessimisticSubtreeErrors.value = 0;
numLeaves.value = 0;
DoubleRef numChildErrors = new DoubleRef(0);
DoubleRef childPessimisticErrors = new DoubleRef(0);
Node largestChild = null; // with the most instances (weight)
DoubleRef maxChildWeight = new DoubleRef(-1);
for (int cat = 0; cat < instLists.length; cat++) {
if (instLists[cat].num_instances() >= instance_list().num_instances())
Error.fatalErr("TDDTInducer.induce_tree_from_split: the most recent split "
+splitCat.description()+" resulted in no reduction of the "
+"instance list total weight (from "
+instance_list().total_weight()+" to "
+instLists[cat].total_weight());
int remainingChildren = instLists.length - cat;
Node child;
if (instLists[cat].no_weight()) {
// No weight of instances with this value. Make it a leaf (majority),
// unless category unknown.
// if (cat != UNKNOWN_CATEGORY_VAL)
// IFDRIBBLE(dribble_level(level+1, "Leaf node",
// remainingChildren));
if (get_unknown_edges() || cat != Globals.UNKNOWN_CATEGORY_VAL) {
logOptions.LOG(3, "Category: " + (cat - 1)//-1 added to match MLC output -JL
+" empty. Assigning majority"+'\n');
NodeCategorizer[] constCat = new NodeCategorizer[1];
constCat[0] = create_leaf_categorizer(0, myTiebreakingOrder,
numChildErrors, childPessimisticErrors);
if (cat != Globals.UNKNOWN_CATEGORY_VAL)
++numLeaves.value; // don't count trivial leaves
MLJ.ASSERT(numChildErrors.value == 0,"TDDTInducer.induce_tree_from_split: numChildErrors.value != 0");
MLJ.ASSERT(childPessimisticErrors.value == 0,"TDDTInducer.induce_tree_from_split: childPessimisticErrors.value != 0");
child = decisionTree.create_node(constCat, get_level() + 1);
MLJ.ASSERT(constCat[0] == null,"TDDTInducer.induce_tree_from_split: constCat != null");
//create_node gets ownership
logOptions.LOG(6, "Created child leaf "+child+'\n');
logOptions.LOG(6, "Connecting root "+decisionTree.get_root()
+" to child "+child
+" with string '"+(String)catNames.get(cat)+"'"+'\n');
connect(decisionTree, decisionTree.get_root(), child,
cat, (String)catNames.get(cat));
}
} else { // Solve the problem recursively.
CGraph aCgraph = decisionTree.get_graph();
logOptions.LOG(3, "Recursive call"+'\n');
double totalChildWeight = instLists[cat].total_weight();
TDDTInducer childInducer =
create_subinducer(name_sub_inducer(splitCat.description(), cat),
aCgraph);
childInducer.set_total_inst_weight(get_total_inst_weight());
childInducer.assign_data(instLists[cat]);
IntRef numChildLeaves = new IntRef(0);
child = childInducer.induce_decision_tree(aCgraph,
myTiebreakingOrder,
numChildErrors,
childPessimisticErrors,
numChildLeaves,
remainingChildren);
numSubtreeErrors.value += numChildErrors.value;
pessimisticSubtreeErrors.value += childPessimisticErrors.value;
numLeaves.value += numChildLeaves.value;
if (totalChildWeight > maxChildWeight.value) {
maxChildWeight.value = totalChildWeight;
largestChild = child;
}
childInducer = null; //delete childInducer;
Node root = decisionTree.get_root();
logOptions.LOG(6, "Connecting child "+child+" to root "
+root+", using "+cat
+" with string '"+(String)catNames.get(cat)+"'"+'\n');
connect(decisionTree, decisionTree.get_root(), child,
cat, (String)catNames.get(cat));
}
}
MLJ.clamp_above(maxChildWeight, 0, "TDDTInducer.induce_tree_from_split: "
+"maximum child's weight must be non-negative");
MLJ.ASSERT(largestChild != null,"TDDTInducer.induce_tree_from_split: largestChild == null");
// DBGSLOW(decisionTree.OK(1));
instLists = null; //delete &instLists;
/* prune_subtree(decisionTree, myTiebreakingOrder,
largestChild, numSubtreeErrors, pessimisticSubtreeErrors,
numLeaves);
*/ myTiebreakingOrder = null; //delete myTiebreakingOrder;
/*
if (get_debug()) {
// Cast away constness for modifying the name.
Categorizer splitC = (Categorizer)decisionTree.
get_categorizer(decisionTree.get_root());
String name = splitC.description();
double[] distribution = splitC.get_distr();
int numChars = 128;
char buffer[numChars];
for (int chr = 0; chr < numChars; chr++)
buffer[chr] = '\0';
MLCOStream stream(EMPTY_STRING, buffer, numChars);
stream << distribution;
String distDescrip = stream.mem_buf();
String newName = name + "\\nErr=" + String(numSubtreeErrors, 3) +
"/" + String(pessimisticSubtreeErrors, 3);
if (splitC.class_id() != CLASS_CONST_CATEGORIZER)
newName += "\\nwDist=" + distDescrip;
splitC.set_description(newName);
}
*/
// if (get_level() == 0)
// DRIBBLE(endl);
}
/** Connects two nodes in the specified CatGraph.
* @param catGraph The CatGreph containing these nodes.
* @param from The node from which the edge originates.
* @param to The node to which the edge connects.
* @param edgeVal The value of the AugCategory associated with that edge.
* @param edgeName The name of the edge.
*/
protected void connect(CatGraph catGraph, Node from, Node to, int edgeVal, String edgeName)
{
AugCategory edge = new AugCategory(edgeVal, edgeName);
logOptions.GLOBLOG(6, "TDDTInducer's connect(), given string '" +edgeName
+"', is using '" + edge.description()
+"' as an edge description\n");
catGraph.connect(from, to, edge);
// ASSERT(edge == NULL); // connect() gets ownership
// catGraph.OK(1);
}
/** Create a string to name the subinducer. We just append some basic info.
* @return The name of the subinducer.
* @param catDescr The description of this subinducer.
* @param catNum The category number for which this subinducer is
* inducing.
*/
public String name_sub_inducer(String catDescr, int catNum)
{
String CAT_EQUAL = " Cat=";
String CHILD_EQUAL = " child =";
return description() + CAT_EQUAL + catDescr + CHILD_EQUAL + catNum;
}
/** Create_subinducer creates the Inducer for calling recursively. Note that since
* this is an abstract class, it can't create a copy of itself.
*
* @param dscr The description for the sub inducer.
* @param aCgraph The categorizer graph to use for the subinducer.
* @return The new subinducer.
*/
abstract public TDDTInducer create_subinducer(String dscr, CGraph aCgraph);
/** When the subtree rooted from the current node does not improve
* the error, the subtree may be replaced by a leaf or by its largest
* child. This serves as a collapsing mechanism if the pruning factor
* is 0, i.e., we collapse the subtree if it has the same number of
* errors as all children.<P>
* "Confidence" pruning is based on C4.5's pruning method. "Penalty"
* pruning is based on "Pessimistic Decision tree pruning based on tree
* size" by Yishay Mansour, ICML-97. "Linear" pruning is used to implement
* cost-complexity pruning as described in CART. Its use is not
* recommended otherwise. "KLdistance" pruning uses the Kullback-Leibler
* distance metric to determine whether to prune.<P>
* This function is divided into three main parts. First, initial
* checks are performed and values are set. Second, the test specific
* to each pruning method is performed. Last, if pruning is
* necessary, do it.
* @param decisionTree Tree to be pruned.
* @param tieBreakingOrder Order for breaking distribution ties.
* @param largestChild The largest child node.
* @param numSubtreeErrors Number of errors this subtree produces in categorization of Instances.
* @param pessimisticSubtreeErrors Error estimate if this was a leaf node.
* @param numLeaves Number of leaves on a subtree.
*/
public void prune_subtree(DecisionTree decisionTree,
int[] tieBreakingOrder,
Node largestChild,
DoubleRef numSubtreeErrors,
DoubleRef pessimisticSubtreeErrors,
IntRef numLeaves)
{
logOptions.LOG(0,"Pruning is taking place.\n");
MLJ.ASSERT(numSubtreeErrors.value >= 0,"TDDTInducer:prune_subtree:"
+" numSubtreeErrors < 0");
MLJ.ASSERT(pessimisticSubtreeErrors.value >= 0,"TDDTInducer:prune_subtree:"
+" pessimisticSubtreeErrors < 0");
Node treeRoot = decisionTree.get_root(true);
// @@ CatDTInducers can't prune, but we don't want to check
// get_prune_tree() here because even if we're not doing pruning, this code
// does some useful safety checks. The checks aren't valid on
// CatDTInducers, because they do not compute pessmisticSubtreeErrors.
// if (this instanceof CatDTInducer) return;
// if (class_id() == CatDT_INDUCER)
// return;
// DBGSLOW(if (numLeaves != decisionTree.num_nontrivial_leaves())
// Error.fatalErr("TDDTInducer.prune_subtree: number of leaves given "
// +numLeaves+" is not the same as the number counted "
// +decisionTree.num_nontrivial_leaves()));
// DBGSLOW(
// // We don't want any side effect logging only in debug level
// logOptions logOpt(logOptions.get_log_options());
// logOpt.set_log_level(0);
// double pess_err =
// pessimistic_subtree_errors(logOpt, decisionTree, treeRoot, *TS,
// get_pruning_factor(), tieBreakingOrder);
// MLJ.verify_approx_equal(pess_err, pessimisticSubtreeErrors,
// "TDDTInducer.prune_subtree: pessimistic error"
// " differs from expected value");
// );
// How many errors (weighted) would we make with a leaf here?
int myMajority = TS.majority_category(tieBreakingOrder);
double numMajority = TS.counters().label_count(myMajority);
double totalWeight = TS.total_weight();
double myErrors = totalWeight - numMajority;
if (!(MLJ.approx_greater(myErrors, numSubtreeErrors.value) ||
MLJ.approx_equal(myErrors, numSubtreeErrors.value)))
Error.fatalErr("TDDTInducer.prune_subtree: myErrors is not >= numSubtreeErrors"
+": myErrors - numSubtreeErrors = "+(myErrors - numSubtreeErrors.value));
int numChildren = decisionTree.num_children(treeRoot);
// test if a leaf; if so, we can exit immediately
if (numChildren == 0) {
numSubtreeErrors.value = totalWeight - numMajority;
numLeaves.value = 1;
return;
}
logOptions.LOG(3, "Testing at "
+decisionTree.get_categorizer(treeRoot).description()
+" (weight "+decisionTree.get_categorizer(treeRoot).total_weight()
+')'+'\n');
boolean pruneSubtree = false;
boolean pruneChild = false;
// We need to declare these here, as we use them during pruning
double myPessimisticErrors = CatTestResult.pessimistic_error_correction(
myErrors, TS.total_weight(), get_pruning_factor());
DoubleRef childPessimisticErrors = new DoubleRef(0);
if (get_pruning_factor() == 0)
MLJ.verify_approx_equal(myPessimisticErrors, myErrors,
"TDDTInducer.prune_subtree:pessimistic error "
+"when computed for leaf, "
+"differs from expected value");
switch (get_pruning_method()) {
case confidence:
//@@ replace "100 * MLC.real_epsilon()" with "0.1" for
//@@ C4.5 functionality
if (myPessimisticErrors - pessimis
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -