📄 decisiontree.java
字号:
* @param node The base node where level assignment will start.
* @param baseLevel The initial level for the base Node.
*/
public void assign_subtree_levels(Node node, int baseLevel) {
// MLJ.ASSERT(baseLevel != DEFAULT_LEVEL);
NodeInfo rootInfo = cGraph.node_info(node);
logOptions.LOG(5, "Replacing level "+rootInfo.level()+" with "+baseLevel+'\n');
cGraph.node_info(node).set_level(baseLevel);
Edge iterEdge;
Edge oldEdge;
iterEdge = node.first_adj_edge();
int nextLevel;
// if (get_categorizer(node).class_id() == CLASS_MULTI_SPLIT_CATEGORIZER)
// nextLevel = baseLevel;
// else
nextLevel = baseLevel + 1;
while (iterEdge != null) {
oldEdge = iterEdge;
iterEdge = oldEdge.adj_succ();
Node childNode = oldEdge.target();
assign_subtree_levels(childNode, nextLevel);
}
}
/***************************************************************************
Display this DecisionTree.
@param
@param
@param
@param
***************************************************************************
public void display(boolean hasNodeLosses, boolean hasLossMatrix,
Writer stream, DisplayPref dp)
{
stream.write(display(hasNodeLosses, hasLossMatrix, dp));
}
/***************************************************************************
Display this DecisionTree.
@param
@param
@param
***************************************************************************
public String display(boolean hasNodeLosses, boolean hasLossMatrix,
DisplayPref dp)
{
String return_value = new String();
// Note that if the display is XStream, our virtual function gets it
// if (stream.output_type() == XStream ||
// dp.preference_type() != DisplayPref::TreeVizDisplay)
// RootedCatGraph.display(hasNodeLosses, hasLossMatrix, stream, dp);
else
{
String dataName = stream.description() + ".data";
MLCOStream data(dataName);
convertToTreeVizFormat(stream, data, dp, hasNodeLosses, hasLossMatrix);
}
}
*/
/*
/***************************************************************************
Displays the DecisionTree in TreeVizFormat.
@param
@param
@param
@param
@param
***************************************************************************
public void convertToTreeVizFormat(Writer conf, Writer data,
DisplayPref displayPref,
boolean hasNodeLosses,
boolean hasLossMatrix) throws IOException
{
Node rootNode = get_root(true);
NodeCategorizer cat = get_categorizer(rootNode);
Schema schema = cat.get_schema();
NominalAttrInfo nai = schema.nominal_label_info();
int numLabelValues = nai.num_values();
MLJ.ASSERT(numLabelValues >= 1,"DecisionTree::convertToTreeVizFormat:numLabelValues < 1");
// Avoid log of 1, which is a scale of zero, and causes division
// by zero.
double scale = MLJ.log_bin(Math.max(numLabelValues, 2));
int[] permLabels = schema.sort_labels(); // permuted labels
boolean dispBackfitDisks =
displayPref.typecast_to_treeViz().get_display_backfit_disks();
write_subtree(get_log_options(), scale, data, permLabels,
Globals.EMPTY_STRING, Globals.EMPTY_STRING, this, rootNode,
dispBackfitDisks, hasNodeLosses, hasLossMatrix);
String protectedLabelName = new String(Globals.SINGLE_QUOTE + MLJ.protect(nai.name(),"`\\")
+ Globals.SINGLE_QUOTE);
conf.write(minesetVersionStr + "\n");
conf.write("# MLC++ generated file for MineSet Tree Visualizer.\n"
+ "input {\n"
+ "\t file \"" + data.description() + "\";\n"
+ "\t options backslash on;\n"
+ "\t key string " + protectedLabelName + " {\n");
for (int i = 0; i < numLabelValues; i++)
{
conf.write("\t\t " + nai.get_value(permLabels[i]).quote());
if (i != numLabelValues - 1)
conf.write(",");
conf.write("\n");
}
permLabels = null;
conf.write("\t };\n"
+ "\t expression `Node label`[] separator ':';\n"
+ "\t string `Test attribute`;\n"
+ "\t string `Test value`;\n"
+ "\t float `Subtree weight` [" + protectedLabelName
+ "] separator ',' ;\n"
+ "\t float Percent [" + protectedLabelName + "] separator ',' ;\n");
if (dispBackfitDisks)
conf.write("\t float OriginalDist [" + protectedLabelName
+ "] separator ',' ;\n");
conf.write("\t float Purity;\n");
if (hasNodeLosses)
{
conf.write("\t float `Test-set subtree weight`;\n");
if (hasLossMatrix)
conf.write("\t float `Test-set loss`;\n"
+ "\t float `Mean loss std-dev`;\n");
else
conf.write("\t float `Test-set error`;\n"
+ "\t float `Mean err std-dev`;\n");
}
conf.write("}\n\n");
conf.write("hierarchy {\n"
+ "\t levels `Node label`;\n"
+ "\t key `Subtree weight`;\n"
+ "\t aggregate base {\n"
+ "\t\t sum `Subtree weight`;\n");
if (dispBackfitDisks)
conf.write("\t\t sum `OriginalDist`;\n");
conf.write("\t\t any Purity;\n"
+ "\t\t any `Test attribute`;\n"
+ "\t\t any `Test value`;\n");
if (hasNodeLosses)
{
conf.write("\t\t any `Test-set subtree weight`;\n");
if (hasLossMatrix)
conf.write("\t\t any `Test-set loss`;\n"
+ "\t\t any `Mean loss std-dev`;\n");
else
conf.write("\t\t any `Test-set error`;\n"
+ "\t\t any `Mean err std-dev`;\n");
}
conf.write("\t }\n"
+ "\t options organization same;\n"
+ "}\n");
// Pick the midpoint entropy color to be 3/4 versus 1/4 for two class probs.
// This just makes the color scale much better then 50, which requires
// 89% versus 11% to be the middle color.
double[] typicalMix = new double[2];
typicalMix[0] = 3;
typicalMix[1] = 1;
DoubleRef midPointEnt = new DoubleRef(100 - Entropy.entropy(typicalMix)*100 / scale);
MLJ.clamp_to_range(midPointEnt, 0, 100,
"DecisionTree::convertToTreeVizFormat: mid-point does "
+ "not clamp to range [0-100]");
MLJ.ASSERT(schema.num_label_values() > 0,"DecisionTree::"
+ "convertToTreeVizFormat:schema.num_label_values() <= 0");
// Even though nulls are never used, we want to distinguish
// them in case somebody changes anything. They're therefore hidden.
conf.write("view hierarchy landscape {\n"
+ "\t height `Subtree weight`, normalize, max 5.0;\n");
if (dispBackfitDisks)
conf.write("\t disk height `OriginalDist`;\n");
conf.write("\t base height max 2.0;\n"
+ "\t base label `Test attribute`;\n"
+ "\t line label `Test value`;\n"
+ "\t color key;\n");
// "\t base color legend label \"Purity\";\n"
// "\t base color Purity, "
// "colors \"red\" \"yellow\" \"green\""
// ", scale 0 " << midPointEnt << " 100, legend on;\n"
// "\t base color legend \"impure\" \"mixed\" \"pure\";\n"
if (hasNodeLosses)
{
double min = 0;
double max = 0;
loss_min_max(this, min, max);
if (max - min < 0.01)
max += 0.01; // Avoid cases where both are zero and we rely
// on a treeviz tiebreaker (happens in mushroom).
NodeLoss rootLoss = get_categorizer(rootNode).get_loss();
double medColor = suggest_mid(min, max, rootLoss.totalWeight,
rootLoss.totalLoss);
if (!hasLossMatrix)
{
min *= 100;
max *= 100;
medColor *=100;
}
if (hasLossMatrix)
conf.write("\t base color legend label \"Test-set loss\";\n"
+ "\t base color `Test-set loss`, ");
else
conf.write("\t base color legend label \"Test-set error\";\n"
+ "\t base color `Test-set error`, ");
conf.write("colors \"green\" \"yellow\" \"red\""
+ ", scale " + min + " " + medColor
+ " " + max + ", legend on;\n"
+ "\t base color legend \"low ("
+ MLJ.numberToString(min,2) + ")\" \"medium ("
+ MLJ.numberToString(medColor,2) + ")\" \"high ("
+ MLJ.numberToString(max,2) + ")\";\n");
}
conf.write("\t options rows 1;\n"
+ "\t options root label \"\";\n"
+ "\t options initial depth 4;\n"
// Don't show bar labels, so the level of details is far
+ "\t options lod bar label 10000;\n"
+ "\t options zero outline;\n"
+ "\t options null hidden;\n");
conf.write("\t base message \"Subtree weight:%.2f, ");
String lossMetric = hasLossMatrix ? "loss" : "error";
String shortLossMetric = hasLossMatrix ? "loss" : "err";
if (hasNodeLosses)
conf.write("test-set " + lossMetric + ":%.2f+-%.2f, "
+ " test-set weight:%.2f, ");
if (dispBackfitDisks)
conf.write("training-set weight: %.2f, ");
conf.write("purity:%.2f\", `Subtree weight`, ");
if (hasNodeLosses)
conf.write("`Test-set " + lossMetric + "`, "
+ "`Mean " + shortLossMetric + " std-dev`, "
+ "`Test-set subtree weight`, ");
if (dispBackfitDisks)
conf.write("`OriginalDist`, ");
conf.write("Purity;\n");
conf.write("\t message \"Subtree weight for label value:%.2f, percent:%.2f");
if (dispBackfitDisks)
conf.write(", training-set weight:%.2f");
conf.write("\", `Subtree weight`, Percent");
if (dispBackfitDisks)
conf.write(", `OriginalDist`");
conf.write(";\n}\n");
}
*/
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -