⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 decisionstump.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    DecisionStump.java
 *    Copyright (C) 1999 Eibe Frank
 *
 */

package weka.classifiers.trees;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.Sourcable;
import weka.core.Attribute;
import weka.core.ContingencyTables;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.UnsupportedAttributeTypeException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/**
 * Class for building and using a decision stump. Usually used in conjunction
 * with a boosting algorithm.
 *
 * Typical usage: <p>
 * <code>java weka.classifiers.trees.LogitBoost -I 100 -W weka.classifiers.trees.DecisionStump 
 * -t training_data </code><p>
 * 
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @version $Revision$
 */
public class DecisionStump extends Classifier 
  implements WeightedInstancesHandler, Sourcable {

  /** The attribute used for classification. */
  private int m_AttIndex;

  /** The split point (index respectively). */
  private double m_SplitPoint;

  /** The distribution of class values or the means in each subset. */
  private double[][] m_Distribution;

  /** The instances used for training. */
  private Instances m_Instances;

  /**
   * Returns a string describing classifier
   * @return a description suitable for
   * displaying in the explorer/experimenter gui
   */
  public String globalInfo() {

    return  "Class for building and using a decision stump. Usually used in "
      + "conjunction with a boosting algorithm. Does regression (based on "
      + "mean-squared error) or classification (based on entropy). Missing "
      + "is treated as a separate value.";
  }

  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data 
   * @exception Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {
    
    double bestVal = Double.MAX_VALUE, currVal;
    double bestPoint = -Double.MAX_VALUE, sum;
    int bestAtt = -1, numClasses;

    if (instances.checkForStringAttributes()) {
      throw new UnsupportedAttributeTypeException("Can't handle string attributes!");
    }

    double[][] bestDist = new double[3][instances.numClasses()];

    m_Instances = new Instances(instances);
    m_Instances.deleteWithMissingClass();

    if (m_Instances.numInstances() == 0) {
      throw new IllegalArgumentException("No instances without missing " +
					 "class values in training file!");
    }

    if (instances.numAttributes() == 1) {
      throw new IllegalArgumentException("Attribute missing. Need at least one " +
					 "attribute other than class attribute!");
    }

    if (m_Instances.classAttribute().isNominal()) {
      numClasses = m_Instances.numClasses();
    } else {
      numClasses = 1;
    }

    // For each attribute
    boolean first = true;
    for (int i = 0; i < m_Instances.numAttributes(); i++) {
      if (i != m_Instances.classIndex()) {

	// Reserve space for distribution.
	m_Distribution = new double[3][numClasses];

	// Compute value of criterion for best split on attribute
	if (m_Instances.attribute(i).isNominal()) {
	  currVal = findSplitNominal(i);
	} else {
	  currVal = findSplitNumeric(i);
	}
	if ((first) || (currVal < bestVal)) {
	  bestVal = currVal;
	  bestAtt = i;
	  bestPoint = m_SplitPoint;
	  for (int j = 0; j < 3; j++) {
	    System.arraycopy(m_Distribution[j], 0, bestDist[j], 0, 
			     numClasses);
	  }
	}
	
	// First attribute has been investigated
	first = false;
      }
    }
    
    // Set attribute, split point and distribution.
    m_AttIndex = bestAtt;
    m_SplitPoint = bestPoint;
    m_Distribution = bestDist;
    if (m_Instances.classAttribute().isNominal()) {
      for (int i = 0; i < m_Distribution.length; i++) {
	double sumCounts = Utils.sum(m_Distribution[i]);
	if (sumCounts == 0) { // This means there were only missing attribute values
	  System.arraycopy(m_Distribution[2], 0, m_Distribution[i], 0, 
			   m_Distribution[2].length);
	  Utils.normalize(m_Distribution[i]);
	} else {
	  Utils.normalize(m_Distribution[i], sumCounts); 
	}
      }
    }
    
    // Save memory
    m_Instances = new Instances(m_Instances, 0);
  }

  /**
   * Calculates the class membership probabilities for the given test instance.
   *
   * @param instance the instance to be classified
   * @return predicted class probability distribution
   * @exception Exception if distribution can't be computed
   */
  public double[] distributionForInstance(Instance instance) throws Exception {

    return m_Distribution[whichSubset(instance)];
  }

  /**
   * Returns the decision tree as Java source code.
   *
   * @return the tree as Java source code
   * @exception Exception if something goes wrong
   */
  public String toSource(String className) throws Exception {

    StringBuffer text = new StringBuffer("class ");
    Attribute c = m_Instances.classAttribute();
    text.append(className)
      .append(" {\n"
	      +"  public static double classify(Object [] i) {\n");
    text.append("    /* " + m_Instances.attribute(m_AttIndex).name() + " */\n");
    text.append("    if (i[").append(m_AttIndex);
    text.append("] == null) { return ");
    text.append(sourceClass(c, m_Distribution[2])).append(";");
    if (m_Instances.attribute(m_AttIndex).isNominal()) {
      text.append(" } else if (((String)i[").append(m_AttIndex);
      text.append("]).equals(\"");
      text.append(m_Instances.attribute(m_AttIndex).value((int)m_SplitPoint));
      text.append("\")");
    } else {
      text.append(" } else if (((Double)i[").append(m_AttIndex);
      text.append("]).doubleValue() <= ").append(m_SplitPoint);
    }
    text.append(") { return ");
    text.append(sourceClass(c, m_Distribution[0])).append(";");
    text.append(" } else { return ");
    text.append(sourceClass(c, m_Distribution[1])).append(";");
    text.append(" }\n  }\n}\n");
    return text.toString();
  }

  private String sourceClass(Attribute c, double []dist) {

    if (c.isNominal()) {
      return Integer.toString(Utils.maxIndex(dist));
    } else {
      return Double.toString(dist[0]);
    }
  }

  /**
   * Returns a description of the classifier.
   *
   * @return a description of the classifier as a string.
   */
  public String toString(){

    if (m_Instances == null) {
      return "Decision Stump: No model built yet.";
    }
    try {
      StringBuffer text = new StringBuffer();
      
      text.append("Decision Stump\n\n");
      text.append("Classifications\n\n");
      Attribute att = m_Instances.attribute(m_AttIndex);
      if (att.isNominal()) {
	text.append(att.name() + " = " + att.value((int)m_SplitPoint) + 
		    " : ");
	text.append(printClass(m_Distribution[0]));
	text.append(att.name() + " != " + att.value((int)m_SplitPoint) + 
		    " : ");
	text.append(printClass(m_Distribution[1]));
      } else {
	text.append(att.name() + " <= " + m_SplitPoint + " : ");
	text.append(printClass(m_Distribution[0]));
	text.append(att.name() + " > " + m_SplitPoint + " : ");
	text.append(printClass(m_Distribution[1]));
      }
      text.append(att.name() + " is missing : ");
      text.append(printClass(m_Distribution[2]));

      if (m_Instances.classAttribute().isNominal()) {
	text.append("\nClass distributions\n\n");
	if (att.isNominal()) {
	  text.append(att.name() + " = " + att.value((int)m_SplitPoint) + 
		      "\n");
	  text.append(printDist(m_Distribution[0]));
	  text.append(att.name() + " != " + att.value((int)m_SplitPoint) + 
		      "\n");
	  text.append(printDist(m_Distribution[1]));
	} else {
	  text.append(att.name() + " <= " + m_SplitPoint + "\n");
	  text.append(printDist(m_Distribution[0]));
	  text.append(att.name() + " > " + m_SplitPoint + "\n");
	  text.append(printDist(m_Distribution[1]));
	}
	text.append(att.name() + " is missing\n");
	text.append(printDist(m_Distribution[2]));
      }

      return text.toString();
    } catch (Exception e) {
      return "Can't print decision stump classifier!";
    }
  }

  /** 
   * Prints a class distribution.
   *
   * @param dist the class distribution to print
   * @return the distribution as a string
   * @exception Exception if distribution can't be printed
   */
  private String printDist(double[] dist) throws Exception {

    StringBuffer text = new StringBuffer();
    
    if (m_Instances.classAttribute().isNominal()) {
      for (int i = 0; i < m_Instances.numClasses(); i++) {
	text.append(m_Instances.classAttribute().value(i) + "\t");
      }
      text.append("\n");
      for (int i = 0; i < m_Instances.numClasses(); i++) {
	text.append(dist[i] + "\t");
      }
      text.append("\n");
    }
    
    return text.toString();
  }

  /** 
   * Prints a classification.
   *
   * @param dist the class distribution
   * @return the classificationn as a string
   * @exception Exception if the classification can't be printed
   */
  private String printClass(double[] dist) throws Exception {

    StringBuffer text = new StringBuffer();
    
    if (m_Instances.classAttribute().isNominal()) {
      text.append(m_Instances.classAttribute().value(Utils.maxIndex(dist)));
    } else {
      text.append(dist[0]);
    }
    
    return text.toString() + "\n";
  }

  /**
   * Finds best split for nominal attribute and returns value.
   *
   * @param index attribute index
   * @return value of criterion for the best split
   * @exception Exception if something goes wrong
   */
  private double findSplitNominal(int index) throws Exception {

    if (m_Instances.classAttribute().isNominal()) {
      return findSplitNominalNominal(index);
    } else {
      return findSplitNominalNumeric(index);
    }
  }

  /**
   * Finds best split for nominal attribute and nominal class
   * and returns value.
   *
   * @param index attribute index
   * @return value of criterion for the best split
   * @exception Exception if something goes wrong
   */
  private double findSplitNominalNominal(int index) throws Exception {

    double bestVal = Double.MAX_VALUE, currVal;
    double[][] counts = new double[m_Instances.attribute(index).numValues() 
				  + 1][m_Instances.numClasses()];

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -