📄 reptree.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * REPTree.java * Copyright (C) 1999 Eibe Frank * */package weka.classifiers.trees;import weka.classifiers.Classifier;import weka.classifiers.Evaluation;import weka.classifiers.Sourcable;import weka.classifiers.rules.ZeroR;import weka.core.AdditionalMeasureProducer;import weka.core.Attribute;import weka.core.Capabilities;import weka.core.ContingencyTables;import weka.core.Drawable;import weka.core.Instance;import weka.core.Instances;import weka.core.Option;import weka.core.OptionHandler;import weka.core.Utils;import weka.core.WeightedInstancesHandler;import weka.core.Capabilities.Capability;import java.io.Serializable;import java.util.Enumeration;import java.util.Random;import java.util.Vector;/** <!-- globalinfo-start --> * Fast decision tree learner. Builds a decision/regression tree using information gain/variance and prunes it using reduced-error pruning (with backfitting). Only sorts values for numeric attributes once. Missing values are dealt with by splitting the corresponding instances into pieces (i.e. as in C4.5). * <p/> <!-- globalinfo-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -M <minimum number of instances> * Set minimum number of instances per leaf (default 2).</pre> * * <pre> -V <minimum variance for split> * Set minimum numeric class variance proportion * of train variance for split (default 1e-3).</pre> * * <pre> -N <number of folds> * Number of folds for reduced error pruning (default 3).</pre> * * <pre> -S <seed> * Seed for random data shuffling (default 1).</pre> * * <pre> -P * No pruning.</pre> * * <pre> -L * Maximum tree depth (default -1, no maximum)</pre> * <!-- options-end --> * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @version $Revision: 1.23 $ */public class REPTree extends Classifier implements OptionHandler, WeightedInstancesHandler, Drawable, AdditionalMeasureProducer, Sourcable { /** for serialization */ static final long serialVersionUID = -8562443428621539458L; /** ZeroR model that is used if no attributes are present. */ protected ZeroR m_zeroR; /** * Returns a string describing classifier * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Fast decision tree learner. Builds a decision/regression tree using " + "information gain/variance and prunes it using reduced-error pruning " + "(with backfitting). Only sorts values for numeric attributes " + "once. Missing values are dealt with by splitting the corresponding " + "instances into pieces (i.e. as in C4.5)."; } /** An inner class for building and storing the tree structure */ protected class Tree implements Serializable { /** for serialization */ static final long serialVersionUID = -1635481717888437935L; /** The header information (for printing the tree). */ protected Instances m_Info = null; /** The subtrees of this tree. */ protected Tree[] m_Successors; /** The attribute to split on. */ protected int m_Attribute = -1; /** The split point. */ protected double m_SplitPoint = Double.NaN; /** The proportions of training instances going down each branch. */ protected double[] m_Prop = null; /** Class probabilities from the training data in the nominal case. Holds the mean in the numeric case. */ protected double[] m_ClassProbs = null; /** The (unnormalized) class distribution in the nominal case. Holds the sum of squared errors and the weight in the numeric case. */ protected double[] m_Distribution = null; /** Class distribution of hold-out set at node in the nominal case. Straight sum of weights in the numeric case (i.e. array has only one element. */ protected double[] m_HoldOutDist = null; /** The hold-out error of the node. The number of miss-classified instances in the nominal case, the sum of squared errors in the numeric case. */ protected double m_HoldOutError = 0; /** * Computes class distribution of an instance using the tree. * * @param instance the instance to compute the distribution for * @return the distribution * @throws Exception if computation fails */ protected double[] distributionForInstance(Instance instance) throws Exception { double[] returnedDist = null; if (m_Attribute > -1) { // Node is not a leaf if (instance.isMissing(m_Attribute)) { // Value is missing returnedDist = new double[m_Info.numClasses()]; // Split instance up for (int i = 0; i < m_Successors.length; i++) { double[] help = m_Successors[i].distributionForInstance(instance); if (help != null) { for (int j = 0; j < help.length; j++) { returnedDist[j] += m_Prop[i] * help[j]; } } } } else if (m_Info.attribute(m_Attribute).isNominal()) { // For nominal attributes returnedDist = m_Successors[(int)instance.value(m_Attribute)]. distributionForInstance(instance); } else { // For numeric attributes if (instance.value(m_Attribute) < m_SplitPoint) { returnedDist = m_Successors[0].distributionForInstance(instance); } else { returnedDist = m_Successors[1].distributionForInstance(instance); } } } if ((m_Attribute == -1) || (returnedDist == null)) { // Node is a leaf or successor is empty return m_ClassProbs; } else { return returnedDist; } } /** * Returns a string containing java source code equivalent to the test * made at this node. The instance being tested is called "i". This * routine assumes to be called in the order of branching, enabling us to * set the >= condition test (the last one) of a numeric splitpoint * to just "true" (because being there in the flow implies that the * previous less-than test failed). * * @param index index of the value tested * @return a value of type 'String' */ public final String sourceExpression(int index) { StringBuffer expr = null; if (index < 0) { return "i[" + m_Attribute + "] == null"; } if (m_Info.attribute(m_Attribute).isNominal()) { expr = new StringBuffer("i["); expr.append(m_Attribute).append("]"); expr.append(".equals(\"").append(m_Info.attribute(m_Attribute) .value(index)).append("\")"); } else { expr = new StringBuffer(""); if (index == 0) { expr.append("((Double)i[") .append(m_Attribute).append("]).doubleValue() < ") .append(m_SplitPoint); } else { expr.append("true"); } } return expr.toString(); } /** * Returns source code for the tree as if-then statements. The * class is assigned to variable "p", and assumes the tested * instance is named "i". The results are returned as two stringbuffers: * a section of code for assignment of the class, and a section of * code containing support code (eg: other support methods). * <p/> * TODO: If the outputted source code encounters a missing value * for the evaluated attribute, it stops branching and uses the * class distribution of the current node to decide the return value. * This is unlike the behaviour of distributionForInstance(). * * @param className the classname that this static classifier has * @param parent parent node of the current node * @return an array containing two stringbuffers, the first string containing * assignment code, and the second containing source for support code. * @throws Exception if something goes wrong */ public StringBuffer [] toSource(String className, Tree parent) throws Exception { StringBuffer [] result = new StringBuffer[2]; double[] currentProbs; if(m_ClassProbs == null) currentProbs = parent.m_ClassProbs; else currentProbs = m_ClassProbs; long printID = nextID(); // Is this a leaf? if (m_Attribute == -1) { result[0] = new StringBuffer(" p = "); if(m_Info.classAttribute().isNumeric()) result[0].append(currentProbs[0]); else { result[0].append(Utils.maxIndex(currentProbs)); } result[0].append(";\n"); result[1] = new StringBuffer(""); } else { StringBuffer text = new StringBuffer(""); StringBuffer atEnd = new StringBuffer(""); text.append(" static double N") .append(Integer.toHexString(this.hashCode()) + printID) .append("(Object []i) {\n") .append(" double p = Double.NaN;\n"); text.append(" /* " + m_Info.attribute(m_Attribute).name() + " */\n"); // Missing attribute? text.append(" if (" + this.sourceExpression(-1) + ") {\n") .append(" p = "); if(m_Info.classAttribute().isNumeric()) text.append(currentProbs[0] + ";\n"); else text.append(Utils.maxIndex(currentProbs) + ";\n"); text.append(" } "); // Branching of the tree for (int i=0;i<m_Successors.length; i++) { text.append("else if (" + this.sourceExpression(i) + ") {\n"); // Is the successor a leaf? if(m_Successors[i].m_Attribute == -1) { double[] successorProbs = m_Successors[i].m_ClassProbs; if(successorProbs == null) successorProbs = m_ClassProbs; text.append(" p = "); if(m_Info.classAttribute().isNumeric()) { text.append(successorProbs[0] + ";\n"); } else { text.append(Utils.maxIndex(successorProbs) + ";\n"); } } else { StringBuffer [] sub = m_Successors[i].toSource(className, this); text.append("" + sub[0]); atEnd.append("" + sub[1]); } text.append(" } "); if (i == m_Successors.length - 1) { text.append("\n"); } } text.append(" return p;\n }\n"); result[0] = new StringBuffer(" p = " + className + ".N"); result[0].append(Integer.toHexString(this.hashCode()) + printID) .append("(i);\n"); result[1] = text.append("" + atEnd); } return result; } /** * Outputs one node for graph. * * @param text the buffer to append the output to * @param num the current node id * @param parent the parent of the nodes * @return the next node id * @throws Exception if something goes wrong */ protected int toGraph(StringBuffer text, int num, Tree parent) throws Exception { num++; if (m_Attribute == -1) { text.append("N" + Integer.toHexString(Tree.this.hashCode()) + " [label=\"" + num + leafString(parent) +"\"" + "shape=box]\n"); } else { text.append("N" + Integer.toHexString(Tree.this.hashCode()) + " [label=\"" + num + ": " + m_Info.attribute(m_Attribute).name() + "\"]\n"); for (int i = 0; i < m_Successors.length; i++) { text.append("N" + Integer.toHexString(Tree.this.hashCode()) + "->" + "N" + Integer.toHexString(m_Successors[i].hashCode()) + " [label=\""); if (m_Info.attribute(m_Attribute).isNumeric()) { if (i == 0) { text.append(" < " + Utils.doubleToString(m_SplitPoint, 2)); } else { text.append(" >= " + Utils.doubleToString(m_SplitPoint, 2)); } } else { text.append(" = " + m_Info.attribute(m_Attribute).value(i)); } text.append("\"]\n"); num = m_Successors[i].toGraph(text, num, this); } } return num; } /** * Outputs description of a leaf node. * * @param parent the parent of the node * @return the description of the node * @throws Exception if generation fails */ protected String leafString(Tree parent) throws Exception { if (m_Info.classAttribute().isNumeric()) { double classMean; if (m_ClassProbs == null) { classMean = parent.m_ClassProbs[0]; } else { classMean = m_ClassProbs[0]; } StringBuffer buffer = new StringBuffer(); buffer.append(" : " + Utils.doubleToString(classMean, 2)); double avgError = 0; if (m_Distribution[1] > 0) { avgError = m_Distribution[0] / m_Distribution[1]; } buffer.append(" (" + Utils.doubleToString(m_Distribution[1], 2) + "/" + Utils.doubleToString(avgError, 2) + ")"); avgError = 0; if (m_HoldOutDist[0] > 0) { avgError = m_HoldOutError / m_HoldOutDist[0]; } buffer.append(" [" + Utils.doubleToString(m_HoldOutDist[0], 2) + "/" + Utils.doubleToString(avgError, 2) + "]"); return buffer.toString(); } else { int maxIndex; if (m_ClassProbs == null) { maxIndex = Utils.maxIndex(parent.m_ClassProbs); } else { maxIndex = Utils.maxIndex(m_ClassProbs); } return " : " + m_Info.classAttribute().value(maxIndex) + " (" + Utils.doubleToString(Utils.sum(m_Distribution), 2) + "/" + Utils.doubleToString((Utils.sum(m_Distribution) - m_Distribution[maxIndex]), 2) + ")" + " [" + Utils.doubleToString(Utils.sum(m_HoldOutDist), 2) + "/" + Utils.doubleToString((Utils.sum(m_HoldOutDist) - m_HoldOutDist[maxIndex]), 2) + "]"; } } /** * Recursively outputs the tree. * * @param level the current level * @param parent the current parent * @return the generated substree */ protected String toString(int level, Tree parent) { try { StringBuffer text = new StringBuffer(); if (m_Attribute == -1) { // Output leaf info return leafString(parent); } else if (m_Info.attribute(m_Attribute).isNominal()) { // For nominal attributes for (int i = 0; i < m_Successors.length; i++) { text.append("\n"); for (int j = 0; j < level; j++) { text.append("| "); } text.append(m_Info.attribute(m_Attribute).name() + " = " + m_Info.attribute(m_Attribute).value(i)); text.append(m_Successors[i].toString(level + 1, this)); } } else { // For numeric attributes text.append("\n"); for (int j = 0; j < level; j++) { text.append("| "); } text.append(m_Info.attribute(m_Attribute).name() + " < " + Utils.doubleToString(m_SplitPoint, 2)); text.append(m_Successors[0].toString(level + 1, this)); text.append("\n"); for (int j = 0; j < level; j++) { text.append("| "); } text.append(m_Info.attribute(m_Attribute).name() + " >= " + Utils.doubleToString(m_SplitPoint, 2)); text.append(m_Successors[1].toString(level + 1, this)); } return text.toString(); } catch (Exception e) { e.printStackTrace(); return "Decision tree: tree can't be printed"; } } /** * Recursively generates a tree. * * @param sortedIndices the sorted indices of the instances
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -