📄 reptree.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * REPTree.java * Copyright (C) 1999 Eibe Frank * */package weka.classifiers.trees;import weka.classifiers.*;import weka.core.*;import java.util.*;import java.io.*;/** * Fast decision tree learner. Builds a decision/regression tree using * information gain/variance and prunes it using reduced-error pruning. * Only sorts values for numeric attributes once. Missing values are * dealt with by splitting the corresponding instances into pieces * (i.e. as in C4.5). * * Valid options are: <p> * * -M number <br> * Set minimum number of instances per leaf (default 2). <p> * * -V number <br> * Set minimum numeric class variance proportion of train variance for * split (default 1e-3). <p> * * -N number <br> * Number of folds for reduced error pruning (default 3). <p> * * -S number <br> * Seed for random data shuffling (default 1). <p> * * -P <br> * No pruning. <p> * * -D <br> * Maximum tree depth (default -1, no maximum). <p> * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @version $Revision: 1.1.1.1 $ */public class REPTree extends DistributionClassifier implements OptionHandler, WeightedInstancesHandler, Drawable, AdditionalMeasureProducer { /** An inner class for building and storing the tree structure */ protected class Tree implements Serializable { /** The header information (for printing the tree). */ protected Instances m_Info = null; /** The subtrees of this tree. */ protected Tree[] m_Successors; /** The attribute to split on. */ protected int m_Attribute = -1; /** The split point. */ protected double m_SplitPoint = Double.NaN; /** The proportions of training instances going down each branch. */ protected double[] m_Prop = null; /** Class probabilities from the training data in the nominal case. Holds the mean in the numeric case. */ protected double[] m_ClassProbs = null; /** The (unnormalized) class distribution in the nominal case. Holds the sum of squared errors and the weight in the numeric case. */ protected double[] m_Distribution = null; /** Class distribution of hold-out set at node in the nominal case. Straight sum of weights in the numeric case (i.e. array has only one element. */ protected double[] m_HoldOutDist = null; /** The hold-out error of the node. The number of miss-classified instances in the nominal case, the sum of squared errors in the numeric case. */ protected double m_HoldOutError = 0; /** * Computes class distribution of an instance using the tree. */ protected double[] distributionForInstance(Instance instance) throws Exception { double[] returnedDist = null; if (m_Attribute > -1) { // Node is not a leaf if (instance.isMissing(m_Attribute)) { // Value is missing returnedDist = new double[m_Info.numClasses()]; // Split instance up for (int i = 0; i < m_Successors.length; i++) { double[] help = m_Successors[i].distributionForInstance(instance); if (help != null) { for (int j = 0; j < help.length; j++) { returnedDist[j] += m_Prop[i] * help[j]; } } } } else if (m_Info.attribute(m_Attribute).isNominal()) { // For nominal attributes returnedDist = m_Successors[(int)instance.value(m_Attribute)]. distributionForInstance(instance); } else { // For numeric attributes if (instance.value(m_Attribute) < m_SplitPoint) { returnedDist = m_Successors[0].distributionForInstance(instance); } else { returnedDist = m_Successors[1].distributionForInstance(instance); } } } if ((m_Attribute == -1) || (returnedDist == null)) { // Node is a leaf or successor is empty return m_ClassProbs; } else { return returnedDist; } } /** * Outputs one node for graph. */ protected int toGraph(StringBuffer text, int num, Tree parent) throws Exception { num++; if (m_Attribute == -1) { text.append("N" + Integer.toHexString(Tree.this.hashCode()) + " [label=\"" + num + leafString(parent) +"\"" + "shape=box]\n"); } else { text.append("N" + Integer.toHexString(Tree.this.hashCode()) + " [label=\"" + num + ": " + m_Info.attribute(m_Attribute).name() + "\"]\n"); for (int i = 0; i < m_Successors.length; i++) { text.append("N" + Integer.toHexString(Tree.this.hashCode()) + "->" + "N" + Integer.toHexString(m_Successors[i].hashCode()) + " [label=\""); if (m_Info.attribute(m_Attribute).isNumeric()) { if (i == 0) { text.append(" < " + Utils.doubleToString(m_SplitPoint, 2)); } else { text.append(" >= " + Utils.doubleToString(m_SplitPoint, 2)); } } else { text.append(" = " + m_Info.attribute(m_Attribute).value(i)); } text.append("\"]\n"); num = m_Successors[i].toGraph(text, num, this); } } return num; } /** * Outputs description of a leaf node. */ protected String leafString(Tree parent) throws Exception { if (m_Info.classAttribute().isNumeric()) { double classMean; if (m_ClassProbs == null) { classMean = parent.m_ClassProbs[0]; } else { classMean = m_ClassProbs[0]; } StringBuffer buffer = new StringBuffer(); buffer.append(" : " + Utils.doubleToString(classMean, 2)); double avgError = 0; if (m_Distribution[1] > 0) { avgError = m_Distribution[0] / m_Distribution[1]; } buffer.append(" (" + Utils.doubleToString(m_Distribution[1], 2) + "/" + Utils.doubleToString(avgError, 2) + ")"); avgError = 0; if (m_HoldOutDist[0] > 0) { avgError = m_HoldOutError / m_HoldOutDist[0]; } buffer.append(" [" + Utils.doubleToString(m_HoldOutDist[0], 2) + "/" + Utils.doubleToString(avgError, 2) + "]"); return buffer.toString(); } else { int maxIndex; if (m_ClassProbs == null) { maxIndex = Utils.maxIndex(parent.m_ClassProbs); } else { maxIndex = Utils.maxIndex(m_ClassProbs); } return " : " + m_Info.classAttribute().value(maxIndex) + " (" + Utils.doubleToString(Utils.sum(m_Distribution), 2) + "/" + Utils.doubleToString((Utils.sum(m_Distribution) - m_Distribution[maxIndex]), 2) + ")" + " [" + Utils.doubleToString(Utils.sum(m_HoldOutDist), 2) + "/" + Utils.doubleToString((Utils.sum(m_HoldOutDist) - m_HoldOutDist[maxIndex]), 2) + "]"; } } /** * Recursively outputs the tree. */ protected String toString(int level, Tree parent) { try { StringBuffer text = new StringBuffer(); if (m_Attribute == -1) { // Output leaf info return leafString(parent); } else if (m_Info.attribute(m_Attribute).isNominal()) { // For nominal attributes for (int i = 0; i < m_Successors.length; i++) { text.append("\n"); for (int j = 0; j < level; j++) { text.append("| "); } text.append(m_Info.attribute(m_Attribute).name() + " = " + m_Info.attribute(m_Attribute).value(i)); text.append(m_Successors[i].toString(level + 1, this)); } } else { // For numeric attributes text.append("\n"); for (int j = 0; j < level; j++) { text.append("| "); } text.append(m_Info.attribute(m_Attribute).name() + " < " + Utils.doubleToString(m_SplitPoint, 2)); text.append(m_Successors[0].toString(level + 1, this)); text.append("\n"); for (int j = 0; j < level; j++) { text.append("| "); } text.append(m_Info.attribute(m_Attribute).name() + " >= " + Utils.doubleToString(m_SplitPoint, 2)); text.append(m_Successors[1].toString(level + 1, this)); } return text.toString(); } catch (Exception e) { e.printStackTrace(); return "Decision tree: tree can't be printed"; } } /** * Recursively generates a tree. */ protected void buildTree(int[][] sortedIndices, double[][] weights, Instances data, double totalWeight, double[] classProbs, Instances header, double minNum, double minVariance, int depth, int maxDepth) throws Exception { // Store structure of dataset, set minimum number of instances // and make space for potential info from pruning data m_Info = header; m_HoldOutDist = new double[data.numClasses()]; // Make leaf if there are no training instances if (sortedIndices[0].length == 0) { if (data.classAttribute().isNumeric()) { m_Distribution = new double[2]; } else { m_Distribution = new double[data.numClasses()]; } m_ClassProbs = null; return; } double priorVar = 0; if (data.classAttribute().isNumeric()) { // Compute prior variance double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0; for (int i = 0; i < sortedIndices[0].length; i++) { Instance inst = data.instance(sortedIndices[0][i]); totalSum += inst.classValue() * weights[0][i]; totalSumSquared += inst.classValue() * inst.classValue() * weights[0][i]; totalSumOfWeights += weights[0][i]; } priorVar = singleVariance(totalSum, totalSumSquared, totalSumOfWeights); } // Check if node doesn't contain enough instances, is pure // or the maximum tree depth is reached m_ClassProbs = new double[classProbs.length]; System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length); if ((totalWeight < (2 * minNum)) || // Nominal case (data.classAttribute().isNominal() && Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)], Utils.sum(m_ClassProbs))) || // Numeric case (data.classAttribute().isNumeric() && ((priorVar / totalWeight) < minVariance)) || // Check tree depth ((m_MaxDepth >= 0) && (depth >= maxDepth))) { // Make leaf m_Attribute = -1; if (data.classAttribute().isNominal()) { // Nominal case m_Distribution = new double[m_ClassProbs.length]; for (int i = 0; i < m_ClassProbs.length; i++) { m_Distribution[i] = m_ClassProbs[i]; } Utils.normalize(m_ClassProbs); } else { // Numeric case m_Distribution = new double[2]; m_Distribution[0] = priorVar; m_Distribution[1] = totalWeight; } return; } // Compute class distributions and value of splitting // criterion for each attribute double[] vals = new double[data.numAttributes()]; double[][][] dists = new double[data.numAttributes()][0][0]; double[][] props = new double[data.numAttributes()][0]; double[][] totalSubsetWeights = new double[data.numAttributes()][0]; double[] splits = new double[data.numAttributes()]; if (data.classAttribute().isNominal()) { // Nominal case for (int i = 0; i < data.numAttributes(); i++) { if (i != data.classIndex()) { splits[i] = distribution(props, dists, i, sortedIndices[i], weights[i], totalSubsetWeights, data); vals[i] = gain(dists[i], priorVal(dists[i])); } } } else { // Numeric case for (int i = 0; i < data.numAttributes(); i++) { if (i != data.classIndex()) { splits[i] = numericDistribution(props, dists, i, sortedIndices[i], weights[i], totalSubsetWeights, data, vals); } } } // Find best attribute m_Attribute = Utils.maxIndex(vals); int numAttVals = dists[m_Attribute].length; // Check if there are at least two subsets with // required minimum number of instances int count = 0; for (int i = 0; i < numAttVals; i++) { if (totalSubsetWeights[m_Attribute][i] >= minNum) { count++; } if (count > 1) { break; } } // Any useful split found? if ((vals[m_Attribute] > 0) && (count > 1)) { // Build subtrees m_SplitPoint = splits[m_Attribute]; m_Prop = props[m_Attribute]; int[][][] subsetIndices = new int[numAttVals][data.numAttributes()][0]; double[][][] subsetWeights = new double[numAttVals][data.numAttributes()][0]; splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint, sortedIndices, weights, data); m_Successors = new Tree[numAttVals]; for (int i = 0; i < numAttVals; i++) { m_Successors[i] = new Tree(); m_Successors[i]. buildTree(subsetIndices[i], subsetWeights[i], data, totalSubsetWeights[m_Attribute][i], dists[m_Attribute][i], header, minNum, minVariance, depth + 1, maxDepth); } } else { // Make leaf m_Attribute = -1; } // Normalize class counts if (data.classAttribute().isNominal()) { m_Distribution = new double[m_ClassProbs.length]; for (int i = 0; i < m_ClassProbs.length; i++) { m_Distribution[i] = m_ClassProbs[i]; } Utils.normalize(m_ClassProbs); } else { m_Distribution = new double[2]; m_Distribution[0] = priorVar; m_Distribution[1] = totalWeight; } } /** * Computes size of the tree. */ protected int numNodes() { if (m_Attribute == -1) { return 1; } else { int size = 1; for (int i = 0; i < m_Successors.length; i++) { size += m_Successors[i].numNodes(); }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -