randomtree.java
来自「Weka」· Java 代码 · 共 1,087 行 · 第 1/3 页
JAVA
1,087 行
* @param random random number generator for choosing random attributes * @param depth the current depth * @throws Exception if generation fails */ protected void buildTree(int[][] sortedIndices, double[][] weights, Instances data, double[] classProbs, Instances header, double minNum, boolean debug, int[] attIndicesWindow, Random random, int depth) throws Exception { // Store structure of dataset, set minimum number of instances m_Info = header; m_Debug = debug; m_MinNum = minNum; // Make leaf if there are no training instances if (((data.classIndex() > 0) && (sortedIndices[0].length == 0)) || ((data.classIndex() == 0) && sortedIndices[1].length == 0)) { m_Distribution = new double[1][data.numClasses()]; m_ClassProbs = null; return; } // Check if node doesn't contain enough instances or is pure // or maximum depth reached m_ClassProbs = new double[classProbs.length]; System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length); if (Utils.sum(m_ClassProbs) < 2 * m_MinNum || Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)], Utils.sum(m_ClassProbs)) || ((getMaxDepth() > 0) && (depth >= getMaxDepth()))) { // Make leaf m_Attribute = -1; m_Distribution = new double[1][m_ClassProbs.length]; for (int i = 0; i < m_ClassProbs.length; i++) { m_Distribution[0][i] = m_ClassProbs[i]; } Utils.normalize(m_ClassProbs); return; } // Compute class distributions and value of splitting // criterion for each attribute double[] vals = new double[data.numAttributes()]; double[][][] dists = new double[data.numAttributes()][0][0]; double[][] props = new double[data.numAttributes()][0]; double[] splits = new double[data.numAttributes()]; // Investigate K random attributes int attIndex = 0; int windowSize = attIndicesWindow.length; int k = m_KValue; boolean gainFound = false; while ((windowSize > 0) && (k-- > 0 || !gainFound)) { int chosenIndex = random.nextInt(windowSize); attIndex = attIndicesWindow[chosenIndex]; // shift chosen attIndex out of window attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize-1]; attIndicesWindow[windowSize-1] = attIndex; windowSize--; splits[attIndex] = distribution(props, dists, attIndex, sortedIndices[attIndex], weights[attIndex], data); vals[attIndex] = gain(dists[attIndex], priorVal(dists[attIndex])); if (vals[attIndex] > 0) gainFound = true; } // Find best attribute m_Attribute = Utils.maxIndex(vals); m_Distribution = dists[m_Attribute]; // Any useful split found? if (vals[m_Attribute] > 0) { // Build subtrees m_SplitPoint = splits[m_Attribute]; m_Prop = props[m_Attribute]; int[][][] subsetIndices = new int[m_Distribution.length][data.numAttributes()][0]; double[][][] subsetWeights = new double[m_Distribution.length][data.numAttributes()][0]; splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint, sortedIndices, weights, m_Distribution, data); m_Successors = new RandomTree[m_Distribution.length]; for (int i = 0; i < m_Distribution.length; i++) { m_Successors[i] = new RandomTree(); m_Successors[i].setKValue(m_KValue); m_Successors[i].setMaxDepth(getMaxDepth()); m_Successors[i].buildTree(subsetIndices[i], subsetWeights[i], data, m_Distribution[i], header, m_MinNum, m_Debug, attIndicesWindow, random, depth + 1); } } else { // Make leaf m_Attribute = -1; m_Distribution = new double[1][m_ClassProbs.length]; for (int i = 0; i < m_ClassProbs.length; i++) { m_Distribution[0][i] = m_ClassProbs[i]; } } // Normalize class counts Utils.normalize(m_ClassProbs); } /** * Computes size of the tree. * * @return the number of nodes */ public int numNodes() { if (m_Attribute == -1) { return 1; } else { int size = 1; for (int i = 0; i < m_Successors.length; i++) { size += m_Successors[i].numNodes(); } return size; } } /** * Splits instances into subsets. * * @param subsetIndices the sorted indices of the subset * @param subsetWeights the weights of the subset * @param att the attribute index * @param splitPoint the splitpoint for numeric attributes * @param sortedIndices the sorted indices of the whole set * @param weights the weights of the whole set * @param dist the distribution * @param data the data to work with * @throws Exception if something goes wrong */ protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights, int att, double splitPoint, int[][] sortedIndices, double[][] weights, double[][] dist, Instances data) throws Exception { int j; int[] num; // For each attribute for (int i = 0; i < data.numAttributes(); i++) { if (i != data.classIndex()) { if (data.attribute(att).isNominal()) { // For nominal attributes num = new int[data.attribute(att).numValues()]; for (int k = 0; k < num.length; k++) { subsetIndices[k][i] = new int[sortedIndices[i].length]; subsetWeights[k][i] = new double[sortedIndices[i].length]; } for (j = 0; j < sortedIndices[i].length; j++) { Instance inst = data.instance(sortedIndices[i][j]); if (inst.isMissing(att)) { // Split instance up for (int k = 0; k < num.length; k++) { if (m_Prop[k] > 0) { subsetIndices[k][i][num[k]] = sortedIndices[i][j]; subsetWeights[k][i][num[k]] = m_Prop[k] * weights[i][j]; num[k]++; } } } else { int subset = (int)inst.value(att); subsetIndices[subset][i][num[subset]] = sortedIndices[i][j]; subsetWeights[subset][i][num[subset]] = weights[i][j]; num[subset]++; } } } else { // For numeric attributes num = new int[2]; for (int k = 0; k < 2; k++) { subsetIndices[k][i] = new int[sortedIndices[i].length]; subsetWeights[k][i] = new double[weights[i].length]; } for (j = 0; j < sortedIndices[i].length; j++) { Instance inst = data.instance(sortedIndices[i][j]); if (inst.isMissing(att)) { // Split instance up for (int k = 0; k < num.length; k++) { if (m_Prop[k] > 0) { subsetIndices[k][i][num[k]] = sortedIndices[i][j]; subsetWeights[k][i][num[k]] = m_Prop[k] * weights[i][j]; num[k]++; } } } else { int subset = (inst.value(att) < splitPoint) ? 0 : 1; subsetIndices[subset][i][num[subset]] = sortedIndices[i][j]; subsetWeights[subset][i][num[subset]] = weights[i][j]; num[subset]++; } } } // Trim arrays for (int k = 0; k < num.length; k++) { int[] copy = new int[num[k]]; System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]); subsetIndices[k][i] = copy; double[] copyWeights = new double[num[k]]; System.arraycopy(subsetWeights[k][i], 0, copyWeights, 0, num[k]); subsetWeights[k][i] = copyWeights; } } } } /** * Computes class distribution for an attribute. * * @param props * @param dists * @param att the attribute index * @param sortedIndices the sorted indices of the data * @param weights * @param data the data to work with * @throws Exception if something goes wrong */ protected double distribution(double[][] props, double[][][] dists, int att, int[] sortedIndices, double[] weights, Instances data) throws Exception { double splitPoint = Double.NaN; Attribute attribute = data.attribute(att); double[][] dist = null; int i; if (attribute.isNominal()) { // For nominal attributes dist = new double[attribute.numValues()][data.numClasses()]; for (i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } dist[(int)inst.value(att)][(int)inst.classValue()] += weights[i]; } } else { // For numeric attributes double[][] currDist = new double[2][data.numClasses()]; dist = new double[2][data.numClasses()]; // Move all instances into second subset for (int j = 0; j < sortedIndices.length; j++) { Instance inst = data.instance(sortedIndices[j]); if (inst.isMissing(att)) { break; } currDist[1][(int)inst.classValue()] += weights[j]; } double priorVal = priorVal(currDist); for (int j = 0; j < currDist.length; j++) { System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length); } // Try all possible split points double currSplit = data.instance(sortedIndices[0]).value(att); double currVal, bestVal = -Double.MAX_VALUE; for (i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } if (inst.value(att) > currSplit) { currVal = gain(currDist, priorVal); if (currVal > bestVal) { bestVal = currVal; splitPoint = (inst.value(att) + currSplit) / 2.0; for (int j = 0; j < currDist.length; j++) { System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length); } } } currSplit = inst.value(att); currDist[0][(int)inst.classValue()] += weights[i]; currDist[1][(int)inst.classValue()] -= weights[i]; } } // Compute weights props[att] = new double[dist.length]; for (int k = 0; k < props[att].length; k++) { props[att][k] = Utils.sum(dist[k]); } if (Utils.eq(Utils.sum(props[att]), 0)) { for (int k = 0; k < props[att].length; k++) { props[att][k] = 1.0 / (double)props[att].length; } } else { Utils.normalize(props[att]); } // Any instances with missing values ? if (i < sortedIndices.length) { // Distribute counts while (i < sortedIndices.length) { Instance inst = data.instance(sortedIndices[i]); for (int j = 0; j < dist.length; j++) { dist[j][(int)inst.classValue()] += props[att][j] * weights[i]; } i++; } } // Return distribution and split point dists[att] = dist; return splitPoint; } /** * Computes value of splitting criterion before split. * * @param dist the distributions * @return the splitting criterion */ protected double priorVal(double[][] dist) { return ContingencyTables.entropyOverColumns(dist); } /** * Computes value of splitting criterion after split. * * @param dist the distributions * @param priorVal the splitting criterion * @return the gain after the split */ protected double gain(double[][] dist, double priorVal) { return priorVal - ContingencyTables.entropyConditionedOnRows(dist); } /** * Main method for this class. * * @param argv the commandline parameters */ public static void main(String[] argv) { runClassifier(new RandomTree(), argv); }}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?