📄 adtree.java
字号:
/** * Calculates the Z-value from the rows of a weight distribution array. * * @param distribution the weight distribution * @return the Z-value */ private double conditionedZOnRows(double [][] distribution) { double w1 = distribution[0][0] + 1.0; double w2 = distribution[0][1] + 1.0; double w3 = distribution[1][0] + 1.0; double w4 = distribution[1][1] + 1.0; double wRemainder = m_trainTotalWeight + 4.0 - (w1 + w2 + w3 + w4); return (2.0 * (Math.sqrt(w1 * w2) + Math.sqrt(w3 * w4))) + wRemainder; } /** * Returns the class probability distribution for an instance. * * @param instance the instance to be classified * @return the distribution the tree generates for the instance */ public double[] distributionForInstance(Instance instance) { double predVal = predictionValueForInstance(instance, m_root, 0.0); double[] distribution = new double[2]; distribution[0] = 1.0 / (1.0 + Math.pow(Math.E, predVal)); distribution[1] = 1.0 / (1.0 + Math.pow(Math.E, -predVal)); return distribution; } /** * Returns the class prediction value (vote) for an instance. * * @param inst the instance * @param currentNode the root of the tree to get the values from * @param currentValue the current value before adding the value contained in the * subtree * @return the class prediction value (vote) */ protected double predictionValueForInstance(Instance inst, PredictionNode currentNode, double currentValue) { currentValue += currentNode.getValue(); for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) { Splitter split = (Splitter) e.nextElement(); int branch = split.branchInstanceGoesDown(inst); if (branch >= 0) currentValue = predictionValueForInstance(inst, split.getChildForBranch(branch), currentValue); } return currentValue; } /** * Returns a description of the classifier. * * @return a string containing a description of the classifier */ public String toString() { if (m_root == null) return ("ADTree not built yet"); else { return ("Alternating decision tree:\n\n" + toString(m_root, 1) + "\nLegend: " + legend() + "\nTree size (total number of nodes): " + numOfAllNodes(m_root) + "\nLeaves (number of predictor nodes): " + numOfPredictionNodes(m_root) ); } } /** * Traverses the tree, forming a string that describes it. * * @param currentNode the current node under investigation * @param level the current level in the tree * @return the string describing the subtree */ protected String toString(PredictionNode currentNode, int level) { StringBuffer text = new StringBuffer(); text.append(": " + Utils.doubleToString(currentNode.getValue(),3)); for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) { Splitter split = (Splitter) e.nextElement(); for (int j=0; j<split.getNumOfBranches(); j++) { PredictionNode child = split.getChildForBranch(j); if (child != null) { text.append("\n"); for (int k = 0; k < level; k++) { text.append("| "); } text.append("(" + split.orderAdded + ")"); text.append(split.attributeString(m_trainInstances) + " " + split.comparisonString(j, m_trainInstances)); text.append(toString(child, level + 1)); } } } return text.toString(); } /** * Returns graph describing the tree. * * @return the graph of the tree in dotty format * @exception Exception if something goes wrong */ public String graph() throws Exception { StringBuffer text = new StringBuffer(); text.append("digraph ADTree {\n"); graphTraverse(m_root, text, 0, 0, m_trainInstances); return text.toString() +"}\n"; } /** * Traverses the tree, graphing each node. * * @param currentNode the currentNode under investigation * @param text the string built so far * @param splitOrder the order the parent splitter was added to the tree * @param predOrder the order this predictor was added to the split * @exception Exception if something goes wrong */ protected void graphTraverse(PredictionNode currentNode, StringBuffer text, int splitOrder, int predOrder, Instances instances) throws Exception { text.append("S" + splitOrder + "P" + predOrder + " [label=\""); text.append(Utils.doubleToString(currentNode.getValue(),3)); if (splitOrder == 0) // show legend in root text.append(" (" + legend() + ")"); text.append("\" shape=box style=filled"); if (instances.numInstances() > 0) text.append(" data=\n" + instances + "\n,\n"); text.append("]\n"); for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) { Splitter split = (Splitter) e.nextElement(); text.append("S" + splitOrder + "P" + predOrder + "->" + "S" + split.orderAdded + " [style=dotted]\n"); text.append("S" + split.orderAdded + " [label=\"" + split.orderAdded + ": " + split.attributeString(m_trainInstances) + "\"]\n"); for (int i=0; i<split.getNumOfBranches(); i++) { PredictionNode child = split.getChildForBranch(i); if (child != null) { text.append("S" + split.orderAdded + "->" + "S" + split.orderAdded + "P" + i + " [label=\"" + split.comparisonString(i, m_trainInstances) + "\"]\n"); graphTraverse(child, text, split.orderAdded, i, split.instancesDownBranch(i, instances)); } } } } /** * Returns the legend of the tree, describing how results are to be interpreted. * * @return a string containing the legend of the classifier */ public String legend() { Attribute classAttribute = null; if (m_trainInstances == null) return ""; try {classAttribute = m_trainInstances.classAttribute();} catch (Exception x){}; return ("-ve = " + classAttribute.value(0) + ", +ve = " + classAttribute.value(1)); } /** * @return a description of the classifier suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Builds an alternating decision tree, optimized for 2-class problems only."; } /** * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numOfBoostingIterationsTipText() { return "Sets the number of boosting iterations to perform. You will need to manually " + "tune this parameter to suit the dataset and the desired complexity/accuracy " + "tradeoff. More boosting iterations will result in larger (potentially more " + " accurate) trees, but will make learning slower. Each iteration will add 3 nodes " + "(1 split + 2 prediction) to the tree unless merging occurs."; } /** * Gets the number of boosting iterations. * * @return the number of boosting iterations */ public int getNumOfBoostingIterations() { return m_boostingIterations; } /** * Sets the number of boosting iterations. * * @param b the number of boosting iterations to use */ public void setNumOfBoostingIterations(int b) { m_boostingIterations = b; } /** * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String searchPathTipText() { return "Sets the type of search to perform when building the tree. The default option" + " (Expand all paths) will do an exhaustive search. The other search methods are" + " heuristic, so they are not guaranteed to find an optimal solution but they are" + " much faster. Expand the heaviest path: searches the path with the most heavily" + " weighted instances. Expand the best z-pure path: searches the path determined" + " by the best z-pure estimate. Expand a random path: the fastest method, simply" + " searches down a single random path on each iteration."; } /** * Gets the method of searching the tree for a new insertion. Will be one of * SEARCHPATH_ALL, SEARCHPATH_HEAVIEST, SEARCHPATH_ZPURE, SEARCHPATH_RANDOM. * * @return the tree searching mode */ public SelectedTag getSearchPath() { return new SelectedTag(m_searchPath, TAGS_SEARCHPATH); } /** * Sets the method of searching the tree for a new insertion. Will be one of * SEARCHPATH_ALL, SEARCHPATH_HEAVIEST, SEARCHPATH_ZPURE, SEARCHPATH_RANDOM. * * @param newMethod the new tree searching mode */ public void setSearchPath(SelectedTag newMethod) { if (newMethod.getTags() == TAGS_SEARCHPATH) { m_searchPath = newMethod.getSelectedTag().getID(); } } /** * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String randomSeedTipText() { return "Sets the random seed to use for a random search."; } /** * Gets random seed for a random walk. * * @return the random seed */ public int getRandomSeed() { return m_randomSeed; } /** * Sets random seed for a random walk. * * @param s the random seed */ public void setRandomSeed(int seed) { // the actual random object is created when the tree is initialized m_randomSeed = seed; } /** * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String saveInstanceDataTipText() { return "Sets whether the tree is to save instance data - the model will take up more" + " memory if it does. If enabled you will be able to visualize the instances at" + " the prediction nodes when visualizing the tree."; } /** * Gets whether the tree is to save instance data. * * @return the random seed */ public boolean getSaveInstanceData() { return m_saveInstanceData; } /** * Sets whether the tree is to save instance data. * * @param s the random seed */ public void setSaveInstanceData(boolean v) { m_saveInstanceData = v; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options */ public Enumeration listOptions() { Vector newVector = new Vector(3); newVector.addElement(new Option( "\tNumber of boosting iterations.\n" +"\t(Default = 10)", "B", 1,"-B <number of boosting iterations>")); newVector.addElement(new Option( "\tExpand nodes: -3(all), -2(weight), -1(z_pure), " +">=0 seed for random walk\n" +"\t(Default = -3)", "E", 1,"-E <-3|-2|-1|>=0>")); newVector.addElement(new Option( "\tSave the instance data with the model", "D", 0,"-D")); return newVector.elements(); } /** * Parses a given list of options. Valid options are:<p> * * -B num <br> * Set the number of boosting iterations * (default 10) <p> * * -E num <br> * Set the nodes to expand: -3(all), -2(weight), -1(z_pure), >=0 seed for random walk
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -