📄 simplecart.java
字号:
nodeList = getInnerNodes(); prune = (nodeList.size() > 0); continue; } preAlpha = nodeToPrune.m_Alpha; //update tree errors and alphas treeErrors(); calculateAlphas(); nodeList = getInnerNodes(); prune = (nodeList.size() > 0); } } /** * Method for performing one fold in the cross-validation of minimal * cost-complexity pruning. Generates a sequence of alpha-values with error * estimates for the corresponding (partially pruned) trees, given the test * set of that fold. * * @param alphas array to hold the generated alpha-values * @param errors array to hold the corresponding error estimates * @param test test set of that fold (to obtain error estimates) * @return the iteration of the pruning * @throws Exception if something goes wrong */ public int prune(double[] alphas, double[] errors, Instances test) throws Exception { Vector nodeList; // determine training error of subtrees (both with and without replacing a subtree), // and calculate alpha-values from them modelErrors(); treeErrors(); calculateAlphas(); // get list of all inner nodes in the tree nodeList = getInnerNodes(); boolean prune = (nodeList.size() > 0); //alpha_0 is always zero (unpruned tree) alphas[0] = 0; Evaluation eval; // error of unpruned tree if (errors != null) { eval = new Evaluation(test); eval.evaluateModel(this, test); errors[0] = eval.errorRate(); } int iteration = 0; double preAlpha = Double.MAX_VALUE; while (prune) { iteration++; // get node with minimum alpha SimpleCart nodeToPrune = nodeToPrune(nodeList); // do not set m_sons null, want to unprune nodeToPrune.m_isLeaf = true; // normally would not happen if (nodeToPrune.m_Alpha==preAlpha) { iteration--; treeErrors(); calculateAlphas(); nodeList = getInnerNodes(); prune = (nodeList.size() > 0); continue; } // get alpha-value of node alphas[iteration] = nodeToPrune.m_Alpha; // log error if (errors != null) { eval = new Evaluation(test); eval.evaluateModel(this, test); errors[iteration] = eval.errorRate(); } preAlpha = nodeToPrune.m_Alpha; //update errors/alphas treeErrors(); calculateAlphas(); nodeList = getInnerNodes(); prune = (nodeList.size() > 0); } //set last alpha 1 to indicate end alphas[iteration + 1] = 1.0; return iteration; } /** * Method to "unprune" the CART tree. Sets all leaf-fields to false. * Faster than re-growing the tree because CART do not have to be fit again. */ protected void unprune() { if (m_Successors != null) { m_isLeaf = false; for (int i = 0; i < m_Successors.length; i++) m_Successors[i].unprune(); } } /** * Compute distributions, proportions and total weights of two successor * nodes for a given numeric attribute. * * @param props proportions of each two branches for each attribute * @param dists class distributions of two branches for each attribute * @param att numeric att split on * @param sortedIndices sorted indices of instances for the attirubte * @param weights weights of instances for the attirbute * @param subsetWeights total weight of two branches split based on the attribute * @param giniGains Gini gains for each attribute * @param data training instances * @return Gini gain the given numeric attribute * @throws Exception if something goes wrong */ protected double numericDistribution(double[][] props, double[][][] dists, Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights, double[] giniGains, Instances data) throws Exception { double splitPoint = Double.NaN; double[][] dist = null; int numClasses = data.numClasses(); int i; // differ instances with or without missing values double[][] currDist = new double[2][numClasses]; dist = new double[2][numClasses]; // Move all instances without missing values into second subset double[] parentDist = new double[numClasses]; int missingStart = 0; for (int j = 0; j < sortedIndices.length; j++) { Instance inst = data.instance(sortedIndices[j]); if (!inst.isMissing(att)) { missingStart ++; currDist[1][(int)inst.classValue()] += weights[j]; } parentDist[(int)inst.classValue()] += weights[j]; } System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length); // Try all possible split points double currSplit = data.instance(sortedIndices[0]).value(att); double currGiniGain; double bestGiniGain = -Double.MAX_VALUE; for (i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } if (inst.value(att) > currSplit) { double[][] tempDist = new double[2][numClasses]; for (int k=0; k<2; k++) { //tempDist[k] = currDist[k]; System.arraycopy(currDist[k], 0, tempDist[k], 0, tempDist[k].length); } double[] tempProps = new double[2]; for (int k=0; k<2; k++) { tempProps[k] = Utils.sum(tempDist[k]); } if (Utils.sum(tempProps) !=0) Utils.normalize(tempProps); // split missing values int index = missingStart; while (index < sortedIndices.length) { Instance insta = data.instance(sortedIndices[index]); for (int j = 0; j < 2; j++) { tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index]; } index++; } currGiniGain = computeGiniGain(parentDist,tempDist); if (currGiniGain > bestGiniGain) { bestGiniGain = currGiniGain; // clean split point splitPoint = Math.rint((inst.value(att) + currSplit)/2.0*100000)/100000.0; for (int j = 0; j < currDist.length; j++) { System.arraycopy(tempDist[j], 0, dist[j], 0, dist[j].length); } } } currSplit = inst.value(att); currDist[0][(int)inst.classValue()] += weights[i]; currDist[1][(int)inst.classValue()] -= weights[i]; } // Compute weights int attIndex = att.index(); props[attIndex] = new double[2]; for (int k = 0; k < 2; k++) { props[attIndex][k] = Utils.sum(dist[k]); } if (Utils.sum(props[attIndex]) != 0) Utils.normalize(props[attIndex]); // Compute subset weights subsetWeights[attIndex] = new double[2]; for (int j = 0; j < 2; j++) { subsetWeights[attIndex][j] += Utils.sum(dist[j]); } // clean Gini gain giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0; dists[attIndex] = dist; return splitPoint; } /** * Compute distributions, proportions and total weights of two successor * nodes for a given nominal attribute. * * @param props proportions of each two branches for each attribute * @param dists class distributions of two branches for each attribute * @param att numeric att split on * @param sortedIndices sorted indices of instances for the attirubte * @param weights weights of instances for the attirbute * @param subsetWeights total weight of two branches split based on the attribute * @param giniGains Gini gains for each attribute * @param data training instances * @param useHeuristic if use heuristic search * @return Gini gain for the given nominal attribute * @throws Exception if something goes wrong */ protected String nominalDistribution(double[][] props, double[][][] dists, Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights, double[] giniGains, Instances data, boolean useHeuristic) throws Exception { String[] values = new String[att.numValues()]; int numCat = values.length; // number of values of the attribute int numClasses = data.numClasses(); String bestSplitString = ""; double bestGiniGain = -Double.MAX_VALUE; // class frequency for each value int[] classFreq = new int[numCat]; for (int j=0; j<numCat; j++) classFreq[j] = 0; double[] parentDist = new double[numClasses]; double[][] currDist = new double[2][numClasses]; double[][] dist = new double[2][numClasses]; int missingStart = 0; for (int i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (!inst.isMissing(att)) { missingStart++; classFreq[(int)inst.value(att)] ++; } parentDist[(int)inst.classValue()] += weights[i]; } // count the number of values that class frequency is not 0 int nonEmpty = 0; for (int j=0; j<numCat; j++) { if (classFreq[j]!=0) nonEmpty ++; } // attribute values that class frequency is not 0 String[] nonEmptyValues = new String[nonEmpty]; int nonEmptyIndex = 0; for (int j=0; j<numCat; j++) { if (classFreq[j]!=0) { nonEmptyValues[nonEmptyIndex] = att.value(j); nonEmptyIndex ++; } } // attribute values that class frequency is 0 int empty = numCat - nonEmpty; String[] emptyValues = new String[empty]; int emptyIndex = 0; for (int j=0; j<numCat; j++) { if (classFreq[j]==0) { emptyValues[emptyIndex] = att.value(j); emptyIndex ++; } } if (nonEmpty<=1) { giniGains[att.index()] = 0; return ""; } // for tow-class probloms if (data.numClasses()==2) { //// Firstly, for attribute values which class frequency is not zero // probability of class 0 for each attribute value double[] pClass0 = new double[nonEmpty]; // class distribution for each attribute value double[][] valDist = new double[nonEmpty][2]; for (int j=0; j<nonEmpty; j++) { for (int k=0; k<2; k++) { valDist[j][k] = 0; } } for (int i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } for (int j=0; j<nonEmpty; j++) { if (att.value((int)inst.value(att)).compareTo(nonEmptyValues[j])==0) { valDist[j][(int)inst.classValue()] += inst.weight(); break; } } } for (int j=0; j<nonEmpty; j++) { double distSum = Utils.sum(valDist[j]); if (distSum==0) pClass0[j]=0; else pClass0[j] = valDist[j][0]/distSum; } // sort category according to the probability of the first class String[] sortedValues = new String[nonEmpty]; for (int j=0; j<nonEmpty; j++) { sortedValues[j] = nonEmptyValues[Utils.minIndex(pClass0)]; pClass0[Utils.minIndex(pClass0)] = Double.MAX_VALUE; } // Find a subset of attribute values that maximize Gini decrease // for the attribute values that class frequency is not 0 String tempStr = ""; for (int j=0; j<nonEmpty-1; j++) { currDist = new double[2][numClasses]; if (tempStr=="") tempStr="(" + sortedValues[j] + ")"; else tempStr += "|"+ "(" + sortedValues[j] + ")"; for (int i=0; i<sortedIndices.length;i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } if (tempStr.indexOf ("(" + att.value((int)inst.value(att)) + ")")!=-1) { currDist[0][(int)inst.classValue()] += weights[i]; } else currDist[1][(int)inst.classValue()] += weights[i]; } double[][] tempDist = new double[2][numClasses]; for (int kk=0; kk<2; kk++) { tempDist[kk] = currDist[kk]; } double[] tempProps = new double[2]; for (int kk=0; kk<2; kk++) { tempProps[kk] = Utils.sum(tempDist[kk]); } if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps); // split missing values int mstart = missingStart; while (mstart < sortedIndices.length) { Instance insta = data.instance(sortedIndices[mstart]); for (int jj = 0; jj < 2; jj++) { tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart]; } mstart++; } double currGiniGain = computeGiniGain(parentDist,tempDist); if (currGiniGain>bestGiniGain) { bestGiniGain = currGiniGain; bestSplitString = tempStr; for (int jj = 0; jj < 2; jj++) { //dist[jj] = new double[currDist[jj].length]; System.arraycopy(tempDist[jj], 0, dist[jj], 0, dist[jj].length); } } } } // multi-class problems - exhaustive search else if (!useHeuristic || nonEmpty<=4) { // Firstly, for attribute values which class frequency is not zero for (int i=0; i<(int)Math.pow(2,nonEmpty-1); i++) { String tempStr=""; currDist = new double[2][numClasses]; int mod; int bit10 = i; for (int j=nonEmpty-1; j>=0; j--) { mod = bit10%2; // convert from 10bit to 2bit if (mod==1) { if (tempStr=="") tempStr = "("+nonEmptyValues[j]+")"; else tempStr += "|" + "("+nonEmptyValues[j]+")"; } bit10 = bit10/2; } for (int j=0; j<sortedIndices.length;j++) { Instance inst = data.instance(sortedIndices[j]); if (inst.isMissing(att)) { break; } if (tempStr.indexOf("("+att.value((int)inst.value(att))+")")!=-1) { currDist[0][(int)inst.classValue()] += weights[j]; } else currDist[1][(int)inst.classValue()] += weights[j]; } double[][] tempDist = new double[2][numClasses]; for (int k=0; k<2; k++) { tempDist[k] = currDist[k]; } double[] tempProps = new double[2]; for (int k=0; k<2; k++) { tempProps[k] = Utils.sum(tempDist[k]); } if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps); // split missing values int index = missingStart; while (index < sortedIndices.length) { Instance insta = data.instance(sortedIndices[index]); for (int j = 0; j < 2; j++) { tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index]; } index++; } double currGiniGain = computeGiniGain(parentDist,tempDist); if (currGiniGain>bestGiniGain) { bestGiniGain = currGiniGain; bestSplitString = tempStr; for (int j = 0; j < 2; j++) { //dist[jj] = new double[currDist[jj].length]; System.arraycopy(tempDist[j], 0, dist[j], 0, dist[j].length); } } } } // huristic search to solve multi-classes problems else { // Firstly, for attribute values which class frequency is not zero int n = nonEmpty; int k = data.numClasses(); // number of classes of the data double[][] P = new double[n][k]; // class probability matrix int[] numInstancesValue = new int[n]; // number of instances for an attribute value double[] meanClass = new double[k]; // vector of mean class probability int numInstances = data.numInstances(); // total number of instances // initialize the vector of mean class probability for (int j=0; j<meanClass.length; j++) meanClass[j]=0; for (int j=0; j<numInstances; j++) { Instance inst = (Instance)data.instance(j); int valueIndex = 0; // attribute value index in nonEmptyValues for (int i=0; i<nonEmpty; i++) { if (att.value((int)inst.value(att)).compareToIgnoreCase(nonEmptyValues[i])==0){ valueIndex = i; break; } } P[valueIndex][(int)inst.classValue()]++; numInstancesValue[valueIndex]++; meanClass[(int)inst.classValue()]++; } // calculate the class probability matrix for (int i=0; i<P.length; i++) { for (int j=0; j<P[0].length; j++) { if (numInstancesValue[i]==0) P[i][j]=0; else P[i][j]/=numInstancesValue[i]; }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -