📄 decisiontree.java
字号:
/** * @(#)DecisionTree.java 1.5.0 09/01/18 */package ml.classifier.dt;import java.util.Arrays;import ml.dataset.DataSet;import ml.dataset.Attribute;import ml.dataset.ContinuousAttribute;import ml.dataset.DiscreteAttribute;import ml.classifier.TreeClassifier;import ml.tree.*;/** * A decision tree built with C4.5 algorithm. * * @author Ping He * @author Xiaohua Xu */public class DecisionTree implements TreeClassifier { // The root of the decision tree private TreeNode root; // The DataSet contains all the information in the input files private DataSet dataSet; // The delegate of attributes to assist tree building and tree pruning private AttributeDelegate[] attributeDelegates; /** * Build a decision tree with the specified data set */ public DecisionTree(DataSet dataSet) { this.dataSet = dataSet; build(); root.setName(dataSet.getName()); } public int size() { return treeSize(root); } /** * Compute the number of tree nodes in the subtree started from the specified tree node. */ private int treeSize(TreeNode root) { if(root instanceof LeafNode) return 1; int sum = 0; int childrenCount = root.getChildrenCount(); for(int i = 0; i < childrenCount; i ++) { sum += treeSize(root.getChildAt(i)); } return sum + 1; } public int getTrainError() { return getTestError(dataSet.getTrainData()); } public TreeNode getRoot() { return root; } public void setRoot(TreeNode root) { this.root = root; } public int getTestError(String[][] testData){ String[] classificationResults = classify(testData); int testError = 0; int classAttributeIndex = dataSet.getClassAttributeIndex(); for(int i = 0; i < classificationResults.length; i ++) if(!classificationResults[i].equals(testData[i][classAttributeIndex])) testError ++; return testError; } public String[] classify(String[][] testData) { // Ready to record the classification results String[] results = new String[testData.length]; // Get the number of different class values int numberOfClasses = dataSet.getClassCount(); String[] classValues = dataSet.getClassValues(); // Initialize the test error float[] testClassDistribution = new float[numberOfClasses]; TreeNode node = root; for(int testIndex = 0; testIndex < testData.length; testIndex ++) { // Classify the test data into a specific class // Initialize the probability of the test data belonging to each class value as 0 Arrays.fill(testClassDistribution, 0.0f); // Classify a single test data from top to bottom classifyDownward(root, testData[testIndex], testClassDistribution, 1.0f); // Select the branch whose probability is the greatest as the classification of the test data float max = -1.0f; int maxIndex = -1; for(int i = 0; i < testClassDistribution.length; i ++) { if(testClassDistribution[i] > max) { maxIndex = i; max = testClassDistribution[i]; } } results[testIndex] = classValues[maxIndex]; } return results; } /** * Classify a test data from top to bottom from one tree node to its offspring (if there is any). * @param node the current tree node classify the test data * @param record the test data with its attribute values extracted * @param testClassDistribution actually the output of this method, recording the weight * distribution of the test data in different class values. * @param weight the weight of the test data on the current tree node */ private void classifyDownward(TreeNode node, String[] record, float[] testClassDistribution, float weight){ TreeNodeContent content = node.getContent(); if(node instanceof LeafNode) { // If there is no train data distributed on this tree node, // then add the weight of the test data to its corresponding class branch if(content.getTrainWeight() <= 0){ // Get the branch index of the tree node's classification int classificationIndex = indexOf(content.getClassification(), dataSet.getClassValues()); testClassDistribution[classificationIndex] += weight; } // Otherwise, distribute the weight of the test data with the coefficient // of trainClassDistri[classValueIndex]/trainWeight else { float[] trainClassDistribution = content.getTrainClassDistribution(); for(int i = 0; i < testClassDistribution.length; i ++){ testClassDistribution[i] += weight * trainClassDistribution[i]/content.getTrainWeight(); } } } // If the current tree node is an InternalNode else { // if the test attribute value of the test data is not missing, then // pass it to its child tree node for classification. Attribute testAttribute = ((InternalNode)node).getTestAttribute(); int testAttributeIndex = indexOf(testAttribute.getName(), dataSet.getMetaData().getAttributeNames()); if(!record[testAttributeIndex].equals("?")) { int branchIndex = findChildBranch(record[testAttributeIndex], (InternalNode)node); classifyDownward(node.getChildAt(branchIndex), record, testClassDistribution, weight); } /* If the test attribute value of the test data is missing or not exists in declaration, * the test data is then passed to all the children tree nodes with the partitioned weight * of (weight*children[childindex].getTrainWeight()/trainWeight) */ else { TreeNode[] children = node.getChildren(); for(int i = 0; i < children.length; i ++) { TreeNodeContent childContent = children[i].getContent(); float childWeight = (float)weight*childContent.getTrainWeight()/content.getTrainWeight(); classifyDownward(children[i], record, testClassDistribution, childWeight); } } } } /** * Find the branch index of the child tree node to which the parent tree node should * classify the test data to. * @param value the attribute value of the test data on the parent tree node's test attribute * @param node the parent tree node which need to classify the test data to its offspring. */ private int findChildBranch(String value, InternalNode node) { Attribute testAttribute = node.getTestAttribute(); // If the test attribute is continuous, find the branch of the test data // belong to by comparing its test attribute value and the cut value. if(testAttribute instanceof ContinuousAttribute) { float continValue = Float.parseFloat(value); return (continValue < (node.getCut() + Parameter.PRECISION)) ? 0 : 1; } else{ // If the test attribute is discrete, find the branch whose value is // the same as the test attribute value of the test data String[] nominalValues = ((DiscreteAttribute)testAttribute).getNominalValues(); for(int i = 0; i < nominalValues.length; i ++) { if(nominalValues[i].equals(value)) return i; } // Not Found the test attribute value return -1; } } /** * Build a decision tree. */ public void build() { class TreeBuilder { // The sequence of the cases used for tree construction private int[] cases; // The weight of each case used for tree construction private float[] weight; // The number of the candidate test attributes private int candidateTestAttrCount; // Whether the attributes are candidate for test attribute selection private boolean[] isCandidateTestAttr; /** * Initialize a tree builder which build a decision tree. */ TreeBuilder() { // Create Attribute Delegate objects attributeDelegates = new AttributeDelegate[dataSet.getAttributeCount()]; int attributeIndex = 0; for(Attribute attribute : dataSet.getAttributes()){ if(attribute instanceof ContinuousAttribute) attributeDelegates[attributeIndex] = new ContinuousAttributeDelegate((ContinuousAttribute)attribute); else attributeDelegates[attributeIndex] = new DiscreteAttributeDelegate((DiscreteAttribute)attribute); attributeIndex ++; } // Initialize the qualification of candidate test attributes candidateTestAttrCount = dataSet.getAttributeCount()-1; this.isCandidateTestAttr = new boolean[dataSet.getAttributeCount()]; Arrays.fill(isCandidateTestAttr, true); isCandidateTestAttr[dataSet.getClassAttributeIndex()] = false; // Initialize the data sequence and their weight initializeCasesWeight(); root = constructTreeNode(0, dataSet.getCaseCount()); } /** * Initialize the sequence of the train data from 1 to n, and initialize their * weight with all 1.0. */ void initializeCasesWeight(){ int caseCount = dataSet.getCaseCount(); this.cases = new int[caseCount]; for(int i = 0; i < cases.length; i ++) cases[i] = i; this.weight = new float[caseCount]; Arrays.fill(weight, 1.0f); // All the attribute delegates share the same cases and weight array for(AttributeDelegate attributeDelegate : attributeDelegates){ attributeDelegate.setCasesWeight(cases, weight); } } /** * Construct tree node from top to bottom. * @param first the start(inclusive) index of the train data used for tree node * construction. * @param last the end(exclusive) index of the train data used for tree node * construction. * @return the constructed tree node. */ private TreeNode constructTreeNode (int first, int last) { // Construct an initial Leaf tree node TreeNodeContent content = createContent(first, last); float errorAsLeafNode = content.getErrorAsLeafNode(); // If any of the leaf conditions is satisfied, return the Leaf tree node if(content.satisfyLeafNode(Parameter.MINWEIGHT) || candidateTestAttrCount <= 0) { return new LeafNode(content);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -