📄 decisionstump.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* DecisionStump.java
* Copyright (C) 1999 Eibe Frank
*
*/
package weka.classifiers.trees;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.Sourcable;
import weka.core.Attribute;
import weka.core.ContingencyTables;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.UnsupportedAttributeTypeException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
/**
* Class for building and using a decision stump. Usually used in conjunction
* with a boosting algorithm.
*
* Typical usage: <p>
* <code>java weka.classifiers.trees.LogitBoost -I 100 -W weka.classifiers.trees.DecisionStump
* -t training_data </code><p>
*
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @version $Revision$
*/
public class DecisionStump extends Classifier
implements WeightedInstancesHandler, Sourcable {
/** The attribute used for classification. */
private int m_AttIndex;
/** The split point (index respectively). */
private double m_SplitPoint;
/** The distribution of class values or the means in each subset. */
private double[][] m_Distribution;
/** The instances used for training. */
private Instances m_Instances;
/**
* Returns a string describing classifier
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Class for building and using a decision stump. Usually used in "
+ "conjunction with a boosting algorithm. Does regression (based on "
+ "mean-squared error) or classification (based on entropy). Missing "
+ "is treated as a separate value.";
}
/**
* Generates the classifier.
*
* @param instances set of instances serving as training data
* @exception Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances instances) throws Exception {
double bestVal = Double.MAX_VALUE, currVal;
double bestPoint = -Double.MAX_VALUE, sum;
int bestAtt = -1, numClasses;
if (instances.checkForStringAttributes()) {
throw new UnsupportedAttributeTypeException("Can't handle string attributes!");
}
double[][] bestDist = new double[3][instances.numClasses()];
m_Instances = new Instances(instances);
m_Instances.deleteWithMissingClass();
if (m_Instances.numInstances() == 0) {
throw new IllegalArgumentException("No instances without missing " +
"class values in training file!");
}
if (instances.numAttributes() == 1) {
throw new IllegalArgumentException("Attribute missing. Need at least one " +
"attribute other than class attribute!");
}
if (m_Instances.classAttribute().isNominal()) {
numClasses = m_Instances.numClasses();
} else {
numClasses = 1;
}
// For each attribute
boolean first = true;
for (int i = 0; i < m_Instances.numAttributes(); i++) {
if (i != m_Instances.classIndex()) {
// Reserve space for distribution.
m_Distribution = new double[3][numClasses];
// Compute value of criterion for best split on attribute
if (m_Instances.attribute(i).isNominal()) {
currVal = findSplitNominal(i);
} else {
currVal = findSplitNumeric(i);
}
if ((first) || (currVal < bestVal)) {
bestVal = currVal;
bestAtt = i;
bestPoint = m_SplitPoint;
for (int j = 0; j < 3; j++) {
System.arraycopy(m_Distribution[j], 0, bestDist[j], 0,
numClasses);
}
}
// First attribute has been investigated
first = false;
}
}
// Set attribute, split point and distribution.
m_AttIndex = bestAtt;
m_SplitPoint = bestPoint;
m_Distribution = bestDist;
if (m_Instances.classAttribute().isNominal()) {
for (int i = 0; i < m_Distribution.length; i++) {
double sumCounts = Utils.sum(m_Distribution[i]);
if (sumCounts == 0) { // This means there were only missing attribute values
System.arraycopy(m_Distribution[2], 0, m_Distribution[i], 0,
m_Distribution[2].length);
Utils.normalize(m_Distribution[i]);
} else {
Utils.normalize(m_Distribution[i], sumCounts);
}
}
}
// Save memory
m_Instances = new Instances(m_Instances, 0);
}
/**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @exception Exception if distribution can't be computed
*/
public double[] distributionForInstance(Instance instance) throws Exception {
return m_Distribution[whichSubset(instance)];
}
/**
* Returns the decision tree as Java source code.
*
* @return the tree as Java source code
* @exception Exception if something goes wrong
*/
public String toSource(String className) throws Exception {
StringBuffer text = new StringBuffer("class ");
Attribute c = m_Instances.classAttribute();
text.append(className)
.append(" {\n"
+" public static double classify(Object [] i) {\n");
text.append(" /* " + m_Instances.attribute(m_AttIndex).name() + " */\n");
text.append(" if (i[").append(m_AttIndex);
text.append("] == null) { return ");
text.append(sourceClass(c, m_Distribution[2])).append(";");
if (m_Instances.attribute(m_AttIndex).isNominal()) {
text.append(" } else if (((String)i[").append(m_AttIndex);
text.append("]).equals(\"");
text.append(m_Instances.attribute(m_AttIndex).value((int)m_SplitPoint));
text.append("\")");
} else {
text.append(" } else if (((Double)i[").append(m_AttIndex);
text.append("]).doubleValue() <= ").append(m_SplitPoint);
}
text.append(") { return ");
text.append(sourceClass(c, m_Distribution[0])).append(";");
text.append(" } else { return ");
text.append(sourceClass(c, m_Distribution[1])).append(";");
text.append(" }\n }\n}\n");
return text.toString();
}
private String sourceClass(Attribute c, double []dist) {
if (c.isNominal()) {
return Integer.toString(Utils.maxIndex(dist));
} else {
return Double.toString(dist[0]);
}
}
/**
* Returns a description of the classifier.
*
* @return a description of the classifier as a string.
*/
public String toString(){
if (m_Instances == null) {
return "Decision Stump: No model built yet.";
}
try {
StringBuffer text = new StringBuffer();
text.append("Decision Stump\n\n");
text.append("Classifications\n\n");
Attribute att = m_Instances.attribute(m_AttIndex);
if (att.isNominal()) {
text.append(att.name() + " = " + att.value((int)m_SplitPoint) +
" : ");
text.append(printClass(m_Distribution[0]));
text.append(att.name() + " != " + att.value((int)m_SplitPoint) +
" : ");
text.append(printClass(m_Distribution[1]));
} else {
text.append(att.name() + " <= " + m_SplitPoint + " : ");
text.append(printClass(m_Distribution[0]));
text.append(att.name() + " > " + m_SplitPoint + " : ");
text.append(printClass(m_Distribution[1]));
}
text.append(att.name() + " is missing : ");
text.append(printClass(m_Distribution[2]));
if (m_Instances.classAttribute().isNominal()) {
text.append("\nClass distributions\n\n");
if (att.isNominal()) {
text.append(att.name() + " = " + att.value((int)m_SplitPoint) +
"\n");
text.append(printDist(m_Distribution[0]));
text.append(att.name() + " != " + att.value((int)m_SplitPoint) +
"\n");
text.append(printDist(m_Distribution[1]));
} else {
text.append(att.name() + " <= " + m_SplitPoint + "\n");
text.append(printDist(m_Distribution[0]));
text.append(att.name() + " > " + m_SplitPoint + "\n");
text.append(printDist(m_Distribution[1]));
}
text.append(att.name() + " is missing\n");
text.append(printDist(m_Distribution[2]));
}
return text.toString();
} catch (Exception e) {
return "Can't print decision stump classifier!";
}
}
/**
* Prints a class distribution.
*
* @param dist the class distribution to print
* @return the distribution as a string
* @exception Exception if distribution can't be printed
*/
private String printDist(double[] dist) throws Exception {
StringBuffer text = new StringBuffer();
if (m_Instances.classAttribute().isNominal()) {
for (int i = 0; i < m_Instances.numClasses(); i++) {
text.append(m_Instances.classAttribute().value(i) + "\t");
}
text.append("\n");
for (int i = 0; i < m_Instances.numClasses(); i++) {
text.append(dist[i] + "\t");
}
text.append("\n");
}
return text.toString();
}
/**
* Prints a classification.
*
* @param dist the class distribution
* @return the classificationn as a string
* @exception Exception if the classification can't be printed
*/
private String printClass(double[] dist) throws Exception {
StringBuffer text = new StringBuffer();
if (m_Instances.classAttribute().isNominal()) {
text.append(m_Instances.classAttribute().value(Utils.maxIndex(dist)));
} else {
text.append(dist[0]);
}
return text.toString() + "\n";
}
/**
* Finds best split for nominal attribute and returns value.
*
* @param index attribute index
* @return value of criterion for the best split
* @exception Exception if something goes wrong
*/
private double findSplitNominal(int index) throws Exception {
if (m_Instances.classAttribute().isNominal()) {
return findSplitNominalNominal(index);
} else {
return findSplitNominalNumeric(index);
}
}
/**
* Finds best split for nominal attribute and nominal class
* and returns value.
*
* @param index attribute index
* @return value of criterion for the best split
* @exception Exception if something goes wrong
*/
private double findSplitNominalNominal(int index) throws Exception {
double bestVal = Double.MAX_VALUE, currVal;
double[][] counts = new double[m_Instances.attribute(index).numValues()
+ 1][m_Instances.numClasses()];
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -