📄 randomtree.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* RandomTree.java
* Copyright (C) 2001 Richard Kirkby, Eibe Frank
*
*/
package weka.classifiers.trees;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Attribute;
import weka.core.ContingencyTables;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Randomizable;
import weka.core.UnsupportedClassTypeException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
/**
* Class for constructing a tree that considers K random features at each node.
* Performs no pruning.
*
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision$
*/
public class RandomTree extends Classifier
implements OptionHandler, WeightedInstancesHandler, Randomizable {
/** The subtrees appended to this tree. */
protected RandomTree[] m_Successors;
/** The attribute to split on. */
protected int m_Attribute = -1;
/** The split point. */
protected double m_SplitPoint = Double.NaN;
/** The class distribution from the training data. */
protected double[][] m_Distribution = null;
/** The header information. */
protected Instances m_Info = null;
/** The proportions of training instances going down each branch. */
protected double[] m_Prop = null;
/** Class probabilities from the training data. */
protected double[] m_ClassProbs = null;
/** Minimum number of instances for leaf. */
protected double m_MinNum = 1.0;
/** Debug info */
protected boolean m_Debug = false;
/** The number of attributes considered for a split. */
protected int m_KValue = 1;
/** The random seed to use. */
protected int m_randomSeed = 1;
/**
* Returns a string describing classifier
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Class for constructing a tree that considers K randomly " +
" chosen attributes at each node. Performs no pruning.";
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String minNumTipText() {
return "The minimum total weight of the instances in a leaf.";
}
/**
* Get the value of MinNum.
*
* @return Value of MinNum.
*/
public double getMinNum() {
return m_MinNum;
}
/**
* Set the value of MinNum.
*
* @param newMinNum Value to assign to MinNum.
*/
public void setMinNum(double newMinNum) {
m_MinNum = newMinNum;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String KValueTipText() {
return "Sets the number of randomly chosen attributes.";
}
/**
* Get the value of K.
*
* @return Value of K.
*/
public int getKValue() {
return m_KValue;
}
/**
* Set the value of K.
*
* @param k Value to assign to K.
*/
public void setKValue(int k) {
m_KValue = k;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String debugTipText() {
return "Whether debug information is output to the console.";
}
/**
* Get the value of Debug.
*
* @return Value of Debug.
*/
public boolean getDebug() {
return m_Debug;
}
/**
* Set the value of Debug.
*
* @param newDebug Value to assign to Debug.
*/
public void setDebug(boolean newDebug) {
m_Debug = newDebug;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String seedTipText() {
return "The random number seed used for selecting attributes.";
}
/**
* Set the seed for random number generation.
*
* @param seed the seed
*/
public void setSeed(int seed) {
m_randomSeed = seed;
}
/**
* Gets the seed for the random number generations
*
* @return the seed for the random number generation
*/
public int getSeed() {
return m_randomSeed;
}
/**
* Lists the command-line options for this classifier.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(6);
newVector.
addElement(new Option("\tNumber of attributes to randomly investigate.",
"K", 1, "-K <number of attributes>"));
newVector.
addElement(new Option("\tSet minimum number of instances per leaf.",
"M", 1, "-M <minimum number of instances>"));
newVector.
addElement(new Option("\tTurns debugging info on.",
"D", 0, "-D"));
newVector
.addElement(new Option("\tSeed for random number generator.\n"
+ "\t(default 1)",
"S", 1, "-S"));
return newVector.elements();
}
/**
* Gets options from this classifier.
*/
public String[] getOptions() {
String [] options = new String [10];
int current = 0;
options[current++] = "-K";
options[current++] = "" + getKValue();
options[current++] = "-M";
options[current++] = "" + getMinNum();
options[current++] = "-S";
options[current++] = "" + getSeed();
if (getDebug()) {
options[current++] = "-D";
}
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Parses a given list of options.
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception{
String kValueString = Utils.getOption('K', options);
if (kValueString.length() != 0) {
m_KValue = Integer.parseInt(kValueString);
} else {
m_KValue = 1;
}
String minNumString = Utils.getOption('M', options);
if (minNumString.length() != 0) {
m_MinNum = (double)Integer.parseInt(minNumString);
} else {
m_MinNum = 1;
}
String seed = Utils.getOption('S', options);
if (seed.length() != 0) {
setSeed(Integer.parseInt(seed));
} else {
setSeed(1);
}
m_Debug = Utils.getFlag('D', options);
Utils.checkForRemainingOptions(options);
}
/**
* Builds classifier.
*/
public void buildClassifier(Instances data) throws Exception {
// Make sure K value is in range
if (m_KValue > data.numAttributes()-1) m_KValue = data.numAttributes()-1;
// Check for non-nominal classes
if (!data.classAttribute().isNominal()) {
throw new UnsupportedClassTypeException("RandomTree: Nominal class, please.");
}
// Delete instances with missing class
data = new Instances(data);
data.deleteWithMissingClass();
if (data.numInstances() == 0) {
throw new IllegalArgumentException("RandomTree: zero training instances or all " +
"instances have missing class!");
}
if (data.numAttributes() == 1) {
throw new IllegalArgumentException("RandomTree: Attribute missing. Need at least " +
"one attribute other than class attribute!");
}
Instances train = data;
// Create array of sorted indices and weights
int[][] sortedIndices = new int[train.numAttributes()][0];
double[][] weights = new double[train.numAttributes()][0];
double[] vals = new double[train.numInstances()];
for (int j = 0; j < train.numAttributes(); j++) {
if (j != train.classIndex()) {
weights[j] = new double[train.numInstances()];
if (train.attribute(j).isNominal()) {
// Handling nominal attributes. Putting indices of
// instances with missing values at the end.
sortedIndices[j] = new int[train.numInstances()];
int count = 0;
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
if (!inst.isMissing(j)) {
sortedIndices[j][count] = i;
weights[j][count] = inst.weight();
count++;
}
}
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
if (inst.isMissing(j)) {
sortedIndices[j][count] = i;
weights[j][count] = inst.weight();
count++;
}
}
} else {
// Sorted indices are computed for numeric attributes
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
vals[i] = inst.value(j);
}
sortedIndices[j] = Utils.sort(vals);
for (int i = 0; i < train.numInstances(); i++) {
weights[j][i] = train.instance(sortedIndices[j][i]).weight();
}
}
}
}
// Compute initial class counts
double[] classProbs = new double[train.numClasses()];
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
classProbs[(int)inst.classValue()] += inst.weight();
}
// Create the attribute indices window
int[] attIndicesWindow = new int[data.numAttributes()-1];
int j=0;
for (int i=0; i<attIndicesWindow.length; i++) {
if (j == data.classIndex()) j++; // do not include the class
attIndicesWindow[i] = j++;
}
// Build tree
buildTree(sortedIndices, weights, train, classProbs,
new Instances(train, 0), m_MinNum, m_Debug,
attIndicesWindow, data.getRandomNumberGenerator(m_randomSeed));
}
/**
* Computes class distribution of an instance using the decision tree.
*/
public double[] distributionForInstance(Instance instance) throws Exception {
double[] returnedDist = null;
if (m_Attribute > -1) {
// Node is not a leaf
if (instance.isMissing(m_Attribute)) {
// Value is missing
returnedDist = new double[m_Info.numClasses()];
// Split instance up
for (int i = 0; i < m_Successors.length; i++) {
double[] help = m_Successors[i].distributionForInstance(instance);
if (help != null) {
for (int j = 0; j < help.length; j++) {
returnedDist[j] += m_Prop[i] * help[j];
}
}
}
} else if (m_Info.attribute(m_Attribute).isNominal()) {
// For nominal attributes
returnedDist = m_Successors[(int)instance.value(m_Attribute)].
distributionForInstance(instance);
} else {
// For numeric attributes
if (Utils.sm(instance.value(m_Attribute), m_SplitPoint)) {
returnedDist = m_Successors[0].distributionForInstance(instance);
} else {
returnedDist = m_Successors[1].distributionForInstance(instance);
}
}
}
if ((m_Attribute == -1) || (returnedDist == null)) {
// Node is a leaf or successor is empty
return m_ClassProbs;
} else {
return returnedDist;
}
}
/**
* Outputs the decision tree as a graph
*/
public String toGraph() {
try {
StringBuffer resultBuff = new StringBuffer();
toGraph(resultBuff, 0);
String result = "digraph Tree {\n" + "edge [style=bold]\n" + resultBuff.toString()
+ "\n}\n";
return result;
} catch (Exception e) {
return null;
}
}
/**
* Outputs one node for graph.
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -