📄 reptree.java
字号:
return size; } } /** * Splits instances into subsets. */ protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights, int att, double splitPoint, int[][] sortedIndices, double[][] weights, Instances data) throws Exception { int j; int[] num; // For each attribute for (int i = 0; i < data.numAttributes(); i++) { if (i != data.classIndex()) { if (data.attribute(att).isNominal()) { // For nominal attributes num = new int[data.attribute(att).numValues()]; for (int k = 0; k < num.length; k++) { subsetIndices[k][i] = new int[sortedIndices[i].length]; subsetWeights[k][i] = new double[sortedIndices[i].length]; } for (j = 0; j < sortedIndices[i].length; j++) { Instance inst = data.instance(sortedIndices[i][j]); if (inst.isMissing(att)) { // Split instance up for (int k = 0; k < num.length; k++) { if (m_Prop[k] > 0) { subsetIndices[k][i][num[k]] = sortedIndices[i][j]; subsetWeights[k][i][num[k]] = m_Prop[k] * weights[i][j]; num[k]++; } } } else { int subset = (int)inst.value(att); subsetIndices[subset][i][num[subset]] = sortedIndices[i][j]; subsetWeights[subset][i][num[subset]] = weights[i][j]; num[subset]++; } } } else { // For numeric attributes num = new int[2]; for (int k = 0; k < 2; k++) { subsetIndices[k][i] = new int[sortedIndices[i].length]; subsetWeights[k][i] = new double[weights[i].length]; } for (j = 0; j < sortedIndices[i].length; j++) { Instance inst = data.instance(sortedIndices[i][j]); if (inst.isMissing(att)) { // Split instance up for (int k = 0; k < num.length; k++) { if (m_Prop[k] > 0) { subsetIndices[k][i][num[k]] = sortedIndices[i][j]; subsetWeights[k][i][num[k]] = m_Prop[k] * weights[i][j]; num[k]++; } } } else { int subset = (inst.value(att) < splitPoint) ? 0 : 1; subsetIndices[subset][i][num[subset]] = sortedIndices[i][j]; subsetWeights[subset][i][num[subset]] = weights[i][j]; num[subset]++; } } } // Trim arrays for (int k = 0; k < num.length; k++) { int[] copy = new int[num[k]]; System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]); subsetIndices[k][i] = copy; double[] copyWeights = new double[num[k]]; System.arraycopy(subsetWeights[k][i], 0, copyWeights, 0, num[k]); subsetWeights[k][i] = copyWeights; } } } } /** * Computes class distribution for an attribute. */ protected double distribution(double[][] props, double[][][] dists, int att, int[] sortedIndices, double[] weights, double[][] subsetWeights, Instances data) throws Exception { double splitPoint = Double.NaN; Attribute attribute = data.attribute(att); double[][] dist = null; int i; if (attribute.isNominal()) { // For nominal attributes dist = new double[attribute.numValues()][data.numClasses()]; for (i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } dist[(int)inst.value(att)][(int)inst.classValue()] += weights[i]; } } else { // For numeric attributes double[][] currDist = new double[2][data.numClasses()]; dist = new double[2][data.numClasses()]; // Move all instances into second subset for (int j = 0; j < sortedIndices.length; j++) { Instance inst = data.instance(sortedIndices[j]); if (inst.isMissing(att)) { break; } currDist[1][(int)inst.classValue()] += weights[j]; } double priorVal = priorVal(currDist); System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length); // Try all possible split points double currSplit = data.instance(sortedIndices[0]).value(att); double currVal, bestVal = -Double.MAX_VALUE; for (i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } if (inst.value(att) > currSplit) { currVal = gain(currDist, priorVal); if (currVal > bestVal) { bestVal = currVal; splitPoint = (inst.value(att) + currSplit) / 2.0; for (int j = 0; j < currDist.length; j++) { System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length); } } } currSplit = inst.value(att); currDist[0][(int)inst.classValue()] += weights[i]; currDist[1][(int)inst.classValue()] -= weights[i]; } } // Compute weights props[att] = new double[dist.length]; for (int k = 0; k < props[att].length; k++) { props[att][k] = Utils.sum(dist[k]); } if (!(Utils.sum(props[att]) > 0)) { for (int k = 0; k < props[att].length; k++) { props[att][k] = 1.0 / (double)props[att].length; } } else { Utils.normalize(props[att]); } // Distribute counts while (i < sortedIndices.length) { Instance inst = data.instance(sortedIndices[i]); for (int j = 0; j < dist.length; j++) { dist[j][(int)inst.classValue()] += props[att][j] * weights[i]; } i++; } // Compute subset weights subsetWeights[att] = new double[dist.length]; for (int j = 0; j < dist.length; j++) { subsetWeights[att][j] += Utils.sum(dist[j]); } // Return distribution and split point dists[att] = dist; return splitPoint; } /** * Computes class distribution for an attribute. */ protected double numericDistribution(double[][] props, double[][][] dists, int att, int[] sortedIndices, double[] weights, double[][] subsetWeights, Instances data, double[] vals) throws Exception { double splitPoint = Double.NaN; Attribute attribute = data.attribute(att); double[][] dist = null; double[] sums = null; double[] sumSquared = null; double[] sumOfWeights = null; double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0; int i; if (attribute.isNominal()) { // For nominal attributes sums = new double[attribute.numValues()]; sumSquared = new double[attribute.numValues()]; sumOfWeights = new double[attribute.numValues()]; int attVal; for (i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } attVal = (int)inst.value(att); sums[attVal] += inst.classValue() * weights[i]; sumSquared[attVal] += inst.classValue() * inst.classValue() * weights[i]; sumOfWeights[attVal] += weights[i]; } totalSum = Utils.sum(sums); totalSumSquared = Utils.sum(sumSquared); totalSumOfWeights = Utils.sum(sumOfWeights); } else { // For numeric attributes sums = new double[2]; sumSquared = new double[2]; sumOfWeights = new double[2]; double[] currSums = new double[2]; double[] currSumSquared = new double[2]; double[] currSumOfWeights = new double[2]; // Move all instances into second subset for (int j = 0; j < sortedIndices.length; j++) { Instance inst = data.instance(sortedIndices[j]); if (inst.isMissing(att)) { break; } currSums[1] += inst.classValue() * weights[j]; currSumSquared[1] += inst.classValue() * inst.classValue() * weights[j]; currSumOfWeights[1] += weights[j]; } totalSum = currSums[1]; totalSumSquared = currSumSquared[1]; totalSumOfWeights = currSumOfWeights[1]; sums[1] = currSums[1]; sumSquared[1] = currSumSquared[1]; sumOfWeights[1] = currSumOfWeights[1]; // Try all possible split points double currSplit = data.instance(sortedIndices[0]).value(att); double currVal, bestVal = Double.MAX_VALUE; for (i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } if (inst.value(att) > currSplit) { currVal = variance(currSums, currSumSquared, currSumOfWeights); if (currVal < bestVal) { bestVal = currVal; splitPoint = (inst.value(att) + currSplit) / 2.0; for (int j = 0; j < 2; j++) { sums[j] = currSums[j]; sumSquared[j] = currSumSquared[j]; sumOfWeights[j] = currSumOfWeights[j]; } } } currSplit = inst.value(att); double classVal = inst.classValue() * weights[i]; double classValSquared = inst.classValue() * classVal; currSums[0] += classVal; currSumSquared[0] += classValSquared; currSumOfWeights[0] += weights[i]; currSums[1] -= classVal; currSumSquared[1] -= classValSquared; currSumOfWeights[1] -= weights[i]; } } // Compute weights props[att] = new double[sums.length]; for (int k = 0; k < props[att].length; k++) { props[att][k] = sumOfWeights[k]; } if (!(Utils.sum(props[att]) > 0)) { for (int k = 0; k < props[att].length; k++) { props[att][k] = 1.0 / (double)props[att].length; } } else { Utils.normalize(props[att]); } // Distribute counts for missing values while (i < sortedIndices.length) { Instance inst = data.instance(sortedIndices[i]); for (int j = 0; j < sums.length; j++) { sums[j] += props[att][j] * inst.classValue() * weights[i]; sumSquared[j] += props[att][j] * inst.classValue() * inst.classValue() * weights[i]; sumOfWeights[j] += props[att][j] * weights[i]; } totalSum += inst.classValue() * weights[i]; totalSumSquared += inst.classValue() * inst.classValue() * weights[i]; totalSumOfWeights += weights[i]; i++; } // Compute final distribution dist = new double[sums.length][data.numClasses()]; for (int j = 0; j < sums.length; j++) { if (sumOfWeights[j] > 0) { dist[j][0] = sums[j] / sumOfWeights[j]; } else { dist[j][0] = totalSum / totalSumOfWeights; } } // Compute variance gain double priorVar = singleVariance(totalSum, totalSumSquared, totalSumOfWeights); double var = variance(sums, sumSquared, sumOfWeights); double gain = priorVar - var; // Return distribution and split point subsetWeights[att] = sumOfWeights; dists[att] = dist; vals[att] = gain; return splitPoint; } /** * Computes variance for subsets. */ protected double variance(double[] s, double[] sS, double[] sumOfWeights) { double var = 0; for (int i = 0; i < s.length; i++) { if (sumOfWeights[i] > 0) { var += singleVariance(s[i], sS[i], sumOfWeights[i]); } } return var; } /** * Computes the variance for a single set */ protected double singleVariance(double s, double sS, double weight) { return sS - ((s * s) / weight); } /** * Computes value of splitting criterion before split. */ protected double priorVal(double[][] dist) { return ContingencyTables.entropyOverColumns(dist); } /** * Computes value of splitting criterion after split. */ protected double gain(double[][] dist, double priorVal) { return priorVal - ContingencyTables.entropyConditionedOnRows(dist); } /** * Prunes the tree using the hold-out data (bottom-up). */ protected double reducedErrorPrune() throws Exception { // Is node leaf ? if (m_Attribute == -1) { return m_HoldOutError; } // Prune all sub trees double errorTree = 0; for (int i = 0; i < m_Successors.length; i++) { errorTree += m_Successors[i].reducedErrorPrune(); } // Replace sub tree with leaf if error doesn't get worse if (errorTree >= m_HoldOutError) { m_Attribute = -1; m_Successors = null; return m_HoldOutError; } else { return errorTree; } } /** * Inserts hold-out set into tree. */ protected void insertHoldOutSet(Instances data) throws Exception { for (int i = 0; i < data.numInstances(); i++) { insertHoldOutInstance(data.instance(i), data.instance(i).weight(), this); } } /** * Inserts an instance from the hold-out set into the tree. */ protected void insertHoldOutInstance(Instance inst, double weight, Tree parent) throws Exception { // Insert instance into hold-out class distribution if (inst.classAttribute().isNominal()) { // Nominal case m_HoldOutDist[(int)inst.classValue()] += weight; int predictedClass = 0; if (m_ClassProbs == null) { predictedClass = Utils.maxIndex(parent.m_ClassProbs); } else { predictedClass = Utils.maxIndex(m_ClassProbs); } if (predictedClass != (int)inst.classValue()) { m_HoldOutError += weight; } } else { // Numeric case m_HoldOutDist[0] += weight; double diff = 0; if (m_ClassProbs == null) { diff = parent.m_ClassProbs[0] - inst.classValue(); } else { diff = m_ClassProbs[0] - inst.classValue(); } m_HoldOutError += diff * diff * weight; } // Th process is recursive if (m_Attribute != -1) { // If node is not a leaf
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -