⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 decisiontree.java

📁 Fast implementation of C4/5 in Java
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
/** * @(#)DecisionTree.java        1.5.0 09/01/18 */package ml.classifier.dt;import java.util.Arrays;import ml.dataset.DataSet;import ml.dataset.Attribute;import ml.dataset.ContinuousAttribute;import ml.dataset.DiscreteAttribute;import ml.classifier.TreeClassifier;import ml.tree.*;/** * A decision tree built with C4.5 algorithm. * * @author Ping He * @author Xiaohua Xu */public class DecisionTree implements TreeClassifier {    // The root of the decision tree	private TreeNode root;	// The DataSet contains all the information in the input files	private DataSet dataSet;	// The delegate of attributes to assist tree building and tree pruning	private AttributeDelegate[] attributeDelegates;	/**	 * Build a decision tree with the specified data set	 */	public DecisionTree(DataSet dataSet) {		this.dataSet = dataSet;		build();		root.setName(dataSet.getName());	}	public int size() {        return treeSize(root);	}	/**	 * Compute the number of tree nodes in the subtree started from the specified tree node.	 */    private int treeSize(TreeNode root) {		if(root instanceof LeafNode) return 1;		int sum = 0;		int childrenCount = root.getChildrenCount();		for(int i = 0; i < childrenCount; i ++) {			sum += treeSize(root.getChildAt(i));		}		return sum + 1;	}	public int getTrainError() {		return getTestError(dataSet.getTrainData());	}	public TreeNode getRoot() {		return root;	}	public void setRoot(TreeNode root) {		this.root = root;	}	public int getTestError(String[][] testData){		String[] classificationResults = classify(testData);		int testError = 0;		int classAttributeIndex = dataSet.getClassAttributeIndex();		for(int i = 0; i < classificationResults.length; i ++)			if(!classificationResults[i].equals(testData[i][classAttributeIndex]))				testError ++;		return testError;	}	public String[] classify(String[][] testData) {		// Ready to record the classification results		String[] results = new String[testData.length];		// Get the number of different class values		int numberOfClasses = dataSet.getClassCount();		String[] classValues = dataSet.getClassValues();		// Initialize the test error		float[] testClassDistribution = new float[numberOfClasses];		TreeNode node = root;		for(int testIndex = 0; testIndex < testData.length; testIndex ++) {			// Classify the test data into a specific class			// Initialize the probability of the test data belonging to each class value as 0			Arrays.fill(testClassDistribution, 0.0f);			// Classify a single test data from top to bottom			classifyDownward(root, testData[testIndex], testClassDistribution, 1.0f);			// Select the branch whose probability is the greatest as the classification of the test data			float max = -1.0f;		    int maxIndex = -1;			for(int i = 0; i < testClassDistribution.length; i ++) {				if(testClassDistribution[i] > max) {					maxIndex = i;					max = testClassDistribution[i];				}			}			results[testIndex] = classValues[maxIndex];		}		return results;	}	/**	 * Classify a test data from top to bottom from one tree node to its offspring (if there is any).	 * @param node the current tree node classify the test data	 * @param record the test data with its attribute values extracted	 * @param testClassDistribution actually the output of this method, recording the weight	 *               distribution of the test data in different class values.	 * @param weight the weight of the test data on the current tree node	 */	private void classifyDownward(TreeNode node, String[] record, float[] testClassDistribution, float weight){		TreeNodeContent content = node.getContent();		if(node instanceof LeafNode) {			// If there is no train data distributed on this tree node,			// then add the weight of the test data to its corresponding class branch			if(content.getTrainWeight() <= 0){				// Get the branch index of the tree node's classification				int classificationIndex = indexOf(content.getClassification(), dataSet.getClassValues());				testClassDistribution[classificationIndex] += weight;			}			// Otherwise, distribute the weight of the test data with the coefficient			// of trainClassDistri[classValueIndex]/trainWeight			else {				float[] trainClassDistribution = content.getTrainClassDistribution();			    for(int i = 0; i < testClassDistribution.length; i ++){			    	testClassDistribution[i] += weight * trainClassDistribution[i]/content.getTrainWeight();			    }			}		}		// If the current tree node is an InternalNode		else {			// if the test attribute value of the test data is not missing, then			// pass it to its child tree node for classification.			Attribute testAttribute = ((InternalNode)node).getTestAttribute();		    int testAttributeIndex = indexOf(testAttribute.getName(), dataSet.getMetaData().getAttributeNames());			if(!record[testAttributeIndex].equals("?")) {			    int branchIndex = findChildBranch(record[testAttributeIndex], (InternalNode)node);			    classifyDownward(node.getChildAt(branchIndex), record, testClassDistribution, weight);			}			/* If the test attribute value of the test data is missing or not exists in declaration,			 * the test data is then passed to all the children tree nodes with the partitioned weight			 * of (weight*children[childindex].getTrainWeight()/trainWeight)			 */			else {				TreeNode[] children = node.getChildren();				for(int i = 0; i < children.length; i ++) {					TreeNodeContent childContent = children[i].getContent();					float childWeight = (float)weight*childContent.getTrainWeight()/content.getTrainWeight();					classifyDownward(children[i], record, testClassDistribution, childWeight);				}			}		}	}	/**	 * Find the branch index of the child tree node to which the parent tree node should	 * classify the test data to.	 * @param value the attribute value of the test data on the parent tree node's test attribute	 * @param node the parent tree node which need to classify the test data to its offspring.	 */	private int findChildBranch(String value, InternalNode node) {		Attribute testAttribute = node.getTestAttribute();		// If the test attribute is continuous, find the branch of the test data		// belong to by comparing its test attribute value and the cut value.		if(testAttribute instanceof ContinuousAttribute) {			float continValue = Float.parseFloat(value);			return (continValue < (node.getCut() + Parameter.PRECISION)) ? 0 : 1;		}		else{			// If the test attribute is discrete, find the branch whose value is			// the same as the test attribute value of the test data			String[] nominalValues = ((DiscreteAttribute)testAttribute).getNominalValues();			for(int i = 0; i < nominalValues.length; i ++) {				if(nominalValues[i].equals(value)) return i;			}			// Not Found the test attribute value			return -1;		}	}	/**	 * Build a decision tree.	 */	public void build() {		class TreeBuilder {			// The sequence of the cases used for tree construction			private int[] cases;			// The weight of each case used for tree construction			private float[] weight;			// The number of the candidate test attributes			private int candidateTestAttrCount;			// Whether the attributes are candidate for test attribute selection			private boolean[] isCandidateTestAttr;			/**			 * Initialize a tree builder which build a decision tree.			 */			TreeBuilder() {				// Create Attribute Delegate objects				attributeDelegates = new AttributeDelegate[dataSet.getAttributeCount()];				int attributeIndex = 0;				for(Attribute attribute : dataSet.getAttributes()){					if(attribute instanceof ContinuousAttribute)						attributeDelegates[attributeIndex] = new ContinuousAttributeDelegate((ContinuousAttribute)attribute);				    else				    	attributeDelegates[attributeIndex] = new DiscreteAttributeDelegate((DiscreteAttribute)attribute);				    attributeIndex ++;				}		    	// Initialize the qualification of candidate test attributes				candidateTestAttrCount = dataSet.getAttributeCount()-1;				this.isCandidateTestAttr = new boolean[dataSet.getAttributeCount()];				Arrays.fill(isCandidateTestAttr, true);				isCandidateTestAttr[dataSet.getClassAttributeIndex()] = false;				// Initialize the data sequence and their weight				initializeCasesWeight();				root = constructTreeNode(0, dataSet.getCaseCount());			}			/**			 * Initialize the sequence of the train data from 1 to n, and initialize their			 * weight with all 1.0.			 */			void initializeCasesWeight(){				int caseCount = dataSet.getCaseCount();				this.cases = new int[caseCount];				for(int i = 0; i < cases.length; i ++) cases[i] = i;				this.weight = new float[caseCount];				Arrays.fill(weight, 1.0f);				// All the attribute delegates share the same cases and weight array				for(AttributeDelegate attributeDelegate : attributeDelegates){					attributeDelegate.setCasesWeight(cases, weight);				}			}			/**			 * Construct tree node from top to bottom.			 * @param first the start(inclusive) index of the train data used for tree node			 *              construction.			 * @param last the end(exclusive) index of the train data used for tree node			 *             construction.			 * @return the constructed tree node.			 */		    private TreeNode constructTreeNode (int first, int last) {		    	// Construct an initial Leaf tree node		    	TreeNodeContent content = createContent(first, last);		    	float errorAsLeafNode = content.getErrorAsLeafNode();				// If any of the leaf conditions is satisfied, return the Leaf tree node		        if(content.satisfyLeafNode(Parameter.MINWEIGHT) || candidateTestAttrCount <= 0) {		        	return new LeafNode(content);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -