📄 reptree.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* REPTree.java
* Copyright (C) 1999 Eibe Frank
*
*/
package weka.classifiers.trees;
import java.io.Serializable;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.Sourcable;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.ContingencyTables;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.UnsupportedClassTypeException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
/**
* Fast decision tree learner. Builds a decision/regression tree using
* information gain/variance reduction and prunes it using reduced-error pruning
* (with backfitting). Only sorts values for numeric attributes
* once. Missing values are dealt with by splitting the corresponding
* instances into pieces (i.e. as in C4.5).
*
* Valid options are: <p>
*
* -M number <br>
* Set minimum number of instances per leaf (default 2). <p>
*
* -V number <br>
* Set minimum numeric class variance proportion of train variance for
* split (default 1e-3). <p>
*
* -N number <br>
* Number of folds for reduced error pruning (default 3). <p>
*
* -S number <br>
* Seed for random data shuffling (default 1). <p>
*
* -P <br>
* No pruning. <p>
*
* -L <br>
* Maximum tree depth (default -1, no maximum). <p>
*
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @version $Revision$
*/
public class REPTree extends Classifier
implements OptionHandler, WeightedInstancesHandler, Drawable,
AdditionalMeasureProducer, Sourcable {
/**
* Returns a string describing classifier
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Fast decision tree learner. Builds a decision/regression tree using "
+ "information gain/variance and prunes it using reduced-error pruning "
+ "(with backfitting). Only sorts values for numeric attributes "
+ "once. Missing values are dealt with by splitting the corresponding "
+ "instances into pieces (i.e. as in C4.5).";
}
/** An inner class for building and storing the tree structure */
protected class Tree implements Serializable {
/** The header information (for printing the tree). */
protected Instances m_Info = null;
/** The subtrees of this tree. */
protected Tree[] m_Successors;
/** The attribute to split on. */
protected int m_Attribute = -1;
/** The split point. */
protected double m_SplitPoint = Double.NaN;
/** The proportions of training instances going down each branch. */
protected double[] m_Prop = null;
/** Class probabilities from the training data in the nominal case.
Holds the mean in the numeric case. */
protected double[] m_ClassProbs = null;
/** The (unnormalized) class distribution in the nominal
case. Holds the sum of squared errors and the weight
in the numeric case. */
protected double[] m_Distribution = null;
/** Class distribution of hold-out set at node in the nominal case.
Straight sum of weights in the numeric case (i.e. array has
only one element. */
protected double[] m_HoldOutDist = null;
/** The hold-out error of the node. The number of miss-classified
instances in the nominal case, the sum of squared errors in the
numeric case. */
protected double m_HoldOutError = 0;
/**
* Computes class distribution of an instance using the tree.
*/
protected double[] distributionForInstance(Instance instance)
throws Exception {
double[] returnedDist = null;
if (m_Attribute > -1) {
// Node is not a leaf
if (instance.isMissing(m_Attribute)) {
// Value is missing
returnedDist = new double[m_Info.numClasses()];
// Split instance up
for (int i = 0; i < m_Successors.length; i++) {
double[] help =
m_Successors[i].distributionForInstance(instance);
if (help != null) {
for (int j = 0; j < help.length; j++) {
returnedDist[j] += m_Prop[i] * help[j];
}
}
}
} else if (m_Info.attribute(m_Attribute).isNominal()) {
// For nominal attributes
returnedDist = m_Successors[(int)instance.value(m_Attribute)].
distributionForInstance(instance);
} else {
// For numeric attributes
if (instance.value(m_Attribute) < m_SplitPoint) {
returnedDist =
m_Successors[0].distributionForInstance(instance);
} else {
returnedDist =
m_Successors[1].distributionForInstance(instance);
}
}
}
if ((m_Attribute == -1) || (returnedDist == null)) {
// Node is a leaf or successor is empty
return m_ClassProbs;
} else {
return returnedDist;
}
}
/**
* Returns a string containing java source code equivalent to the test
* made at this node. The instance being tested is called "i". This
* routine assumes to be called in the order of branching, enabling us to
* set the >= condition test (the last one) of a numeric splitpoint
* to just "true" (because being there in the flow implies that the
* previous less-than test failed).
*
* @param index index of the value tested
* @return a value of type 'String'
*/
public final String sourceExpression(int index) {
StringBuffer expr = null;
if (index < 0) {
return "i[" + m_Attribute + "] == null";
}
if (m_Info.attribute(m_Attribute).isNominal()) {
expr = new StringBuffer("i[");
expr.append(m_Attribute).append("]");
expr.append(".equals(\"").append(m_Info.attribute(m_Attribute)
.value(index)).append("\")");
} else {
expr = new StringBuffer("");
if (index == 0) {
expr.append("((Double)i[")
.append(m_Attribute).append("]).doubleValue() < ")
.append(m_SplitPoint);
} else {
expr.append("true");
}
}
return expr.toString();
}
/**
* Returns source code for the tree as if-then statements. The
* class is assigned to variable "p", and assumes the tested
* instance is named "i". The results are returned as two stringbuffers:
* a section of code for assignment of the class, and a section of
* code containing support code (eg: other support methods).
*
* TODO: If the outputted source code encounters a missing value
* for the evaluated attribute, it stops branching and uses the
* class distribution of the current node to decide the return value.
* This is unlike the behaviour of distributionForInstance().
*
* @param className the classname that this static classifier has
* @param parent parent node of the current node
* @return an array containing two stringbuffers, the first string containing
* assignment code, and the second containing source for support code.
* @exception Exception if something goes wrong
*/
public StringBuffer [] toSource(String className, Tree parent)
throws Exception {
StringBuffer [] result = new StringBuffer[2];
double[] currentProbs;
if(m_ClassProbs == null)
currentProbs = parent.m_ClassProbs;
else
currentProbs = m_ClassProbs;
long printID = nextID();
// Is this a leaf?
if (m_Attribute == -1) {
result[0] = new StringBuffer(" p = ");
if(m_Info.classAttribute().isNumeric())
result[0].append(currentProbs[0]);
else {
result[0].append(Utils.maxIndex(currentProbs));
}
result[0].append(";\n");
result[1] = new StringBuffer("");
} else {
StringBuffer text = new StringBuffer("");
String nextIndent = " ";
StringBuffer atEnd = new StringBuffer("");
text.append(" static double N")
.append(Integer.toHexString(this.hashCode()) + printID)
.append("(Object []i) {\n")
.append(" double p = Double.NaN;\n");
text.append(" /* " + m_Info.attribute(m_Attribute).name() + " */\n");
// Missing attribute?
text.append(" if (" + this.sourceExpression(-1) + ") {\n")
.append(" p = ");
if(m_Info.classAttribute().isNumeric())
text.append(currentProbs[0] + ";\n");
else
text.append(Utils.maxIndex(currentProbs) + ";\n");
text.append(" } ");
// Branching of the tree
for (int i=0;i<m_Successors.length; i++) {
text.append("else if (" + this.sourceExpression(i) + ") {\n");
// Is the successor a leaf?
if(m_Successors[i].m_Attribute == -1) {
double[] successorProbs = m_Successors[i].m_ClassProbs;
if(successorProbs == null)
successorProbs = m_ClassProbs;
text.append(" p = ");
if(m_Info.classAttribute().isNumeric()) {
text.append(successorProbs[0] + ";\n");
} else {
text.append(Utils.maxIndex(successorProbs) + ";\n");
}
} else {
StringBuffer [] sub = m_Successors[i].toSource(className, this);
text.append("" + sub[0]);
atEnd.append("" + sub[1]);
}
text.append(" } ");
if (i == m_Successors.length - 1) {
text.append("\n");
}
}
text.append(" return p;\n }\n");
result[0] = new StringBuffer(" p = " + className + ".N");
result[0].append(Integer.toHexString(this.hashCode()) + printID)
.append("(i);\n");
result[1] = text.append("" + atEnd);
}
return result;
}
/**
* Outputs one node for graph.
*/
protected int toGraph(StringBuffer text, int num,
Tree parent) throws Exception {
num++;
if (m_Attribute == -1) {
text.append("N" + Integer.toHexString(Tree.this.hashCode()) +
" [label=\"" + num + leafString(parent) +"\"" +
"shape=box]\n");
} else {
text.append("N" + Integer.toHexString(Tree.this.hashCode()) +
" [label=\"" + num + ": " +
m_Info.attribute(m_Attribute).name() +
"\"]\n");
for (int i = 0; i < m_Successors.length; i++) {
text.append("N" + Integer.toHexString(Tree.this.hashCode())
+ "->" +
"N" +
Integer.toHexString(m_Successors[i].hashCode()) +
" [label=\"");
if (m_Info.attribute(m_Attribute).isNumeric()) {
if (i == 0) {
text.append(" < " +
Utils.doubleToString(m_SplitPoint, 2));
} else {
text.append(" >= " +
Utils.doubleToString(m_SplitPoint, 2));
}
} else {
text.append(" = " + m_Info.attribute(m_Attribute).value(i));
}
text.append("\"]\n");
num = m_Successors[i].toGraph(text, num, this);
}
}
return num;
}
/**
* Outputs description of a leaf node.
*/
protected String leafString(Tree parent) throws Exception {
if (m_Info.classAttribute().isNumeric()) {
double classMean;
if (m_ClassProbs == null) {
classMean = parent.m_ClassProbs[0];
} else {
classMean = m_ClassProbs[0];
}
StringBuffer buffer = new StringBuffer();
buffer.append(" : " + Utils.doubleToString(classMean, 2));
double avgError = 0;
if (m_Distribution[1] > 0) {
avgError = m_Distribution[0] / m_Distribution[1];
}
buffer.append(" (" +
Utils.doubleToString(m_Distribution[1], 2) + "/" +
Utils.doubleToString(avgError, 2)
+ ")");
avgError = 0;
if (m_HoldOutDist[0] > 0) {
avgError = m_HoldOutError / m_HoldOutDist[0];
}
buffer.append(" [" +
Utils.doubleToString(m_HoldOutDist[0], 2) + "/" +
Utils.doubleToString(avgError, 2)
+ "]");
return buffer.toString();
} else {
int maxIndex;
if (m_ClassProbs == null) {
maxIndex = Utils.maxIndex(parent.m_ClassProbs);
} else {
maxIndex = Utils.maxIndex(m_ClassProbs);
}
return " : " + m_Info.classAttribute().value(maxIndex) +
" (" + Utils.doubleToString(Utils.sum(m_Distribution), 2) +
"/" +
Utils.doubleToString((Utils.sum(m_Distribution) -
m_Distribution[maxIndex]), 2) + ")" +
" [" + Utils.doubleToString(Utils.sum(m_HoldOutDist), 2) + "/" +
Utils.doubleToString((Utils.sum(m_HoldOutDist) -
m_HoldOutDist[maxIndex]), 2) + "]";
}
}
/**
* Recursively outputs the tree.
*/
protected String toString(int level, Tree parent) {
try {
StringBuffer text = new StringBuffer();
if (m_Attribute == -1) {
// Output leaf info
return leafString(parent);
} else if (m_Info.attribute(m_Attribute).isNominal()) {
// For nominal attributes
for (int i = 0; i < m_Successors.length; i++) {
text.append("\n");
for (int j = 0; j < level; j++) {
text.append("| ");
}
text.append(m_Info.attribute(m_Attribute).name() + " = " +
m_Info.attribute(m_Attribute).value(i));
text.append(m_Successors[i].toString(level + 1, this));
}
} else {
// For numeric attributes
text.append("\n");
for (int j = 0; j < level; j++) {
text.append("| ");
}
text.append(m_Info.attribute(m_Attribute).name() + " < " +
Utils.doubleToString(m_SplitPoint, 2));
text.append(m_Successors[0].toString(level + 1, this));
text.append("\n");
for (int j = 0; j < level; j++) {
text.append("| ");
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -