📄 adtree.java
字号:
for (int i = 0; i < instances.numInstances() - (numMissing + 1); i++) {
Instance inst = instances.instance(i);
Instance instPlusOne = instances.instance(i + 1);
distribution[0][(int)inst.classValue()] += inst.weight();
distribution[1][(int)inst.classValue()] -= inst.weight();
if (Utils.sm(inst.value(attIndex), instPlusOne.value(attIndex))) {
currCutPoint = (inst.value(attIndex) + instPlusOne.value(attIndex)) / 2.0;
currVal = conditionedZOnRows(distribution);
if (currVal < bestVal) {
splitPoint = currCutPoint;
bestVal = currVal;
}
}
}
double[] splitAndZ = new double[2];
splitAndZ[0] = splitPoint;
splitAndZ[1] = bestVal;
return splitAndZ;
}
/**
* Calculates the Z-value from the rows of a weight distribution array.
*
* @param distribution the weight distribution
* @return the Z-value
*/
private double conditionedZOnRows(double [][] distribution) {
double w1 = distribution[0][0] + 1.0;
double w2 = distribution[0][1] + 1.0;
double w3 = distribution[1][0] + 1.0;
double w4 = distribution[1][1] + 1.0;
double wRemainder = m_trainTotalWeight + 4.0 - (w1 + w2 + w3 + w4);
return (2.0 * (Math.sqrt(w1 * w2) + Math.sqrt(w3 * w4))) + wRemainder;
}
/**
* Returns the class probability distribution for an instance.
*
* @param instance the instance to be classified
* @return the distribution the tree generates for the instance
*/
public double[] distributionForInstance(Instance instance) {
double predVal = predictionValueForInstance(instance, m_root, 0.0);
double[] distribution = new double[2];
distribution[0] = 1.0 / (1.0 + Math.pow(Math.E, predVal));
distribution[1] = 1.0 / (1.0 + Math.pow(Math.E, -predVal));
return distribution;
}
/**
* Returns the class prediction value (vote) for an instance.
*
* @param inst the instance
* @param currentNode the root of the tree to get the values from
* @param currentValue the current value before adding the value contained in the
* subtree
* @return the class prediction value (vote)
*/
protected double predictionValueForInstance(Instance inst, PredictionNode currentNode,
double currentValue) {
currentValue += currentNode.getValue();
for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
Splitter split = (Splitter) e.nextElement();
int branch = split.branchInstanceGoesDown(inst);
if (branch >= 0)
currentValue = predictionValueForInstance(inst, split.getChildForBranch(branch),
currentValue);
}
return currentValue;
}
/**
* Returns a description of the classifier.
*
* @return a string containing a description of the classifier
*/
public String toString() {
if (m_root == null)
return ("ADTree not built yet");
else {
return ("Alternating decision tree:\n\n" + toString(m_root, 1) +
"\nLegend: " + legend() +
"\nTree size (total number of nodes): " + numOfAllNodes(m_root) +
"\nLeaves (number of predictor nodes): " + numOfPredictionNodes(m_root)
);
}
}
/**
* Traverses the tree, forming a string that describes it.
*
* @param currentNode the current node under investigation
* @param level the current level in the tree
* @return the string describing the subtree
*/
protected String toString(PredictionNode currentNode, int level) {
StringBuffer text = new StringBuffer();
text.append(": " + Utils.doubleToString(currentNode.getValue(),3));
for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
Splitter split = (Splitter) e.nextElement();
for (int j=0; j<split.getNumOfBranches(); j++) {
PredictionNode child = split.getChildForBranch(j);
if (child != null) {
text.append("\n");
for (int k = 0; k < level; k++) {
text.append("| ");
}
text.append("(" + split.orderAdded + ")");
text.append(split.attributeString(m_trainInstances) + " "
+ split.comparisonString(j, m_trainInstances));
text.append(toString(child, level + 1));
}
}
}
return text.toString();
}
/**
* Returns the type of graph this classifier
* represents.
* @return Drawable.TREE
*/
public int graphType() {
return Drawable.TREE;
}
/**
* Returns graph describing the tree.
*
* @return the graph of the tree in dotty format
* @exception Exception if something goes wrong
*/
public String graph() throws Exception {
StringBuffer text = new StringBuffer();
text.append("digraph ADTree {\n");
graphTraverse(m_root, text, 0, 0, m_trainInstances);
return text.toString() +"}\n";
}
/**
* Traverses the tree, graphing each node.
*
* @param currentNode the currentNode under investigation
* @param text the string built so far
* @param splitOrder the order the parent splitter was added to the tree
* @param predOrder the order this predictor was added to the split
* @exception Exception if something goes wrong
*/
protected void graphTraverse(PredictionNode currentNode, StringBuffer text,
int splitOrder, int predOrder, Instances instances)
throws Exception {
text.append("S" + splitOrder + "P" + predOrder + " [label=\"");
text.append(Utils.doubleToString(currentNode.getValue(),3));
if (splitOrder == 0) // show legend in root
text.append(" (" + legend() + ")");
text.append("\" shape=box style=filled");
if (instances.numInstances() > 0) text.append(" data=\n" + instances + "\n,\n");
text.append("]\n");
for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
Splitter split = (Splitter) e.nextElement();
text.append("S" + splitOrder + "P" + predOrder + "->" + "S" + split.orderAdded +
" [style=dotted]\n");
text.append("S" + split.orderAdded + " [label=\"" + split.orderAdded + ": " +
split.attributeString(m_trainInstances) + "\"]\n");
for (int i=0; i<split.getNumOfBranches(); i++) {
PredictionNode child = split.getChildForBranch(i);
if (child != null) {
text.append("S" + split.orderAdded + "->" + "S" + split.orderAdded + "P" + i +
" [label=\"" + split.comparisonString(i, m_trainInstances) + "\"]\n");
graphTraverse(child, text, split.orderAdded, i,
split.instancesDownBranch(i, instances));
}
}
}
}
/**
* Returns the legend of the tree, describing how results are to be interpreted.
*
* @return a string containing the legend of the classifier
*/
public String legend() {
Attribute classAttribute = null;
if (m_trainInstances == null) return "";
try {classAttribute = m_trainInstances.classAttribute();} catch (Exception x){};
return ("-ve = " + classAttribute.value(0) +
", +ve = " + classAttribute.value(1));
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numOfBoostingIterationsTipText() {
return "Sets the number of boosting iterations to perform. You will need to manually "
+ "tune this parameter to suit the dataset and the desired complexity/accuracy "
+ "tradeoff. More boosting iterations will result in larger (potentially more "
+ " accurate) trees, but will make learning slower. Each iteration will add 3 nodes "
+ "(1 split + 2 prediction) to the tree unless merging occurs.";
}
/**
* Gets the number of boosting iterations.
*
* @return the number of boosting iterations
*/
public int getNumOfBoostingIterations() {
return m_boostingIterations;
}
/**
* Sets the number of boosting iterations.
*
* @param b the number of boosting iterations to use
*/
public void setNumOfBoostingIterations(int b) {
m_boostingIterations = b;
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String searchPathTipText() {
return "Sets the type of search to perform when building the tree. The default option"
+ " (Expand all paths) will do an exhaustive search. The other search methods are"
+ " heuristic, so they are not guaranteed to find an optimal solution but they are"
+ " much faster. Expand the heaviest path: searches the path with the most heavily"
+ " weighted instances. Expand the best z-pure path: searches the path determined"
+ " by the best z-pure estimate. Expand a random path: the fastest method, simply"
+ " searches down a single random path on each iteration.";
}
/**
* Gets the method of searching the tree for a new insertion. Will be one of
* SEARCHPATH_ALL, SEARCHPATH_HEAVIEST, SEARCHPATH_ZPURE, SEARCHPATH_RANDOM.
*
* @return the tree searching mode
*/
public SelectedTag getSearchPath() {
return new SelectedTag(m_searchPath, TAGS_SEARCHPATH);
}
/**
* Sets the method of searching the tree for a new insertion. Will be one of
* SEARCHPATH_ALL, SEARCHPATH_HEAVIEST, SEARCHPATH_ZPURE, SEARCHPATH_RANDOM.
*
* @param newMethod the new tree searching mode
*/
public void setSearchPath(SelectedTag newMethod) {
if (newMethod.getTags() == TAGS_SEARCHPATH) {
m_searchPath = newMethod.getSelectedTag().getID();
}
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String randomSeedTipText() {
return "Sets the random seed to use for a random search.";
}
/**
* Gets random seed for a random walk.
*
* @return the random seed
*/
public int getRandomSeed() {
return m_randomSeed;
}
/**
* Sets random seed for a random walk.
*
* @param s the random seed
*/
public void setRandomSeed(int seed) {
// the actual random object is created when the tree is initialized
m_randomSeed = seed;
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String saveInstanceDataTipText() {
return "Sets whether the tree is to save instance data - the model will take up more"
+ " memory if it does. If enabled you will be able to visualize the instances at"
+ " the prediction nodes when visualizing the tree.";
}
/**
* Gets whether the tree is to save instance data.
*
* @return the random seed
*/
public boolean getSaveInstanceData() {
return m_saveInstanceData;
}
/**
* Sets whether the tree is to save instance data.
*
* @param s the random seed
*/
public void setSaveInstanceData(boolean v) {
m_saveInstanceData = v;
}
/**
* Returns an enumeration describing the available options..
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(3);
newVector.addElement(new Option(
"\tNumber of boosting iterations.\n"
+"\t(Default = 10)",
"B", 1,"-B <number of boosting iterations>"));
newVector.addElement(new Option(
"\tExpand nodes: -3(all), -2(weight), -1(z_pure), "
+">=0 seed for random walk\n"
+"\t(Default = -3)",
"E", 1,"-E <-3|-2|-1|>=0>"));
newVector.addElement(new Option(
"\tSave the instance data with the model",
"D", 0,"-D"));
return newVector.elements();
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -