⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 reptree.java

📁 为了下东西 随便发了个 datamining 的源代码
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    REPTree.java
 *    Copyright (C) 1999 Eibe Frank
 *
 */

package weka.classifiers.trees;

import java.io.Serializable;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.Sourcable;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.ContingencyTables;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.UnsupportedClassTypeException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/**
 * Fast decision tree learner. Builds a decision/regression tree using
 * information gain/variance reduction and prunes it using reduced-error pruning
 * (with backfitting).  Only sorts values for numeric attributes
 * once. Missing values are dealt with by splitting the corresponding
 * instances into pieces (i.e. as in C4.5).
 *
 * Valid options are: <p>
 *
 * -M number <br>
 * Set minimum number of instances per leaf (default 2). <p>
 *
 * -V number <br>
 * Set minimum numeric class variance proportion of train variance for
 * split (default 1e-3). <p>
 *			    
 * -N number <br>
 * Number of folds for reduced error pruning (default 3). <p>
 *
 * -S number <br> 
 * Seed for random data shuffling (default 1). <p>
 *
 * -P <br>
 * No pruning. <p>
 *
 * -L <br>
 * Maximum tree depth (default -1, no maximum). <p>
 *
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @version $Revision$ 
 */
public class REPTree extends Classifier 
  implements OptionHandler, WeightedInstancesHandler, Drawable, 
	     AdditionalMeasureProducer, Sourcable {

  /**
   * Returns a string describing classifier
   * @return a description suitable for
   * displaying in the explorer/experimenter gui
   */
  public String globalInfo() {

    return  "Fast decision tree learner. Builds a decision/regression tree using "
      + "information gain/variance and prunes it using reduced-error pruning "
      + "(with backfitting).  Only sorts values for numeric attributes "
      + "once. Missing values are dealt with by splitting the corresponding "
      + "instances into pieces (i.e. as in C4.5).";
  }

  /** An inner class for building and storing the tree structure */
  protected class Tree implements Serializable {
    
    /** The header information (for printing the tree). */
    protected Instances m_Info = null;

    /** The subtrees of this tree. */ 
    protected Tree[] m_Successors;
    
    /** The attribute to split on. */
    protected int m_Attribute = -1;

    /** The split point. */
    protected double m_SplitPoint = Double.NaN;
    
    /** The proportions of training instances going down each branch. */
    protected double[] m_Prop = null;

    /** Class probabilities from the training data in the nominal case. 
	Holds the mean in the numeric case. */
    protected double[] m_ClassProbs = null;
    
    /** The (unnormalized) class distribution in the nominal
	case. Holds the sum of squared errors and the weight 
	in the numeric case. */
    protected double[] m_Distribution = null;
    
    /** Class distribution of hold-out set at node in the nominal case. 
	Straight sum of weights in the numeric case (i.e. array has
	only one element. */
    protected double[] m_HoldOutDist = null;
    
    /** The hold-out error of the node. The number of miss-classified
	instances in the nominal case, the sum of squared errors in the 
	numeric case. */
    protected double m_HoldOutError = 0;
  
    /**
     * Computes class distribution of an instance using the tree.
     */
    protected double[] distributionForInstance(Instance instance) 
      throws Exception {
    
      double[] returnedDist = null;
      
      if (m_Attribute > -1) {
	
	// Node is not a leaf
	if (instance.isMissing(m_Attribute)) {
	  
	  // Value is missing
	  returnedDist = new double[m_Info.numClasses()];

	  // Split instance up
	  for (int i = 0; i < m_Successors.length; i++) {
	    double[] help = 
	      m_Successors[i].distributionForInstance(instance);
	    if (help != null) {
	      for (int j = 0; j < help.length; j++) {
		returnedDist[j] += m_Prop[i] * help[j];
	      }
	    }
	  }
	} else if (m_Info.attribute(m_Attribute).isNominal()) {
	  
	  // For nominal attributes
	  returnedDist =  m_Successors[(int)instance.value(m_Attribute)].
	    distributionForInstance(instance);
	} else {
	  
	  // For numeric attributes
	  if (instance.value(m_Attribute) < m_SplitPoint) {
	    returnedDist = 
	      m_Successors[0].distributionForInstance(instance);
	  } else {
	    returnedDist = 
	      m_Successors[1].distributionForInstance(instance);
	  }
	}
      }
      if ((m_Attribute == -1) || (returnedDist == null)) {
	
	// Node is a leaf or successor is empty
	return m_ClassProbs;
      } else {
	return returnedDist;
      }
    }

   /**
    * Returns a string containing java source code equivalent to the test
    * made at this node. The instance being tested is called "i". This
    * routine assumes to be called in the order of branching, enabling us to
    * set the >= condition test (the last one) of a numeric splitpoint 
    * to just "true" (because being there in the flow implies that the 
    * previous less-than test failed).
    *
    * @param index index of the value tested
    * @return a value of type 'String'
    */
    public final String sourceExpression(int index) {
      
      StringBuffer expr = null;
      if (index < 0) {
        return "i[" + m_Attribute + "] == null";
      }
      if (m_Info.attribute(m_Attribute).isNominal()) {
        expr = new StringBuffer("i[");
	expr.append(m_Attribute).append("]");
	expr.append(".equals(\"").append(m_Info.attribute(m_Attribute)
		.value(index)).append("\")");
      } else {
        expr = new StringBuffer("");
	if (index == 0) {
	  expr.append("((Double)i[")
	    .append(m_Attribute).append("]).doubleValue() < ")
	    .append(m_SplitPoint);
	} else {
	  expr.append("true");
	}
      }
      return expr.toString();
    }

   /**
    * Returns source code for the tree as if-then statements. The 
    * class is assigned to variable "p", and assumes the tested 
    * instance is named "i". The results are returned as two stringbuffers: 
    * a section of code for assignment of the class, and a section of
    * code containing support code (eg: other support methods).
    *
    * TODO: If the outputted source code encounters a missing value
    * for the evaluated attribute, it stops branching and uses the 
    * class distribution of the current node to decide the return value. 
    * This is unlike the behaviour of distributionForInstance(). 
    *
    * @param className the classname that this static classifier has
    * @param parent parent node of the current node 
    * @return an array containing two stringbuffers, the first string containing
    * assignment code, and the second containing source for support code.
    * @exception Exception if something goes wrong
    */
    public StringBuffer [] toSource(String className, Tree parent) 
      throws Exception {
    
      StringBuffer [] result = new StringBuffer[2];
      double[] currentProbs;

      if(m_ClassProbs == null)
        currentProbs = parent.m_ClassProbs;
      else
        currentProbs = m_ClassProbs;

      long printID = nextID();

      // Is this a leaf?
      if (m_Attribute == -1) {
        result[0] = new StringBuffer("	p = ");
 	if(m_Info.classAttribute().isNumeric())
	  result[0].append(currentProbs[0]);
	else {
	  result[0].append(Utils.maxIndex(currentProbs));
	}
	result[0].append(";\n");
	result[1] = new StringBuffer("");
      } else {
	StringBuffer text = new StringBuffer("");
	String nextIndent = "      ";
	StringBuffer atEnd = new StringBuffer("");

	text.append("  static double N")
	  .append(Integer.toHexString(this.hashCode()) + printID)
	  .append("(Object []i) {\n")
	  .append("    double p = Double.NaN;\n");

        text.append("    /* " + m_Info.attribute(m_Attribute).name() + " */\n");
	// Missing attribute?
	text.append("    if (" + this.sourceExpression(-1) + ") {\n")
	  .append("      p = ");
	if(m_Info.classAttribute().isNumeric())
	  text.append(currentProbs[0] + ";\n");
	else
	  text.append(Utils.maxIndex(currentProbs) + ";\n");
	text.append("    } ");
	
	// Branching of the tree
	for (int i=0;i<m_Successors.length; i++) {
          text.append("else if (" + this.sourceExpression(i) + ") {\n");
	  // Is the successor a leaf?
	  if(m_Successors[i].m_Attribute == -1) {
	    double[] successorProbs = m_Successors[i].m_ClassProbs;
	    if(successorProbs == null)
	      successorProbs = m_ClassProbs;
	    text.append("      p = ");
	    if(m_Info.classAttribute().isNumeric()) {
	      text.append(successorProbs[0] + ";\n");
	    } else {
	      text.append(Utils.maxIndex(successorProbs) + ";\n");
	    }
	  } else {
	    StringBuffer [] sub = m_Successors[i].toSource(className, this);
	    text.append("" + sub[0]);
            atEnd.append("" + sub[1]);
	  }
	  text.append("    } ");
	  if (i == m_Successors.length - 1) {
	    text.append("\n");
	  }
        }

        text.append("    return p;\n  }\n");

        result[0] = new StringBuffer("    p = " + className + ".N");
        result[0].append(Integer.toHexString(this.hashCode()) + printID)
          .append("(i);\n");
        result[1] = text.append("" + atEnd);
      }
      return result;
    }

	
    /**
     * Outputs one node for graph.
     */
    protected int toGraph(StringBuffer text, int num,
			Tree parent) throws Exception {
      
      num++;
      if (m_Attribute == -1) {
	text.append("N" + Integer.toHexString(Tree.this.hashCode()) +
		    " [label=\"" + num + leafString(parent) +"\"" +
		    "shape=box]\n");
      } else {
	text.append("N" + Integer.toHexString(Tree.this.hashCode()) +
		    " [label=\"" + num + ": " + 
		    m_Info.attribute(m_Attribute).name() + 
		    "\"]\n");
	for (int i = 0; i < m_Successors.length; i++) {
	  text.append("N" + Integer.toHexString(Tree.this.hashCode()) 
		      + "->" + 
		      "N" + 
		      Integer.toHexString(m_Successors[i].hashCode())  +
		      " [label=\"");
	  if (m_Info.attribute(m_Attribute).isNumeric()) {
	    if (i == 0) {
	      text.append(" < " +
			  Utils.doubleToString(m_SplitPoint, 2));
	    } else {
	      text.append(" >= " +
			  Utils.doubleToString(m_SplitPoint, 2));
	    }
	  } else {
	    text.append(" = " + m_Info.attribute(m_Attribute).value(i));
	  }
	  text.append("\"]\n");
	  num = m_Successors[i].toGraph(text, num, this);
	}
      }
      
      return num;
    }

    /**
     * Outputs description of a leaf node.
     */
    protected String leafString(Tree parent) throws Exception {
    
      if (m_Info.classAttribute().isNumeric()) {
	double classMean;
	if (m_ClassProbs == null) {
	  classMean = parent.m_ClassProbs[0];
	} else {
	  classMean = m_ClassProbs[0];
	}
	StringBuffer buffer = new StringBuffer();
	buffer.append(" : " + Utils.doubleToString(classMean, 2));
	double avgError = 0;
	if (m_Distribution[1] > 0) {
	  avgError = m_Distribution[0] / m_Distribution[1];
	}
	buffer.append(" (" +
		      Utils.doubleToString(m_Distribution[1], 2) + "/" +
		      Utils.doubleToString(avgError, 2) 
		      + ")");
	avgError = 0;
	if (m_HoldOutDist[0] > 0) {
	  avgError = m_HoldOutError / m_HoldOutDist[0];
	}
	buffer.append(" [" +
		      Utils.doubleToString(m_HoldOutDist[0], 2) + "/" +
		      Utils.doubleToString(avgError, 2) 
		      + "]");
	return buffer.toString();
      } else { 
	int maxIndex;
	if (m_ClassProbs == null) {
	  maxIndex = Utils.maxIndex(parent.m_ClassProbs);
	} else {
	  maxIndex = Utils.maxIndex(m_ClassProbs);
	}
	return " : " + m_Info.classAttribute().value(maxIndex) + 
	  " (" + Utils.doubleToString(Utils.sum(m_Distribution), 2) + 
	  "/" + 
	  Utils.doubleToString((Utils.sum(m_Distribution) - 
				m_Distribution[maxIndex]), 2) + ")" +
	  " [" + Utils.doubleToString(Utils.sum(m_HoldOutDist), 2) + "/" + 
	  Utils.doubleToString((Utils.sum(m_HoldOutDist) - 
				m_HoldOutDist[maxIndex]), 2) + "]";
      }
    }
  
    /**
     * Recursively outputs the tree.
     */
    protected String toString(int level, Tree parent) {

      try {
	StringBuffer text = new StringBuffer();
      
	if (m_Attribute == -1) {
	
	  // Output leaf info
	  return leafString(parent);
	} else if (m_Info.attribute(m_Attribute).isNominal()) {
	
	  // For nominal attributes
	  for (int i = 0; i < m_Successors.length; i++) {
	    text.append("\n");
	    for (int j = 0; j < level; j++) {
	      text.append("|   ");
	    }
	    text.append(m_Info.attribute(m_Attribute).name() + " = " +
			m_Info.attribute(m_Attribute).value(i));
	    text.append(m_Successors[i].toString(level + 1, this));
	  }
	} else {
	
	  // For numeric attributes
	  text.append("\n");
	  for (int j = 0; j < level; j++) {
	    text.append("|   ");
	  }
	  text.append(m_Info.attribute(m_Attribute).name() + " < " +
		      Utils.doubleToString(m_SplitPoint, 2));
	  text.append(m_Successors[0].toString(level + 1, this));
	  text.append("\n");
	  for (int j = 0; j < level; j++) {
	    text.append("|   ");
	  }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -