📄 tddtinducer.java
字号:
* @return The evidence correction factor.
*/
public double get_evidence_factor() { return tddtOptions.evidenceFactor; }
/** Returns whether there are continuous attributes present in the data.
* @return TRUE indicates there are continuous attributes in the data, FALSE otherwise.
*/
public boolean get_have_continuous_attributes()
{ return haveContinuousAttributes; }
/** Sets whether there are continuous attributes present in the data.
* @param val TRUE indicates there are continuous attributes in the data, FALSE otherwise.
*/
public void set_have_continuous_attributes(boolean val)
{ haveContinuousAttributes = val; }
/** Accesses the debug variable.
* @return TRUE if debugging statements are active, FALSE otherwise.
*/
public boolean get_debug() { return tddtOptions.debug; }
/** Sets the debugging option.
* @param val TRUE if debugging statements are active, FALSE otherwise.
*/
public void set_debug(boolean val) { tddtOptions.debug = val; }
/** Sets the evaluation metric used.
* @param metric The evaluation metric to be used.
*/
public void set_evaluation_metric(byte metric)
{ tddtOptions.evaluationMetric = metric; }
/** Accesses the evaluation metric used.
* @return The evaluation metric to be used.
*/
public byte get_evaluation_metric()
{ return tddtOptions.evaluationMetric; }
/** Sets statistical information about the tree. This information consists of the total
* number of nontrivial nodes and the total number of attributes.
*/
protected void accumulate_tree_stats()
{
//ASSERT(decisionTreeCat);
totalNodesNum += num_nontrivial_nodes();
if(class_id() == ID3_INDUCER || class_id() == SGI_DT_INDUCER)
totalAttr +=
decisionTreeCat.rooted_cat_graph().num_attr(TS.num_attr());
callCount++;
}
/** Copies the option settings from the given TDDTInducer.
* @param inducer The TDDTInducer with the options to be copied.
*/
public void copy_options(TDDTInducer inducer)
{
// Copy the continuous attributes flag
set_have_continuous_attributes(inducer.get_have_continuous_attributes());
logOptions.set_log_options(inducer.logOptions.get_log_options());
set_max_level(inducer.get_max_level());
set_lower_bound_min_split_weight(
inducer.get_lower_bound_min_split_weight());
set_upper_bound_min_split_weight(
inducer.get_upper_bound_min_split_weight());
set_min_split_weight_percent(inducer.get_min_split_weight_percent());
set_nominal_lbound_only(inducer.get_nominal_lbound_only());
set_debug(inducer.get_debug());
set_unknown_edges(inducer.get_unknown_edges());
set_split_score_criterion(inducer.get_split_score_criterion());
set_empty_node_parent_dist(inducer.get_empty_node_parent_dist());
set_parent_tie_breaking(inducer.get_parent_tie_breaking());
set_pruning_method(inducer.get_pruning_method());
set_pruning_branch_replacement(inducer.get_pruning_branch_replacement());
set_adjust_thresholds(inducer.get_adjust_thresholds());
set_pruning_factor(inducer.get_pruning_factor());
set_cont_mdl_adjust(inducer.get_cont_mdl_adjust());
set_smooth_inst(inducer.get_smooth_inst());
set_smooth_factor(inducer.get_smooth_factor());
set_leaf_dist_type(inducer.get_leaf_dist_type());
set_m_estimate_factor(inducer.get_m_estimate_factor());
set_evidence_factor(inducer.get_evidence_factor());
set_evaluation_metric(inducer.get_evaluation_metric());
}
/** Sets the user options according to the option file.
* @param prefix The prefix for the option names.
*/
public void set_user_options(String prefix){
tddtOptions.maxLevel = getEnv.get_option_int(prefix + "MAX_LEVEL",
tddtOptions.maxLevel, MAX_LEVEL_HELP, true, false);
/* tddtOptions.pruningMethod = get_option_enum(prefix + "PRUNING_METHOD",
pruningMethodEnum,tddtOptions.pruningMethod, PRUNING_METHOD_HELP, false);
*/ if (tddtOptions.pruningMethod != TDDTInducer.none) {
tddtOptions.pruningFactor = getEnv.get_option_real(prefix + "PRUNING_FACTOR",
tddtOptions.pruningFactor, PRUNING_FACTOR_HELP, true, false);
}
tddtOptions.lowerBoundMinSplitWeight = getEnv.get_option_real(
prefix + "LBOUND_MIN_SPLIT",tddtOptions.lowerBoundMinSplitWeight,
LB_MSW_HELP, true);
tddtOptions.upperBoundMinSplitWeight = getEnv.get_option_real(
prefix + "UBOUND_MIN_SPLIT",Math.max(tddtOptions.upperBoundMinSplitWeight,
tddtOptions.lowerBoundMinSplitWeight),UB_MSW_HELP, true);
if (tddtOptions.upperBoundMinSplitWeight <
tddtOptions.lowerBoundMinSplitWeight)
Error.fatalErr("TDDTInducer.set_user_options: upper bound must be >= "
+"lower bound");
tddtOptions.minSplitWeightPercent =
getEnv.get_option_real(prefix + "MIN_SPLIT_WEIGHT",
tddtOptions.minSplitWeightPercent, MS_WP_HELP, true);
tddtOptions.nominalLBoundOnly =
getEnv.get_option_bool(prefix + "NOMINAL_LBOUND_ONLY",
tddtOptions.nominalLBoundOnly, NOM_LBO_HELP);
tddtOptions.debug =
getEnv.get_option_bool(prefix + "DEBUG",
tddtOptions.debug, DEBUG_HELP, true);
tddtOptions.unknownEdges =
getEnv.get_option_bool(prefix + "UNKNOWN_EDGES",
tddtOptions.unknownEdges, UNKNOWN_EDGES_HELP, true);
/* tddtOptions.splitScoreCriterion =
getEnv.get_option_enum(prefix + "SPLIT_BY", SplitScore.splitScoreCriterionEnum,
tddtOptions.splitScoreCriterion,
SplitScore.splitScoreCriterionHelp, true);
// The following is a rare option. It may be supported by uncommenting
// the following. It allows an empty node to get the distribution
// of the parent, which is what C4.5 does. However, in C4.5,
// you can determine that it's an empty node by looking at the
// instance count, which is what we use when reading C4.5 trees.
//tddtOptions.emptyNodeParentDist =
// get_option_bool(prefix + "EMPTY_NODE_PARENT_DIST",
// tddtOptions.emptyNodeParentDist, EMPTY_NODE_PARENT_DIST_HELP,
// false);
*/ tddtOptions.parentTieBreaking =
getEnv.get_option_bool(prefix + "PARENT_TIE_BREAKING",
tddtOptions.parentTieBreaking,
PARENT_TIE_BREAKING_HELP, false);
// @@ tddtOptions.pruningBranchReplacement =
// get_option_bool(prefix + "PRUNING_BRANCH_REPLACEMENT",
// tddtOptions.pruningBranchReplacement,
// PRUNING_BRANCH_REPLACEMENT_HELP,
// false);
tddtOptions.pruningBranchReplacement = false;
tddtOptions.adjustThresholds =
getEnv.get_option_bool(prefix + "ADJUST_THRESHOLDS",
tddtOptions.adjustThresholds, ADJUST_THRESHOLDS_HELP,
false);
tddtOptions.contMDLAdjust =
getEnv.get_option_bool(prefix + "CONT_MDL_ADJUST",
tddtOptions.contMDLAdjust, CONT_MDL_ADJUST_HELP);
tddtOptions.smoothInst = getEnv.get_option_int(prefix + "SMOOTH_INST",
tddtOptions.smoothInst,
SMOOTH_INST_HELP);
if (tddtOptions.smoothInst != 0)
/* tddtOptions.smoothFactor = getEnv.get_option_real(prefix + "SMOOTH_FACTOR",
tddtOptions.smoothFactor,
SMOOTH_FACTOR_HELP);
tddtOptions.leafDistType = getEnv.get_option_enum(prefix + "LEAF_DIST_TYPE",
leafDistTypeMEnum,
tddtOptions.leafDistType,
LEAF_DIST_TYPE_HELP,
false);
*/ if(tddtOptions.leafDistType == laplaceCorrection)
tddtOptions.MEstimateFactor =
getEnv.get_option_real(prefix + "M_ESTIMATE_FACTOR",
tddtOptions.MEstimateFactor,
M_ESTIMATE_FACTOR_HELP, true);
else if(tddtOptions.leafDistType == evidenceProjection)
tddtOptions.evidenceFactor =
getEnv.get_option_real(prefix + "EVIDENCE_FACTOR",
tddtOptions.evidenceFactor,
EVIDENCE_FACTOR_HELP, true);
/* tddtOptions.evaluationMetric =
getEnv.get_option_enum(prefix + "EVAL_METRIC", evalMetricEnum,
tddtOptions.evaluationMetric, EVAL_METRIC_HELP, false);
*/
}
/** Trains the inducer on a stored dataset and collects statistics.
*/
public void train()
{
train_no_stats();
accumulate_tree_stats();
}
/** Trains the inducer on a stored dataset. No statistical data is
* collected for the test of the categorizer.
* @return The number of attributes.
*/
public int train_no_stats()
{
//IFLOG(3, display_options(get_log_stream()));
has_data();
//DBG(OK());
//ASSERT(get_level() == 0); //should never be modified for user created
//inducers.
decisionTreeCat = null; //remove any existing tree categorizer.
//OK must be ignored until done. otherwise
//we get a freed-memory read (???)
boolean usedAutoLBoundMinSplit = false;
if(tddtOptions.lowerBoundMinSplitWeight == 0.0) {
usedAutoLBoundMinSplit = true;
tddtOptions.lowerBoundMinSplitWeight =
Entropy.auto_lbound_min_split(TS.total_weight());
logOptions.LOG(2, "Auto-setting lbound minSplit to "
+ tddtOptions.lowerBoundMinSplitWeight);
}
boolean foundReal = false;
Schema schema = TS.get_schema(); //SchemaRC
// Checking for real(continuous) attributes. -JL
for(int i=0;i<schema.num_attr() && !foundReal; i++)
foundReal = schema.attr_info(i).can_cast_to_real();
set_have_continuous_attributes(foundReal);
// The decision Tree either creates a new graph, or gets ours.
DecisionTree dTree = null;
if(cgraph!=null)
dTree = new DecisionTree(cgraph);
else dTree = new DecisionTree();
// Display the training InstanceList. -JL
logOptions.LOG(4, "Training with instance list\n" + instance_list().out(false));
set_total_inst_weight(TS.total_weight());
boolean saveDribble = GlobalOptions.dribble;
if(TS.num_instances() <= MIN_INSTLIST_DRIBBLE)
GlobalOptions.dribble = false;
else
logOptions.DRIBBLE("There are over " + MIN_INSTLIST_DRIBBLE
+" instances in training set. Showing progress");
DoubleRef numErrors =new DoubleRef(0);
DoubleRef pessimisticErrors =new DoubleRef(0);
IntRef numLeaves =new IntRef(0);
int[] tieBreakingOrder = TS.get_distribution_order();
// Induce_decision_tree returns the root of the created decision tree.
// The root is set into this TDDTInducer.
dTree.set_root(induce_decision_tree(dTree.get_graph(), tieBreakingOrder,
numErrors, pessimisticErrors,
numLeaves,-1));
tieBreakingOrder = null;
logOptions.DRIBBLE("Decision tree classifier has been induced\n");
GlobalOptions.dribble = saveDribble;
if (usedAutoLBoundMinSplit)
tddtOptions.lowerBoundMinSplitWeight = 0.0;
int numAttr = -1;
if (class_id() == ID3_INDUCER || class_id() == SGI_DT_INDUCER)
numAttr = dTree.num_attr(TS.num_attr());
// Creating a reusable categorizer from the induced tree. -JL
build_categorizer(dTree);
if (errorprune) {prune(dTree, dTree.get_root(), instance_list(), true);}
// MLJ.ASSERT(dTree == null,"TDDTInducer.train_no_stats: dTree != null.");
// Display some information about the induced tree. -JL
if (numAttr != -1)
logOptions.LOG(1, "Tree has " + num_nontrivial_nodes() + " nodes, "
+ num_nontrivial_leaves()
+ " leaves, and " + numAttr + " attributes."+'\n');
RootedCatGraph rcg = decisionTreeCat.rooted_cat_graph();
//IFLOG(1, show_hist(get_log_stream(), rcg));
if (get_adjust_thresholds())
decisionTreeCat.adjust_real_thresholds(instance_list());
return numAttr;
}
/** Returns the number of Nodes that contain no Instances and have only uknown edges
* leading to them.
* @return The number of Nodes that contain no Instances and have only uknown edges
* leading to them.
*/
public int num_nontrivial_nodes()
{
was_trained(true);
return decisionTreeCat.num_nontrivial_nodes();
}
/** Returns the number of leaves that contain no Instances and have only uknown edges
* leading to them.
* @return The number of leaves that contain no Instances and have only uknown edges
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -