📄 reptree.java
字号:
* @param weights the weights of the instances * @param data the data to work with * @param totalWeight * @param classProbs the class probabilities * @param header the header of the data * @param minNum the minimum number of instances in a leaf * @param minVariance * @param depth the current depth of the tree * @param maxDepth the maximum allowed depth of the tree * @throws Exception if generation fails */ protected void buildTree(int[][] sortedIndices, double[][] weights, Instances data, double totalWeight, double[] classProbs, Instances header, double minNum, double minVariance, int depth, int maxDepth) throws Exception { // Store structure of dataset, set minimum number of instances // and make space for potential info from pruning data m_Info = header; m_HoldOutDist = new double[data.numClasses()]; // Make leaf if there are no training instances int helpIndex = 0; if (data.classIndex() == 0) { helpIndex = 1; } if (sortedIndices[helpIndex].length == 0) { if (data.classAttribute().isNumeric()) { m_Distribution = new double[2]; } else { m_Distribution = new double[data.numClasses()]; } m_ClassProbs = null; return; } double priorVar = 0; if (data.classAttribute().isNumeric()) { // Compute prior variance double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0; for (int i = 0; i < sortedIndices[helpIndex].length; i++) { Instance inst = data.instance(sortedIndices[helpIndex][i]); totalSum += inst.classValue() * weights[helpIndex][i]; totalSumSquared += inst.classValue() * inst.classValue() * weights[helpIndex][i]; totalSumOfWeights += weights[helpIndex][i]; } priorVar = singleVariance(totalSum, totalSumSquared, totalSumOfWeights); } // Check if node doesn't contain enough instances, is pure // or the maximum tree depth is reached m_ClassProbs = new double[classProbs.length]; System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length); if ((totalWeight < (2 * minNum)) || // Nominal case (data.classAttribute().isNominal() && Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)], Utils.sum(m_ClassProbs))) || // Numeric case (data.classAttribute().isNumeric() && ((priorVar / totalWeight) < minVariance)) || // Check tree depth ((m_MaxDepth >= 0) && (depth >= maxDepth))) { // Make leaf m_Attribute = -1; if (data.classAttribute().isNominal()) { // Nominal case m_Distribution = new double[m_ClassProbs.length]; for (int i = 0; i < m_ClassProbs.length; i++) { m_Distribution[i] = m_ClassProbs[i]; } Utils.normalize(m_ClassProbs); } else { // Numeric case m_Distribution = new double[2]; m_Distribution[0] = priorVar; m_Distribution[1] = totalWeight; } return; } // Compute class distributions and value of splitting // criterion for each attribute double[] vals = new double[data.numAttributes()]; double[][][] dists = new double[data.numAttributes()][0][0]; double[][] props = new double[data.numAttributes()][0]; double[][] totalSubsetWeights = new double[data.numAttributes()][0]; double[] splits = new double[data.numAttributes()]; if (data.classAttribute().isNominal()) { // Nominal case for (int i = 0; i < data.numAttributes(); i++) { if (i != data.classIndex()) { splits[i] = distribution(props, dists, i, sortedIndices[i], weights[i], totalSubsetWeights, data); vals[i] = gain(dists[i], priorVal(dists[i])); } } } else { // Numeric case for (int i = 0; i < data.numAttributes(); i++) { if (i != data.classIndex()) { splits[i] = numericDistribution(props, dists, i, sortedIndices[i], weights[i], totalSubsetWeights, data, vals); } } } // Find best attribute m_Attribute = Utils.maxIndex(vals); int numAttVals = dists[m_Attribute].length; // Check if there are at least two subsets with // required minimum number of instances int count = 0; for (int i = 0; i < numAttVals; i++) { if (totalSubsetWeights[m_Attribute][i] >= minNum) { count++; } if (count > 1) { break; } } // Any useful split found? if ((vals[m_Attribute] > 0) && (count > 1)) { // Build subtrees m_SplitPoint = splits[m_Attribute]; m_Prop = props[m_Attribute]; int[][][] subsetIndices = new int[numAttVals][data.numAttributes()][0]; double[][][] subsetWeights = new double[numAttVals][data.numAttributes()][0]; splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint, sortedIndices, weights, data); m_Successors = new Tree[numAttVals]; for (int i = 0; i < numAttVals; i++) { m_Successors[i] = new Tree(); m_Successors[i]. buildTree(subsetIndices[i], subsetWeights[i], data, totalSubsetWeights[m_Attribute][i], dists[m_Attribute][i], header, minNum, minVariance, depth + 1, maxDepth); } } else { // Make leaf m_Attribute = -1; } // Normalize class counts if (data.classAttribute().isNominal()) { m_Distribution = new double[m_ClassProbs.length]; for (int i = 0; i < m_ClassProbs.length; i++) { m_Distribution[i] = m_ClassProbs[i]; } Utils.normalize(m_ClassProbs); } else { m_Distribution = new double[2]; m_Distribution[0] = priorVar; m_Distribution[1] = totalWeight; } } /** * Computes size of the tree. * * @return the number of nodes */ protected int numNodes() { if (m_Attribute == -1) { return 1; } else { int size = 1; for (int i = 0; i < m_Successors.length; i++) { size += m_Successors[i].numNodes(); } return size; } } /** * Splits instances into subsets. * * @param subsetIndices the sorted indices in the subset * @param subsetWeights the weights of the subset * @param att the attribute index * @param splitPoint the split point for numeric attributes * @param sortedIndices the sorted indices of the whole set * @param weights the weights of the whole set * @param data the data to work with * @throws Exception if something goes wrong */ protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights, int att, double splitPoint, int[][] sortedIndices, double[][] weights, Instances data) throws Exception { int j; int[] num; // For each attribute for (int i = 0; i < data.numAttributes(); i++) { if (i != data.classIndex()) { if (data.attribute(att).isNominal()) { // For nominal attributes num = new int[data.attribute(att).numValues()]; for (int k = 0; k < num.length; k++) { subsetIndices[k][i] = new int[sortedIndices[i].length]; subsetWeights[k][i] = new double[sortedIndices[i].length]; } for (j = 0; j < sortedIndices[i].length; j++) { Instance inst = data.instance(sortedIndices[i][j]); if (inst.isMissing(att)) { // Split instance up for (int k = 0; k < num.length; k++) { if (m_Prop[k] > 0) { subsetIndices[k][i][num[k]] = sortedIndices[i][j]; subsetWeights[k][i][num[k]] = m_Prop[k] * weights[i][j]; num[k]++; } } } else { int subset = (int)inst.value(att); subsetIndices[subset][i][num[subset]] = sortedIndices[i][j]; subsetWeights[subset][i][num[subset]] = weights[i][j]; num[subset]++; } } } else { // For numeric attributes num = new int[2]; for (int k = 0; k < 2; k++) { subsetIndices[k][i] = new int[sortedIndices[i].length]; subsetWeights[k][i] = new double[weights[i].length]; } for (j = 0; j < sortedIndices[i].length; j++) { Instance inst = data.instance(sortedIndices[i][j]); if (inst.isMissing(att)) { // Split instance up for (int k = 0; k < num.length; k++) { if (m_Prop[k] > 0) { subsetIndices[k][i][num[k]] = sortedIndices[i][j]; subsetWeights[k][i][num[k]] = m_Prop[k] * weights[i][j]; num[k]++; } } } else { int subset = (inst.value(att) < splitPoint) ? 0 : 1; subsetIndices[subset][i][num[subset]] = sortedIndices[i][j]; subsetWeights[subset][i][num[subset]] = weights[i][j]; num[subset]++; } } } // Trim arrays for (int k = 0; k < num.length; k++) { int[] copy = new int[num[k]]; System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]); subsetIndices[k][i] = copy; double[] copyWeights = new double[num[k]]; System.arraycopy(subsetWeights[k][i], 0, copyWeights, 0, num[k]); subsetWeights[k][i] = copyWeights; } } } } /** * Computes class distribution for an attribute. * * @param props * @param dists * @param att the attribute index * @param sortedIndices the sorted indices of the instances * @param weights the weights of the instances * @param subsetWeights the weights of the subset * @param data the data to work with * @return the split point * @throws Exception if computation fails */ protected double distribution(double[][] props, double[][][] dists, int att, int[] sortedIndices, double[] weights, double[][] subsetWeights, Instances data) throws Exception { double splitPoint = Double.NaN; Attribute attribute = data.attribute(att); double[][] dist = null; int i; if (attribute.isNominal()) { // For nominal attributes dist = new double[attribute.numValues()][data.numClasses()]; for (i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } dist[(int)inst.value(att)][(int)inst.classValue()] += weights[i]; } } else { // For numeric attributes double[][] currDist = new double[2][data.numClasses()]; dist = new double[2][data.numClasses()]; // Move all instances into second subset for (int j = 0; j < sortedIndices.length; j++) { Instance inst = data.instance(sortedIndices[j]); if (inst.isMissing(att)) { break; } currDist[1][(int)inst.classValue()] += weights[j]; } double priorVal = priorVal(currDist); System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length); // Try all possible split points double currSplit = data.instance(sortedIndices[0]).value(att); double currVal, bestVal = -Double.MAX_VALUE; for (i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } if (inst.value(att) > currSplit) { currVal = gain(currDist, priorVal); if (currVal > bestVal) { bestVal = currVal; splitPoint = (inst.value(att) + currSplit) / 2.0; for (int j = 0; j < currDist.length; j++) { System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length); } } } currSplit = inst.value(att); currDist[0][(int)inst.classValue()] += weights[i]; currDist[1][(int)inst.classValue()] -= weights[i]; } } // Compute weights props[att] = new double[dist.length]; for (int k = 0; k < props[att].length; k++) { props[att][k] = Utils.sum(dist[k]); } if (!(Utils.sum(props[att]) > 0)) { for (int k = 0; k < props[att].length; k++) { props[att][k] = 1.0 / (double)props[att].length; } } else { Utils.normalize(props[att]); } // Distribute counts while (i < sortedIndices.length) { Instance inst = data.instance(sortedIndices[i]); for (int j = 0; j < dist.length; j++) { dist[j][(int)inst.classValue()] += props[att][j] * weights[i]; } i++; } // Compute subset weights subsetWeights[att] = new double[dist.length]; for (int j = 0; j < dist.length; j++) { subsetWeights[att][j] += Utils.sum(dist[j]); } // Return distribution and split point dists[att] = dist; return splitPoint; } /** * Computes class distribution for an attribute. * * @param props * @param dists * @param att the attribute index * @param sortedIndices the sorted indices of the instances * @param weights the weights of the instances * @param subsetWeights the weights of the subset * @param data the data to work with * @param vals * @return the split point * @throws Exception if computation fails */ protected double numericDistribution(double[][] props, double[][][] dists, int att, int[] sortedIndices, double[] weights, double[][] subsetWeights, Instances data, double[] vals) throws Exception { double splitPoint = Double.NaN; Attribute attribute = data.attribute(att); double[][] dist = null; double[] sums = null; double[] sumSquared = null; double[] sumOfWeights = null; double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0; int i; if (attribute.isNominal()) { // For nominal attributes sums = new double[attribute.numValues()]; sumSquared = new double[attribute.numValues()]; sumOfWeights = new double[attribute.numValues()]; int attVal; for (i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } attVal = (int)inst.value(att); sums[attVal] += inst.classValue() * weights[i]; sumSquared[attVal] += inst.classValue() * inst.classValue() * weights[i]; sumOfWeights[attVal] += weights[i]; } totalSum = Utils.sum(sums); totalSumSquared = Utils.sum(sumSquared); totalSumOfWeights = Utils.sum(sumOfWeights); } else { // For numeric attributes sums = new double[2]; sumSquared = new double[2]; sumOfWeights = new double[2]; double[] currSums = new double[2]; double[] currSumSquared = new double[2]; double[] currSumOfWeights = new double[2]; // Move all instances into second subset for (int j = 0; j < sortedIndices.length; j++) { Instance inst = data.instance(sortedIndices[j]); if (inst.isMissing(att)) { break; } currSums[1] += inst.classValue() * weights[j]; currSumSquared[1] += inst.classValue() * inst.classValue() * weights[j]; currSumOfWeights[1] += weights[j]; } totalSum = currSums[1]; totalSumSquared = currSumSquared[1]; totalSumOfWeights = currSumOfWeights[1]; sums[1] = currSums[1]; sumSquared[1] = currSumSquared[1]; sumOfWeights[1] = currSumOfWeights[1];
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -