📄 decisionstump.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * DecisionStump.java * Copyright (C) 1999 Eibe Frank * */package weka.classifiers.trees;import weka.classifiers.meta.LogitBoost;import weka.classifiers.Classifier;import weka.classifiers.DistributionClassifier;import weka.classifiers.Evaluation;import weka.classifiers.meta.LogitBoost;import weka.classifiers.Sourcable;import java.io.*;import java.util.*;import weka.core.*;/** * Class for building and using a decision stump. Usually used in conjunction * with a boosting algorithm. * * Typical usage: <p> * <code>java weka.classifiers.trees.LogitBoost -I 100 -W weka.classifiers.trees.DecisionStump * -t training_data </code><p> * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @version $Revision: 1.1.1.1 $ */public class DecisionStump extends DistributionClassifier implements WeightedInstancesHandler, Sourcable { /** The attribute used for classification. */ private int m_AttIndex; /** The split point (index respectively). */ private double m_SplitPoint; /** The distribution of class values or the means in each subset. */ private double[][] m_Distribution; /** The instances used for training. */ private Instances m_Instances; /** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { double bestVal = Double.MAX_VALUE, currVal; double bestPoint = -Double.MAX_VALUE, sum; int bestAtt = -1, numClasses; if (instances.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Can't handle string attributes!"); } double[][] bestDist = new double[3][instances.numClasses()]; m_Instances = new Instances(instances); m_Instances.deleteWithMissingClass(); if (m_Instances.classAttribute().isNominal()) { numClasses = m_Instances.numClasses(); } else { numClasses = 1; } // For each attribute boolean first = true; for (int i = 0; i < m_Instances.numAttributes(); i++) { if (i != m_Instances.classIndex()) { // Reserve space for distribution. m_Distribution = new double[3][numClasses]; // Compute value of criterion for best split on attribute if (m_Instances.attribute(i).isNominal()) { currVal = findSplitNominal(i); } else { currVal = findSplitNumeric(i); } if ((first) || (Utils.sm(currVal, bestVal))) { bestVal = currVal; bestAtt = i; bestPoint = m_SplitPoint; for (int j = 0; j < 3; j++) { System.arraycopy(m_Distribution[j], 0, bestDist[j], 0, numClasses); } } // First attribute has been investigated first = false; } } // Set attribute, split point and distribution. m_AttIndex = bestAtt; m_SplitPoint = bestPoint; m_Distribution = bestDist; if (m_Instances.classAttribute().isNominal()) { for (int i = 0; i < m_Distribution.length; i++) { Utils.normalize(m_Distribution[i]); } } // Save memory m_Instances = new Instances(m_Instances, 0); } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if distribution can't be computed */ public double[] distributionForInstance(Instance instance) throws Exception { return m_Distribution[whichSubset(instance)]; } /** * Returns the decision tree as Java source code. * * @return the tree as Java source code * @exception Exception if something goes wrong */ public String toSource(String className) throws Exception { StringBuffer text = new StringBuffer("class "); Attribute c = m_Instances.classAttribute(); text.append(className) .append(" {\n" +" public static double classify(Object [] i) {\n"); text.append(" /* " + m_Instances.attribute(m_AttIndex).name() + " */\n"); text.append(" if (i[").append(m_AttIndex); text.append("] == null) { return "); text.append(sourceClass(c, m_Distribution[2])).append(";"); if (m_Instances.attribute(m_AttIndex).isNominal()) { text.append(" } else if (((String)i[").append(m_AttIndex); text.append("]).equals(\""); text.append(m_Instances.attribute(m_AttIndex).value((int)m_SplitPoint)); text.append("\")"); } else { text.append(" } else if (((Double)i[").append(m_AttIndex); text.append("]).doubleValue() <= ").append(m_SplitPoint); } text.append(") { return "); text.append(sourceClass(c, m_Distribution[0])).append(";"); text.append(" } else { return "); text.append(sourceClass(c, m_Distribution[1])).append(";"); text.append(" }\n }\n}\n"); return text.toString(); } private String sourceClass(Attribute c, double []dist) { if (c.isNominal()) { return Integer.toString(Utils.maxIndex(dist)); } else { return Double.toString(dist[0]); } } /** * Returns a description of the classifier. * * @return a description of the classifier as a string. */ public String toString(){ if (m_Instances == null) { return "Decision Stump: No model built yet."; } try { StringBuffer text = new StringBuffer(); text.append("Decision Stump\n\n"); text.append("Classifications\n\n"); Attribute att = m_Instances.attribute(m_AttIndex); if (att.isNominal()) { text.append(att.name() + " = " + att.value((int)m_SplitPoint) + " : "); text.append(printClass(m_Distribution[0])); text.append(att.name() + " != " + att.value((int)m_SplitPoint) + " : "); text.append(printClass(m_Distribution[1])); } else { text.append(att.name() + " <= " + m_SplitPoint + " : "); text.append(printClass(m_Distribution[0])); text.append(att.name() + " > " + m_SplitPoint + " : "); text.append(printClass(m_Distribution[1])); } text.append(att.name() + " is missing : "); text.append(printClass(m_Distribution[2])); if (m_Instances.classAttribute().isNominal()) { text.append("\nClass distributions\n\n"); if (att.isNominal()) { text.append(att.name() + " = " + att.value((int)m_SplitPoint) + "\n"); text.append(printDist(m_Distribution[0])); text.append(att.name() + " != " + att.value((int)m_SplitPoint) + "\n"); text.append(printDist(m_Distribution[1])); } else { text.append(att.name() + " <= " + m_SplitPoint + "\n"); text.append(printDist(m_Distribution[0])); text.append(att.name() + " > " + m_SplitPoint + "\n"); text.append(printDist(m_Distribution[1])); } text.append(att.name() + " is missing\n"); text.append(printDist(m_Distribution[2])); } return text.toString(); } catch (Exception e) { return "Can't print decision stump classifier!"; } } /** * Prints a class distribution. * * @param dist the class distribution to print * @return the distribution as a string * @exception Exception if distribution can't be printed */ private String printDist(double[] dist) throws Exception { StringBuffer text = new StringBuffer(); if (m_Instances.classAttribute().isNominal()) { for (int i = 0; i < m_Instances.numClasses(); i++) { text.append(m_Instances.classAttribute().value(i) + "\t"); } text.append("\n"); for (int i = 0; i < m_Instances.numClasses(); i++) { text.append(dist[i] + "\t"); } text.append("\n"); } return text.toString(); } /** * Prints a classification. * * @param dist the class distribution * @return the classificationn as a string * @exception Exception if the classification can't be printed */ private String printClass(double[] dist) throws Exception { StringBuffer text = new StringBuffer(); if (m_Instances.classAttribute().isNominal()) { text.append(m_Instances.classAttribute().value(Utils.maxIndex(dist))); } else { text.append(dist[0]); } return text.toString() + "\n"; } /** * Finds best split for nominal attribute and returns value. * * @param index attribute index * @return value of criterion for the best split * @exception Exception if something goes wrong */ private double findSplitNominal(int index) throws Exception { if (m_Instances.classAttribute().isNominal()) { return findSplitNominalNominal(index); } else { return findSplitNominalNumeric(index); } } /** * Finds best split for nominal attribute and nominal class * and returns value. * * @param index attribute index * @return value of criterion for the best split * @exception Exception if something goes wrong */ private double findSplitNominalNominal(int index) throws Exception { double bestVal = Double.MAX_VALUE, currVal; double[][] counts = new double[m_Instances.attribute(index).numValues() + 1][m_Instances.numClasses()]; double[] sumCounts = new double[m_Instances.numClasses()]; double[][] bestDist = new double[3][m_Instances.numClasses()]; int numMissing = 0; // Compute counts for all the values for (int i = 0; i < m_Instances.numInstances(); i++) { Instance inst = m_Instances.instance(i); if (inst.isMissing(index)) { numMissing++; counts[m_Instances.attribute(index).numValues()] [(int)inst.classValue()] += inst.weight(); } else { counts[(int)inst.value(index)][(int)inst.classValue()] += inst
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -