📄 bftree.java
字号:
// if expansion is specified (if pruning method used) if ( (m_PruningStrategy == PRUNING_PREPRUNING) || (m_PruningStrategy == PRUNING_POSTPRUNING) || (preExpansion != -1)) m_Expansion++; makeSuccessors(BestFirstElements,data,subsetIndices,subsetWeights,dists, att,useHeuristic, useGini); } // choose next node to split if (BestFirstElements.size()!=0) { FastVector nextSplitElement = (FastVector)BestFirstElements.elementAt(0); BFTree nextSplitNode = (BFTree)nextSplitElement.elementAt(0); nextSplitNode.makeTree(BestFirstElements,data, nextSplitNode.m_SortedIndices, nextSplitNode.m_Weights, nextSplitNode.m_Dists, nextSplitNode.m_ClassProbs, nextSplitNode.m_TotalWeight, nextSplitNode.m_Props, minNumObj, useHeuristic, useGini, preExpansion); } } } /** * This method is to find the number of expansions based on internal * cross-validation for just pre-pruning. It expands the first BestFirst * node in the BestFirstElements if it is expansible, otherwise it looks * for next exapansible node. If it finds a node is expansibel, expand the * node, then return true. (note it just expands one node at a time). * * @param BestFirstElements list to store BFTree nodes * @param root root node of tree in each fold * @param train training data * @param sortedIndices sorted indices of the instances * @param weights weights of the instances * @param dists class distributions for each attribute * @param classProbs class probabilities of this node * @param totalWeight total weight of this node (note if the node * can not split, this value is not calculated.) * @param branchProps proportions of two subbranches * @param minNumObj minimal number of instances at leaf nodes * @param useHeuristic if use heuristic search for nominal attributes * in multi-class problem * @param useGini if use Gini index as splitting criterion * @return true if expand successfully, otherwise return false * (all nodes in BestFirstElements cannot be * expanded). * @throws Exception if something goes wrong */ protected boolean makeTree(FastVector BestFirstElements, BFTree root, Instances train, int[][] sortedIndices, double[][] weights, double[][][] dists, double[] classProbs, double totalWeight, double[] branchProps, int minNumObj, boolean useHeuristic, boolean useGini) throws Exception { if (BestFirstElements.size()==0) return false; /////////////////////////////////////////////////////////////////////// // All information about the node to split (first BestFirst object in // BestFirstElements) FastVector firstElement = (FastVector)BestFirstElements.elementAt(0); // node to split BFTree nodeToSplit = (BFTree)firstElement.elementAt(0); // split attribute Attribute att = (Attribute)firstElement.elementAt(1); // info of split value or split string double splitValue = Double.NaN; String splitStr = null; if (att.isNumeric()) splitValue = ((Double)firstElement.elementAt(2)).doubleValue(); else { splitStr=((String)firstElement.elementAt(2)).toString(); } // the best gini gain or information gain of this node double gain = ((Double)firstElement.elementAt(3)).doubleValue(); /////////////////////////////////////////////////////////////////////// // If no enough data to split for this node or this node can not be split find next node to split. if (totalWeight < 2*minNumObj || branchProps[0]==0 || branchProps[1]==0) { // remove the first element BestFirstElements.removeElementAt(0); nodeToSplit.makeLeaf(train); BFTree nextNode = (BFTree) ((FastVector)BestFirstElements.elementAt(0)).elementAt(0); return root.makeTree(BestFirstElements, root, train, nextNode.m_SortedIndices, nextNode.m_Weights, nextNode.m_Dists, nextNode.m_ClassProbs, nextNode.m_TotalWeight, nextNode.m_Props, minNumObj, useHeuristic, useGini); } // If gini gain or information is 0, make all nodes in the BestFirstElements leaf nodes // because these node sorted descendingly according to gini gain or information gain. // (namely, gini gain or information gain of all nodes in BestFirstEelements is 0). if (gain==0) { for (int i=0; i<BestFirstElements.size(); i++) { FastVector element = (FastVector)BestFirstElements.elementAt(i); BFTree node = (BFTree)element.elementAt(0); node.makeLeaf(train); } BestFirstElements.removeAllElements(); return false; } else { // remove the first element BestFirstElements.removeElementAt(0); nodeToSplit.m_Attribute = att; if (att.isNumeric()) nodeToSplit.m_SplitValue = splitValue; else nodeToSplit.m_SplitString = splitStr; int[][][] subsetIndices = new int[2][train.numAttributes()][0]; double[][][] subsetWeights = new double[2][train.numAttributes()][0]; splitData(subsetIndices, subsetWeights, nodeToSplit.m_Attribute, nodeToSplit.m_SplitValue, nodeToSplit.m_SplitString, nodeToSplit.m_SortedIndices, nodeToSplit.m_Weights, train); // if split will generate node(s) which has total weights less than m_minNumObj, // do not split int attIndex = att.index(); if (subsetIndices[0][attIndex].length<minNumObj || subsetIndices[1][attIndex].length<minNumObj) { nodeToSplit.makeLeaf(train); BFTree nextNode = (BFTree) ((FastVector)BestFirstElements.elementAt(0)).elementAt(0); return root.makeTree(BestFirstElements, root, train, nextNode.m_SortedIndices, nextNode.m_Weights, nextNode.m_Dists, nextNode.m_ClassProbs, nextNode.m_TotalWeight, nextNode.m_Props, minNumObj, useHeuristic, useGini); } // split the node else { nodeToSplit.m_isLeaf = false; nodeToSplit.m_Attribute = att; nodeToSplit.makeSuccessors(BestFirstElements,train,subsetIndices, subsetWeights,dists, nodeToSplit.m_Attribute,useHeuristic,useGini); for (int i=0; i<2; i++){ nodeToSplit.m_Successors[i].makeLeaf(train); } return true; } } } /** * This method is to find the number of expansions based on internal * cross-validation for just post-pruning. It expands the first BestFirst * node in the BestFirstElements until no node can be split. When building * the tree, stroe error for each temporary tree, namely for each expansion. * * @param BestFirstElements list to store BFTree nodes * @param root root node of tree in each fold * @param train training data in each fold * @param test test data in each fold * @param modelError list to store error for each expansion in * each fold * @param sortedIndices sorted indices of the instances * @param weights weights of the instances * @param dists class distributions for each attribute * @param classProbs class probabilities of this node * @param totalWeight total weight of this node (note if the node * can not split, this value is not calculated.) * @param branchProps proportions of two subbranches * @param minNumObj minimal number of instances at leaf nodes * @param useHeuristic if use heuristic search for nominal attributes * in multi-class problem * @param useGini if use Gini index as splitting criterion * @param useErrorRate if use error rate in internal cross-validation * @throws Exception if something goes wrong */ protected void makeTree(FastVector BestFirstElements, BFTree root, Instances train, Instances test, FastVector modelError, int[][] sortedIndices, double[][] weights, double[][][] dists, double[] classProbs, double totalWeight, double[] branchProps, int minNumObj, boolean useHeuristic, boolean useGini, boolean useErrorRate) throws Exception { if (BestFirstElements.size()==0) return; /////////////////////////////////////////////////////////////////////// // All information about the node to split (first BestFirst object in // BestFirstElements) FastVector firstElement = (FastVector)BestFirstElements.elementAt(0); // node to split //BFTree nodeToSplit = (BFTree)firstElement.elementAt(0); // split attribute Attribute att = (Attribute)firstElement.elementAt(1); // info of split value or split string double splitValue = Double.NaN; String splitStr = null; if (att.isNumeric()) splitValue = ((Double)firstElement.elementAt(2)).doubleValue(); else { splitStr=((String)firstElement.elementAt(2)).toString(); } // the best gini gain or information of this node double gain = ((Double)firstElement.elementAt(3)).doubleValue(); /////////////////////////////////////////////////////////////////////// if (totalWeight < 2*minNumObj || branchProps[0]==0 || branchProps[1]==0) { // remove the first element BestFirstElements.removeElementAt(0); makeLeaf(train); BFTree nextSplitNode = (BFTree) ((FastVector)BestFirstElements.elementAt(0)).elementAt(0); nextSplitNode.makeTree(BestFirstElements, root, train, test, modelError, nextSplitNode.m_SortedIndices, nextSplitNode.m_Weights, nextSplitNode.m_Dists, nextSplitNode.m_ClassProbs, nextSplitNode.m_TotalWeight, nextSplitNode.m_Props, minNumObj, useHeuristic, useGini, useErrorRate); return; } // If gini gain or information gain is 0, make all nodes in the BestFirstElements leaf nodes // because these node sorted descendingly according to gini gain or information gain. // (namely, gini gain or information gain of all nodes in BestFirstEelements is 0). if (gain==0) { for (int i=0; i<BestFirstElements.size(); i++) { FastVector element = (FastVector)BestFirstElements.elementAt(i); BFTree node = (BFTree)element.elementAt(0); node.makeLeaf(train); } BestFirstElements.removeAllElements(); } // gini gain or information gain is not 0 else { // remove the first element BestFirstElements.removeElementAt(0); m_Attribute = att; if (att.isNumeric()) m_SplitValue = splitValue; else m_SplitString = splitStr; int[][][] subsetIndices = new int[2][train.numAttributes()][0]; double[][][] subsetWeights = new double[2][train.numAttributes()][0]; splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitValue, m_SplitString, sortedIndices, weights, train); // if split will generate node(s) which has total weights less than m_minNumObj, // do not split int attIndex = att.index(); if (subsetIndices[0][attIndex].length<minNumObj || subsetIndices[1][attIndex].length<minNumObj) { makeLeaf(train); } // split the node and cauculate error rate of this temporary tree else { m_isLeaf = false; m_Attribute = att; makeSuccessors(BestFirstElements,train,subsetIndices, subsetWeights,dists, m_Attribute, useHeuristic, useGini); for (int i=0; i<2; i++){ m_Successors[i].makeLeaf(train); } Evaluation eval = new Evaluation(test); eval.evaluateModel(root, test); double error; if (useErrorRate) error = eval.errorRate(); else error = eval.rootMeanSquaredError(); modelError.addElement(new Double(error)); } if (BestFirstElements.size()!=0) { FastVector nextSplitElement = (FastVector)BestFirstElements.elementAt(0); BFTree nextSplitNode = (BFTree)nextSplitElement.elementAt(0); nextSplitNode.makeTree(BestFirstElements, root, train, test, modelError, nextSplitNode.m_SortedIndices, nextSplitNode.m_Weights, nextSplitNode.m_Dists, nextSplitNode.m_ClassProbs, nextSplitNode.m_TotalWeight, nextSplitNode.m_Props, minNumObj, useHeuristic, useGini,useErrorRate); } } } /** * Generate successor nodes for a node and put them into BestFirstElements * according to gini gain or information gain in a descending order. * * @param BestFirstElements list to store BestFirst nodes * @param data training instance * @param subsetSortedIndices sorted indices of instances of successor nodes * @param subsetWeights weights of instances of successor nodes * @param dists class distributions of successor nodes * @param att attribute used to split the node * @param useHeuristic if use heuristic search for nominal attributes in multi-class problem * @param useGini if use Gini index as splitting criterion * @throws Exception if something goes wrong */ protected void makeSuccessors(FastVector BestFirstElements,Instances data, int[][][] subsetSortedIndices, double[][][] subsetWeights, double[][][] dists, Attribute att, boolean useHeuristic, boolean useGini) throws Exception { m_Successors = new BFTree[2]; for (int i=0; i<2; i++) { m_Successors[i] = new BFTree(); m_Successors[i].m_isLeaf = true; // class probability and distribution for this successor node m_Successors[i].m_ClassProbs = new double[data.numClasses()]; m_Successors[i].m_Distribution = new double[data.numClasses()]; System.arraycopy(dists[att.index()][i], 0, m_Successors[i].m_ClassProbs, 0,m_Successors[i].m_ClassProbs.length); System.arraycopy(dists[att.index()][i], 0, m_Successors[i].m_Distribution, 0,m_Successors[i].m_Distribution.length); if (Utils.sum(m_Successors[i].m_ClassProbs)!=0) Utils.normalize(m_Successors[i].m_ClassProbs); // split information for this successor node double[][] props = new double[data.numAttributes()][2]; double[][][] subDists = new double[data.numAttributes()][2][data.numClasses()]; double[][] totalSubsetWeights = new double[data.numAttributes()][2]; FastVector splitInfo = m_Successors[i].computeSplitInfo(m_Successors[i], data, subsetSortedIndices[i], subsetWeights[i], subDists, props, totalSubsetWeights, useHeuristic, useGini); // branch proportion for this successor node int splitIndex = ((Attribute)splitInfo.elementAt(1)).index(); m_Successors[i].m_Props = new double[2]; System.arraycopy(props[splitIndex], 0, m_Successors[i].m_Props, 0, m_Successors[i].m_Props.length); // sorted indices and weights of each attribute for this successor node m_Successors[i].m_SortedIndices = new int[data.numAttributes()][0]; m_Successors[i].m_Weights = new double[data.numAttributes()][0]; for (int j=0; j<m_Successors[i].m_SortedIndices.length; j++) { m_Successors[i].m_SortedIndices[j] = subsetSortedIndices[i][j]; m_Successors[i].m_Weights[j] = subsetWeights[i][j]; } // distribution of each attribute for this successor node m_Successors[i].m_Dists = new double[data.numAttributes()][2][data.numClasses()]; for (int j=0; j<subDists.length; j++) { m_Successors[i].m_Dists[j] = subDists[j]; } // total weights for this successor node. m_Successors[i].m_TotalWeight = Utils.sum(totalSubsetWeights[splitIndex]); // insert this successor node into BestFirstElements according to gini gain or information gain // descendingly if (BestFirstElements.size()==0) { BestFirstElements.addElement(splitInfo); } else { double gGain = ((Double)(splitInfo.elementAt(3))).doubleValue(); int vectorSize = BestFirstElements.size(); FastVector lastNode = (FastVector)BestFirstElements.elementAt(vectorSize-1);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -