📄 decisiontree.java
字号:
// If the selected test attribute is continuous, record the split ranks as well if(testAttribute instanceof ContinuousAttribute) { winSplitIndex = splitRank[gainIndex]; winPreSplitIndex = preSplitRank[gainIndex]; } } gainIndex ++; attrIndex ++; } // If no test attribute is selected if(testAttribute != null) { testAttrInfo[0] = bestAttrIndex; // If the test attribute is continuous, record its cutRank if(testAttribute instanceof ContinuousAttribute){ testAttrInfo[1] = attributeDelegates[bestAttrIndex].findCutRank(winSplitIndex, winPreSplitIndex); } // Return the test attribute object } return testAttribute; } } TreeBuilder builder = new TreeBuilder(); } /** * Prune the built decision tree. */ public void prune(){ class TreePruner { // The sequence of the cases used for tree construction private int[] cases; // The weight of each case used for tree construction private float[] weight; /** * Initialize a tree pruner which prunes the built decision tree */ TreePruner() { // ReInitialize the data sequence and their weight int caseCount = dataSet.getCaseCount(); this.cases = new int[caseCount]; for(int i = 0; i < cases.length; i ++) cases[i] = i; this.weight = new float[caseCount]; Arrays.fill(weight, 1.0f); // Reset the cases and weight array of all attributes delegate objects for(AttributeDelegate attributeDelegate : attributeDelegates){ attributeDelegate.setCasesWeight(cases, weight); } float errorAfterPrune = ebpPrune(root, 0, caseCount, true); } /** * Prune the decision tree from top to bottom with EBP strategy. * @param node the current tree node to be pruned * @param first the start(inclusive) index of the train data used for pruning. * @param last the end(exclusive) index of the train data used for pruning. * @param update whether the current pruning is a trial to retrieve the error * after pruning (update = false) or an actual pruning (update = true). * @return the estimated error after completing pruning the subtree started from * the current tree node. */ private float ebpPrune (TreeNode node, int first, int last, boolean update) { TreeNodeContent content = createContent(first, last, node); float estimatedLeafError = content.getErrorAsLeafNode(); // If this is an actual pruning instead of an error-estimation, reset the tree node information if(update) node.setContent(content); InternalNode internalNode; // If the current tree node is a Leaf, its pruning is finished if(node instanceof LeafNode) { return estimatedLeafError; } else { internalNode = (InternalNode)node; } /* Begin to estimate the errors of each branch to get the errorAsInternalNode of the tree node */ // The estimated test error of the tree node as an InternalNode float estimatedTreeError = 0; // The branch index with the maximal weight distribution int maxBranch = -1; // The maximal branch weight float maxBranchWeight = 0; // The index of the test attribute on the tree node int testAttributeIndex = indexOf(internalNode.getTestAttribute().getName(), dataSet.getMetaData().getAttributeNames()); AttributeDelegate testAttributeDelegate = attributeDelegates[testAttributeIndex]; int testBranchCount = testAttributeDelegate.getBranchCount(); // Record the class weight distribution of the selected test attribute float[] branchDistri = new float[testBranchCount+1]; /* 'missingBegin' records the begin index of the missing data if there is any, * otherwise it coordinates with beginIndex; * 'groupBegin' records the begin index to group the cases for one branch * 'nextGroupBegin' records the begin index group the cases for next branch */ int missingBegin = first; int groupBegin = first; // Group the missing data to the most front if(testAttributeDelegate.hasMissingData()) { groupBegin = testAttributeDelegate.groupForward(first, last, -1, branchDistri); } // Classify the [first last) cases to the branches of the test attribute // except for the last branch, to construct the children tree nodes for(int index = 0; index < testBranchCount; index ++) { // For a continuous attribute, the group criterion is cutRank; // For a discrete attribute, the group criterion is the branch value(or index) int split = testAttributeDelegate instanceof ContinuousAttributeDelegate ? internalNode.getCutRank() : index; // For the first several branches, we need to group the specified branch values forward // near "groupBegin" and compute its branch weight int nextGroupBegin; if(index < testBranchCount - 1){ nextGroupBegin = testAttributeDelegate.groupForward(groupBegin, last, split, branchDistri); } // For the last branch, the "nextGroupBegin" must be last and its branch weight must be // the rest weight of the total weight. else{ nextGroupBegin = last; float lastWeight = content.getTrainWeight(); for(int j = 0; j < branchDistri.length-1; j ++) { lastWeight -= branchDistri[j]; } branchDistri[branchDistri.length-1] = lastWeight; } // If there is no cases distributed in this branch, omit if(groupBegin == nextGroupBegin){ continue; } // If there is missing data else if(groupBegin > missingBegin){ // Compute the weight ratio of this branch float ratio = branchDistri[index+1]/(content.getTrainWeight() - branchDistri[0]); // split the weight of the missing data with by multiplying the ratio for(int i = missingBegin; i < groupBegin; i ++) weight[cases[i]] *= ratio; // Accumulate the estimated errorAsInternalNode estimatedTreeError += ebpPrune(internalNode.getChildAt(index), missingBegin, nextGroupBegin, update); // Restore the original sequence of the cases after the recursive construction missingBegin = testAttributeDelegate.groupBackward(missingBegin, nextGroupBegin); // Restore the weight of the missing data with by dividing the ratio for(int i = missingBegin; i < nextGroupBegin; i ++) weight[cases[i]] /= ratio; } else{ estimatedTreeError += ebpPrune(internalNode.getChildAt(index), missingBegin, nextGroupBegin, update); //When there is no missing data, missingBegin moves together with groupBegin missingBegin = nextGroupBegin; } // For next branch, group from nextGroupBegin index groupBegin = nextGroupBegin; // Select the biggest branch with maximal weight for branchError estimation if(branchDistri[index+1] > maxBranchWeight) { maxBranchWeight = branchDistri[index+1]; maxBranch = index; } } // If this sentence is not present, it will lead to significant pruning! // Do not evaluate doubled subtree raising (i.e. subtree-raising of subtree-raising) if(!update) return estimatedTreeError; // Estimate the subtree-raising error float estimatedBranchError = ebpPrune(internalNode.getChildAt(maxBranch), first, last, false); TreeNode parent = (InternalNode)internalNode.getParent(); // Select a strategy with the minimal error if(estimatedLeafError <= estimatedBranchError + 0.1 && estimatedLeafError <= estimatedTreeError + 0.1) { LeafNode newNode = new LeafNode(content); if(parent != null) { int childIndex = parent.indexOfChild(internalNode); parent.setChildAt(childIndex, newNode); } else setRoot(newNode); node = newNode; } else if(estimatedBranchError <= estimatedTreeError + 0.1) { ebpPrune(internalNode.getChildAt(maxBranch), first, last, true); TreeNode newNode = node.getChildAt(maxBranch); if(parent != null) { int childIndex = parent.indexOfChild(internalNode); parent.setChildAt(childIndex, newNode); } else setRoot(newNode); node = newNode; } else { node.setTrainError(estimatedTreeError); } return node.getTrainError(); } /** * Recreate a tree node content with the specified data based on the tree node's * existing content. * * @param first the start(inclusive) index of the train data used for creating * the tree node content. * @param last the end(exclusive) index of the train data used for creating the * tree node content. * @param node the tree node whose content need to be recreated. * @return the recreated tree node content. */ private TreeNodeContent createContent(int first, int last, TreeNode node) { // Compute the total weight and its class distribution of [first last) prune cases float totalWeight = 0; AttributeDelegate classAttributeDelegate = attributeDelegates[dataSet.getClassAttributeIndex()]; float[] totalClassDistri = new float[dataSet.getClassCount()]; Arrays.fill(totalClassDistri, 0); for(int i = first ; i < last; i ++) { int classLabel = classAttributeDelegate.getClassBranch(cases[i]); totalClassDistri[classLabel] += weight[cases[i]]; } // Find the original classification of the tree node String nodeClassification = node.getContent().getClassification(); String[] classValues = dataSet.getClassValues(); int maxClassIndex = indexOf(nodeClassification, classValues); // Find the most probable classification of the prune data on the current tree node for(int i = 0; i < totalClassDistri.length; i ++) { totalWeight += totalClassDistri[i]; if(totalClassDistri[i] > totalClassDistri[maxClassIndex]) maxClassIndex = i; } String classification = classValues[maxClassIndex]; // Estimate the leafError of the tree node with the [first last) prune data float basicLeafError = totalWeight - totalClassDistri[maxClassIndex]; float extraLeafError = Estimator.getExtraError(totalWeight, basicLeafError); float estimatedLeafError = basicLeafError + extraLeafError; return new TreeNodeContent(totalWeight, totalClassDistri, classification, estimatedLeafError); } } TreePruner pruner = new TreePruner(); } /** * Find the index of a String value in a String array. */ private int indexOf(String target, String[] from){ for(int i = 0; i < from.length; i ++) { if(from[i].equals(target)) return i; } return -1; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -