📄 tddtinducer.java
字号:
* leading to them.
*/
public int num_nontrivial_leaves()
{
was_trained(true);
return decisionTreeCat.num_nontrivial_leaves();
}
/** Checks if this inducer has a valid decision tree.
* @return True iff the class has a valid decisionTree categorizer.
* @param fatalOnFalse TRUE if an error message should be displayed if the inducer is not trained,
* FALSE otherwise.
*/
public boolean was_trained(boolean fatalOnFalse)
{
if(fatalOnFalse && decisionTreeCat == null)
Error.err("TDDTInducer.was_trained: No decision tree categorizer. "
+ " Call train() to create categorizer -->fatal_error");
return decisionTreeCat != null;
}
/** Induce a decision tree in the given graph.
* @param aCgraph The graph that will contain the decision tree.
* @param tieBreakingOrder The tie breaking order for breaking distribution ties.
* @param numSubtreeErrors Number of errors to this point.
* @param pessimisticSubtreeErrors Error estimate if this was a leaf node.
* @param numLeaves The number of leaves in the decision tree.
* @param remainingSiblings Siblings that have not be induced yet.
* @return The root node of the decision tree.
*/
public Node induce_decision_tree(CGraph aCgraph, int[] tieBreakingOrder, DoubleRef numSubtreeErrors, DoubleRef pessimisticSubtreeErrors, IntRef numLeaves, int remainingSiblings)
{
if (TS.no_weight())
Error.fatalErr("TDDTInducer.induce_decision_tree: list has zero weight");
// DBG(pessimisticSubtreeErrors = -1);
// Create a decision tree object to allow building nodes in the CGraph.
DecisionTree decisionTree = new DecisionTree(aCgraph);
// Display training InstanceList. -JL
logOptions.LOG(4, "Training set ="+'\n'+TS.out(false)+'\n');
LinkedList catNames = new LinkedList();
// catNames[0] = null;
NodeCategorizer[] rootCat = new NodeCategorizer[1];
rootCat[0] = best_split(catNames);
if (rootCat[0] == null) {
rootCat[0] = create_leaf_categorizer(TS.total_weight(),
tieBreakingOrder, numSubtreeErrors,
pessimisticSubtreeErrors);
// We catch empty leaves in induce_tree_from_split. Hence, this can't be
// a trivial leaf.
// MLJ.ASSERT(MLJ.approx_greater(rootCat[0].total_weight(), 0),"TDDTInducer.induce_decision_tree: MLJ.approx_greater(rootCat[0].total_weight(), 0) == false");
numLeaves.value = 1;
decisionTree.set_root(decisionTree.create_node(rootCat, get_level()));
MLJ.ASSERT(rootCat[0] == null,"TDDTInducer.induce_decision_tree: rootCat[0] != null"); // create_node gets ownership
// IFDRIBBLE(dribble_level(level, "Leaf node", remainingSiblings));
} else {
NodeCategorizer splitCat = rootCat[0];
decisionTree.set_root(decisionTree.create_node(rootCat, get_level()));
MLJ.ASSERT(rootCat[0] == null,"TDDTInducer.induce_decision_tree: rootCat[0] != null"); // create_node gets ownership
induce_tree_from_split(decisionTree, splitCat, catNames,
tieBreakingOrder, numSubtreeErrors,
pessimisticSubtreeErrors, numLeaves,
remainingSiblings);
}
catNames = null;
// DBG(catNames = null);
logOptions.LOG(6, "TDDT returning " + decisionTree.get_root() +'\n');
MLJ.ASSERT(pessimisticSubtreeErrors.value >= 0,"TDDTInducer.induce_decision_tree: pessimisticSubtreeErrors.value < 0");
return decisionTree.get_root();
// System.out.println("Warning-->TDDTInducer.induce_decision_tree"
// +" not implemented yet");
// return null;
}
/** Builds a decision tree categorizer for the given DecisionTree.
* @param dTree The DecisionTree to use for creating the categorizer.
*/
protected void build_categorizer(DecisionTree dTree)
{
decisionTreeCat = null;
decisionTreeCat = new DTCategorizer(dTree, description(),
TS.num_categories(),
TS.get_schema());
decisionTreeCat.set_leaf_dist_params(tddtOptions.leafDistType,
tddtOptions.MEstimateFactor,
tddtOptions.evidenceFactor);
//ASSERT(dtree==null);
}
/** Best_split finds the best split in the node and returns a categorizer
* implementing it. It allocates and returns catNames containing the names of the
* resulting categories.
*
* @param catNames The list of categories found in the Node.
* @return The Categorizer using the list of categories.
*/
abstract public NodeCategorizer best_split(LinkedList catNames);
/** Computes the number of errors this node would make as a leaf. If
* totalWeight is zero, the distribution is ignored, else totalWeight
* must be the sum of the distribution counts.
* @return The number of errors this node would make if it were a leaf
* on the decision tree.
* @param cat The Categorizer for the node being checked.
* @param predictClass The category for which this node is being
* checked.
* @param totalWeight The weight of all instances in a data set.
*/
protected static double num_cat_errors(Categorizer cat, int predictClass, double totalWeight)
{
double numErrors = 0;
if (!cat.has_distr())
Error.fatalErr("TDDTInducer.num_cat_errors: Categorizer has no distribution");
// DBG(
// const double[]& dist = cat.get_distr();
// double sum = dist.sum();
// // if (numInstances > 0) @@ will this fail? If yes, comment why
// MLJ.verify_approx_equal((StoredReal)sum, (StoredReal)totalWeight,
// "TDDTInducer.num_cat_errors: summation of "
// "distribution fails to equal number of "
// "instances", 100);
// );
if (totalWeight > 0) { // we're not an empty leaf
double numPredict = cat.get_distr()[predictClass - Globals.FIRST_NOMINAL_VAL];
double nodeErrors = totalWeight - numPredict; // error count
// ASSERT(nodeErrors >= 0);
numErrors = nodeErrors; // increment parent's count of errors
}
return numErrors;
}
/** Creates a leaf categorizer (has no children). We currently create
* a ConstCategorizer with a description and the majority category.
* Note that the augCategory will contain the correct string,
* but the description will contain more information which may
* help when displaying the graph. The augCategory string must
* be the same for CatTestResult to work properly (it compares
* the actual string for debugging purposes).
* @return The LeafCategorizer created.
* @param tieBreakingOrder Order for breaking distribution ties.
* @param totalWeight The total weight of the training data set.
* @param numErrors The number of errors this LeafCategorizer
* will produce.
* @param pessimisticErrors Error estimate if this was a leaf node.
*/
public LeafCategorizer create_leaf_categorizer(double totalWeight,
int[] tieBreakingOrder,
DoubleRef numErrors, DoubleRef pessimisticErrors)
{return create_leaf_categorizer(totalWeight,tieBreakingOrder,numErrors,pessimisticErrors,null);}
/** Creates a leaf categorizer (has no children). We currently create
* a ConstCategorizer with a description and the majority category.
* If the distrArray is given, we don't reference the training set,
* except for its schema (used in pruning).<P>
* Note that the augCategory will contain the correct string,
* but the description will contain more information which may
* help when displaying the graph. The augCategory string must
* be the same for CatTestResult to work properly (it compares
* the actual string for debugging purposes).
* @return The LeafCategorizer created.
* @param tieBreakingOrder Order for breaking distribution ties.
* @param totalWeight The total weight of the training data set.
* @param numErrors The number of errors this LeafCategorizer
* will produce.
* @param pessimisticErrors Error estimate if this was a leaf node.
* @param distrArray Distribution of weight over labels.
*/
public LeafCategorizer create_leaf_categorizer(double totalWeight,
int[] tieBreakingOrder,
DoubleRef numErrors, DoubleRef pessimisticErrors,
double[] distrArray)
{
// Find tiebreaking order.
int[] myTiebreakingOrder = null;
double[] weightDistribution = (distrArray!=null) ? distrArray : TS.counters().label_counts();
// ASSERT(weightDistribution.low() == Globals.UNKNOWN_CATEGORY_VAL);
if (tddtOptions.parentTieBreaking)
myTiebreakingOrder = CatDist.merge_tiebreaking_order(tieBreakingOrder,
weightDistribution);
else
myTiebreakingOrder = CatDist.tiebreaking_order(weightDistribution);
ConstCategorizer leafCat = null;
// @@ this is silly. We compute the majority category, make a ConstCat for
// it, then turn around and predict a different category (the one that
// produces the least loss). We use this majority to compute the number of
// errors, even if we don't predict it!
int majority = CatDist.majority_category(weightDistribution,
myTiebreakingOrder);
if(tddtOptions.leafDistType == allOrNothing) {
AugCategory augMajority = new AugCategory(majority,
TS.get_schema().category_to_label_string(majority));
logOptions.LOG(3, "All-or-nothing Leaf is: ");
leafCat = new ConstCategorizer(" ", augMajority, TS.get_schema());
AugCategory bestPrediction = leafCat.get_category();
logOptions.LOG(3, ""+bestPrediction.toString()+'\n');
String myDescr = bestPrediction.description();//.read_rep();
leafCat.set_description(myDescr);
} else {
double[] fCounts = weightDistribution;
CatDist cDist = null;
switch (tddtOptions.leafDistType) {
case frequencyCounts:
cDist = new CatDist(TS.get_schema(), fCounts, CatDist.none);
logOptions.LOG(3, "Frequency-count Leaf is: ");
break;
case laplaceCorrection:
cDist = new CatDist(TS.get_schema(), fCounts, CatDist.laplace,
tddtOptions.MEstimateFactor);
logOptions.LOG(3, "Laplace Leaf is: ");
break;
case evidenceProjection:
cDist = new CatDist(TS.get_schema(), fCounts, CatDist.evidence,
tddtOptions.evidenceFactor);
logOptions.LOG(3, "Evidence Leaf is: ");
break;
default:
MLJ.Abort();
}
logOptions.LOG(3, ""+cDist+'\n');
cDist.set_tiebreaking_order(myTiebreakingOrder);
AugCategory bestCategory = cDist.best_category();
String myDescr = bestCategory.description();//.read_rep();
leafCat = new ConstCategorizer(myDescr, cDist, TS.get_schema());
// DBG(ASSERT(cDist == null));
}
myTiebreakingOrder = null; //delete myTiebreakingOrder;
// ASSERT(leafCat);
// DBGSLOW(
// InstanceRC dummy(leafCat.get_schema());
// MStringRC predDescr = leafCat.categorize(dummy).description();
// if (predDescr != leafCat.description())
// Error.fatalErr("cat descriptions don't match: I picked "
// +leafCat.description()+", leafCat predicted "
// +predDescr+". CatDist is "+leafCat.get_cat_dist());
// );
if (distrArray != null) {
double[] newDistr = new double[distrArray.length - 1];//(0, distrArray.length - 1, 0);
for (int i = 0; i < newDistr.length; i++)
newDistr[i] = distrArray[i];
leafCat.set_distr(newDistr);
} else {
// Use coarser granularity when approx_equal invoked with floats.
if (MLJ.approx_equal((float)totalWeight,0.0)
&& !tddtOptions.emptyNodeParentDist) {
double[] disArray = new double[TS.num_categories()];// (0, TS.num_categories(), 0);
leafCat.set_distr(disArray);
} else
leafCat.build_distr(instance_list());
}
// If there are no instances, we predict like the parent and
// the penalty for pessimistic errors comes from the other children.
// Note that we can't just call num_cat because the distribution
// may be the parent's distribution
// Use coarser granularity when approx_equal invoked with floats.
if (MLJ.approx_equal((float)totalWeight,0.0)) {
numErrors.value = 0;
pessimisticErrors.value = 0;
} else {
numErrors.value = num_cat_errors(leafCat, majority, totalWeight);
pessimisticErrors.value = CatTestResult.pessimistic_error_correction(
numErrors.value, totalWeight, get_pruning_factor());
}
// ASSERT(numErrors >= 0);
// ASSERT(pessimisticErrors >= 0);
/*
if (get_debug()) {
int numChars = 128;
char buffer[numChars];
for (int chr = 0; chr < numChars; chr++)
buffer[chr] = '\0';
MLCOStream *stream = new MLCOStream(EMPTY_STRING, buffer, numChars);
CatDist score = leafCat.get_cat_dist();
*stream << score.get_scores();
String pDist = stream.mem_buf();
stream = null; //delete stream;
stream = new MLCOStream(EMPTY_STRING, buffer, numChars);
*stream << leafCat.get_distr();
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -