📄 bftree.java
字号:
for (int i = 0; i < m_numFoldsPruning; i++) { train[i] = cvData.trainCV(m_numFoldsPruning, i); test[i] = cvData.testCV(m_numFoldsPruning, i); parallelBFElements[i] = new FastVector(); m_roots[i] = new BFTree(); // calculate sorted indices, weights, initial class counts and total weights for each training data totalWeight[i] = computeSortedInfo(train[i],sortedIndices[i], weights[i], classProbs[i]); // compute information of the best split for this node (include split attribute, // split value and gini gain (or information gain)) in this fold nodeInfo[i] = computeSplitInfo(m_roots[i], train[i], sortedIndices[i], weights[i], dists[i], props[i], totalSubsetWeights[i], m_Heuristic, m_UseGini); // compute information for root nodes int attIndex = ((Attribute)nodeInfo[i].elementAt(1)).index(); m_roots[i].m_SortedIndices = new int[sortedIndices[i].length][0]; m_roots[i].m_Weights = new double[weights[i].length][0]; m_roots[i].m_Dists = new double[dists[i].length][0][0]; m_roots[i].m_ClassProbs = new double[classProbs[i].length]; m_roots[i].m_Distribution = new double[classProbs[i].length]; m_roots[i].m_Props = new double[2]; for (int j=0; j<m_roots[i].m_SortedIndices.length; j++) { m_roots[i].m_SortedIndices[j] = sortedIndices[i][j]; m_roots[i].m_Weights[j] = weights[i][j]; m_roots[i].m_Dists[j] = dists[i][j]; } System.arraycopy(classProbs[i], 0, m_roots[i].m_ClassProbs, 0, classProbs[i].length); if (Utils.sum(m_roots[i].m_ClassProbs)!=0) Utils.normalize(m_roots[i].m_ClassProbs); System.arraycopy(classProbs[i], 0, m_roots[i].m_Distribution, 0, classProbs[i].length); System.arraycopy(props[i][attIndex], 0, m_roots[i].m_Props, 0, props[i][attIndex].length); m_roots[i].m_TotalWeight = totalWeight[i]; parallelBFElements[i].addElement(nodeInfo[i]); } // build a pre-pruned tree if (m_PruningStrategy == PRUNING_PREPRUNING) { double previousError = Double.MAX_VALUE; double currentError = previousError; double minError = Double.MAX_VALUE; int minExpansion = 0; FastVector errorList = new FastVector(); while(true) { // compute average error double expansionError = 0; int count = 0; for (int i=0; i<m_numFoldsPruning; i++) { Evaluation eval; // calculate error rate if only root node if (expansion==0) { m_roots[i].m_isLeaf = true; eval = new Evaluation(test[i]); eval.evaluateModel(m_roots[i], test[i]); if (m_UseErrorRate) expansionError += eval.errorRate(); else expansionError += eval.rootMeanSquaredError(); count ++; } // make tree - expand one node at a time else { if (m_roots[i] == null) continue; // if the tree cannot be expanded, go to next fold m_roots[i].m_isLeaf = false; BFTree nodeToSplit = (BFTree) (((FastVector)(parallelBFElements[i].elementAt(0))).elementAt(0)); if (!m_roots[i].makeTree(parallelBFElements[i], m_roots[i], train[i], nodeToSplit.m_SortedIndices, nodeToSplit.m_Weights, nodeToSplit.m_Dists, nodeToSplit.m_ClassProbs, nodeToSplit.m_TotalWeight, nodeToSplit.m_Props, m_minNumObj, m_Heuristic, m_UseGini)) { m_roots[i] = null; // cannot be expanded continue; } eval = new Evaluation(test[i]); eval.evaluateModel(m_roots[i], test[i]); if (m_UseErrorRate) expansionError += eval.errorRate(); else expansionError += eval.rootMeanSquaredError(); count ++; } } // no tree can be expanded any more if (count==0) break; expansionError /=count; errorList.addElement(new Double(expansionError)); currentError = expansionError; if (!m_UseOneSE) { if (currentError>previousError) break; } else { if (expansionError < minError) { minError = expansionError; minExpansion = expansion; } if (currentError>previousError) { double oneSE = Math.sqrt(minError*(1-minError)/ data.numInstances()); if (currentError > minError + oneSE) { break; } } } expansion ++; previousError = currentError; } if (!m_UseOneSE) expansion = expansion - 1; else { double oneSE = Math.sqrt(minError*(1-minError)/data.numInstances()); for (int i=0; i<errorList.size(); i++) { double error = ((Double)(errorList.elementAt(i))).doubleValue(); if (error<=minError + oneSE) { // && counts[i]>=m_numFoldsPruning/2) { expansion = i; break; } } } } // build a postpruned tree else { FastVector[] modelError = new FastVector[m_numFoldsPruning]; // calculate error of each expansion for each fold for (int i = 0; i < m_numFoldsPruning; i++) { modelError[i] = new FastVector(); m_roots[i].m_isLeaf = true; Evaluation eval = new Evaluation(test[i]); eval.evaluateModel(m_roots[i], test[i]); double error; if (m_UseErrorRate) error = eval.errorRate(); else error = eval.rootMeanSquaredError(); modelError[i].addElement(new Double(error)); m_roots[i].m_isLeaf = false; BFTree nodeToSplit = (BFTree) (((FastVector)(parallelBFElements[i].elementAt(0))).elementAt(0)); m_roots[i].makeTree(parallelBFElements[i], m_roots[i], train[i], test[i], modelError[i],nodeToSplit.m_SortedIndices, nodeToSplit.m_Weights, nodeToSplit.m_Dists, nodeToSplit.m_ClassProbs, nodeToSplit.m_TotalWeight, nodeToSplit.m_Props, m_minNumObj, m_Heuristic, m_UseGini, m_UseErrorRate); m_roots[i] = null; } // find the expansion with minimal error rate double minError = Double.MAX_VALUE; int maxExpansion = modelError[0].size(); for (int i=1; i<modelError.length; i++) { if (modelError[i].size()>maxExpansion) maxExpansion = modelError[i].size(); } double[] error = new double[maxExpansion]; int[] counts = new int[maxExpansion]; for (int i=0; i<maxExpansion; i++) { counts[i] = 0; error[i] = 0; for (int j=0; j<m_numFoldsPruning; j++) { if (i<modelError[j].size()) { error[i] += ((Double)modelError[j].elementAt(i)).doubleValue(); counts[i]++; } } error[i] = error[i]/counts[i]; //average error for each expansion if (error[i]<minError) {// && counts[i]>=m_numFoldsPruning/2) { minError = error[i]; expansion = i; } } // the 1 SE rule choosen if (m_UseOneSE) { double oneSE = Math.sqrt(minError*(1-minError)/ data.numInstances()); for (int i=0; i<maxExpansion; i++) { if (error[i]<=minError + oneSE) { // && counts[i]>=m_numFoldsPruning/2) { expansion = i; break; } } } } // make tree on all data based on the expansion caculated // from cross-validation // calculate sorted indices, weights and initial class counts int[][] prune_sortedIndices = new int[data.numAttributes()][0]; double[][] prune_weights = new double[data.numAttributes()][0]; double[] prune_classProbs = new double[data.numClasses()]; double prune_totalWeight = computeSortedInfo(data, prune_sortedIndices, prune_weights, prune_classProbs); // compute information of the best split for this node (include split attribute, // split value and gini gain) double[][][] prune_dists = new double[data.numAttributes()][2][data.numClasses()]; double[][] prune_props = new double[data.numAttributes()][2]; double[][] prune_totalSubsetWeights = new double[data.numAttributes()][2]; FastVector prune_nodeInfo = computeSplitInfo(this, data, prune_sortedIndices, prune_weights, prune_dists, prune_props, prune_totalSubsetWeights, m_Heuristic,m_UseGini); // add the root node (with its split info) to BestFirstElements FastVector BestFirstElements = new FastVector(); BestFirstElements.addElement(prune_nodeInfo); int attIndex = ((Attribute)prune_nodeInfo.elementAt(1)).index(); m_Expansion = 0; makeTree(BestFirstElements, data, prune_sortedIndices, prune_weights, prune_dists, prune_classProbs, prune_totalWeight, prune_props[attIndex] ,m_minNumObj, m_Heuristic, m_UseGini, expansion); } /** * Recursively build a best-first decision tree. * Method for building a Best-First tree for a given number of expansions. * preExpasion is -1 means that no expansion is specified (just for a * tree without any pruning method). Pre-pruning and post-pruning methods also * use this method to build the final tree on all training data based on the * expansion calculated from internal cross-validation. * * @param BestFirstElements list to store BFTree nodes * @param data training data * @param sortedIndices sorted indices of the instances * @param weights weights of the instances * @param dists class distributions for each attribute * @param classProbs class probabilities of this node * @param totalWeight total weight of this node (note if the node * can not split, this value is not calculated.) * @param branchProps proportions of two subbranches * @param minNumObj minimal number of instances at leaf nodes * @param useHeuristic if use heuristic search for nominal attributes * in multi-class problem * @param useGini if use Gini index as splitting criterion * @param preExpansion the number of expansions the tree to be expanded * @throws Exception if something goes wrong */ protected void makeTree(FastVector BestFirstElements,Instances data, int[][] sortedIndices, double[][] weights, double[][][] dists, double[] classProbs, double totalWeight, double[] branchProps, int minNumObj, boolean useHeuristic, boolean useGini, int preExpansion) throws Exception { if (BestFirstElements.size()==0) return; /////////////////////////////////////////////////////////////////////// // All information about the node to split (the first BestFirst object in // BestFirstElements) FastVector firstElement = (FastVector)BestFirstElements.elementAt(0); // split attribute Attribute att = (Attribute)firstElement.elementAt(1); // info of split value or split string double splitValue = Double.NaN; String splitStr = null; if (att.isNumeric()) splitValue = ((Double)firstElement.elementAt(2)).doubleValue(); else { splitStr=((String)firstElement.elementAt(2)).toString(); } // the best gini gain or information gain of this node double gain = ((Double)firstElement.elementAt(3)).doubleValue(); /////////////////////////////////////////////////////////////////////// if (m_ClassProbs==null) { m_SortedIndices = new int[sortedIndices.length][0]; m_Weights = new double[weights.length][0]; m_Dists = new double[dists.length][0][0]; m_ClassProbs = new double[classProbs.length]; m_Distribution = new double[classProbs.length]; m_Props = new double[2]; for (int i=0; i<m_SortedIndices.length; i++) { m_SortedIndices[i] = sortedIndices[i]; m_Weights[i] = weights[i]; m_Dists[i] = dists[i]; } System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length); System.arraycopy(classProbs, 0, m_Distribution, 0, classProbs.length); System.arraycopy(branchProps, 0, m_Props, 0, m_Props.length); m_TotalWeight = totalWeight; if (Utils.sum(m_ClassProbs)!=0) Utils.normalize(m_ClassProbs); } // If no enough data or this node can not be split, find next node to split. if (totalWeight < 2*minNumObj || branchProps[0]==0 || branchProps[1]==0) { // remove the first element BestFirstElements.removeElementAt(0); makeLeaf(data); if (BestFirstElements.size()!=0) { FastVector nextSplitElement = (FastVector)BestFirstElements.elementAt(0); BFTree nextSplitNode = (BFTree)nextSplitElement.elementAt(0); nextSplitNode.makeTree(BestFirstElements,data, nextSplitNode.m_SortedIndices, nextSplitNode.m_Weights, nextSplitNode.m_Dists, nextSplitNode.m_ClassProbs, nextSplitNode.m_TotalWeight, nextSplitNode.m_Props, minNumObj, useHeuristic, useGini, preExpansion); } return; } // If gini gain or information gain is 0, make all nodes in the BestFirstElements leaf nodes // because these nodes are sorted descendingly according to gini gain or information gain. // (namely, gini gain or information gain of all nodes in BestFirstEelements is 0). if (gain==0 || preExpansion==m_Expansion) { for (int i=0; i<BestFirstElements.size(); i++) { FastVector element = (FastVector)BestFirstElements.elementAt(i); BFTree node = (BFTree)element.elementAt(0); node.makeLeaf(data); } BestFirstElements.removeAllElements(); } // gain is not 0 else { // remove the first element BestFirstElements.removeElementAt(0); m_Attribute = att; if (m_Attribute.isNumeric()) m_SplitValue = splitValue; else m_SplitString = splitStr; int[][][] subsetIndices = new int[2][data.numAttributes()][0]; double[][][] subsetWeights = new double[2][data.numAttributes()][0]; splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitValue, m_SplitString, sortedIndices, weights, data); // If split will generate node(s) which has total weights less than m_minNumObj, // do not split. int attIndex = att.index(); if (subsetIndices[0][attIndex].length<minNumObj || subsetIndices[1][attIndex].length<minNumObj) { makeLeaf(data); } // split the node else { m_isLeaf = false; m_Attribute = att;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -