⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 rulenode.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/*
 *    RuleNode.java
 *    Copyright (C) 2000 Mark Hall
 *
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */
package weka.classifiers.trees.m5;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.LinearRegression;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

/**
 * Constructs a node for use in an m5 tree or rule
 *
 * @author Mark Hall (mhall@cs.waikato.ac.nz)
 * @version $Revision$
 */
public class RuleNode extends Classifier {

  /**
   * instances reaching this node
   */
  private Instances	   m_instances;

  /**
   * the class index
   */
  private int		   m_classIndex;

  /**
   * the number of instances reaching this node
   */
  protected int		   m_numInstances;

  /**
   * the number of attributes
   */
  private int		   m_numAttributes;

  /**
   * Node is a leaf
   */
  private boolean	   m_isLeaf;

  /**
   * attribute this node splits on
   */
  private int		   m_splitAtt;

  /**
   * the value of the split attribute
   */
  private double	   m_splitValue;

  /**
   * the linear model at this node
   */
  private PreConstructedLinearModel m_nodeModel;

  /**
   * the number of paramters in the chosen model for this node---either
   * the subtree model or the linear model.
   * The constant term is counted as a paramter---this is for pruning
   * purposes
   */
  public int		   m_numParameters;

  /**
   * the mean squared error of the model at this node (either linear or
   * subtree)
   */
  private double	   m_rootMeanSquaredError;

  /**
   * child nodes
   */
  protected RuleNode	   m_left;
  protected RuleNode	   m_right;

  /**
   * the parent of this node
   */
  private RuleNode	   m_parent;

  /**
   * a node will not be split if it contains less then m_splitNum instances
   */
  private double	   m_splitNum = 4;

  /**
   * a node will not be split if its class standard deviation is less
   * than 5% of the class standard deviation of all the instances
   */
  private double	   m_devFraction = 0.05;
  private double	   m_pruningMultiplier = 2;

  /**
   * the number assigned to the linear model if this node is a leaf.
   * = 0 if this node is not a leaf
   */
  private int		   m_leafModelNum;

  /**
   * a node will not be split if the class deviation of its
   * instances is less than m_devFraction of the deviation of the
   * global class
   */
  private double	   m_globalDeviation;

  /**
   * the absolute deviation of the global class
   */
  private double	   m_globalAbsDeviation;

  /**
   * Indices of the attributes to be used in generating a linear model
   * at this node
   */
  private int [] m_indices;
    
  /**
   * Constant used in original m5 smoothing calculation
   */
  private static final double	   SMOOTHING_CONSTANT = 15.0;

  /**
   * Node id.
   */
  private int m_id;

  /**
   * Save the instances at each node (for visualizing in the
   * Explorer's treevisualizer.
   */
  private boolean m_saveInstances = false;

  /**
   * Make a regression tree instead of a model tree
   */
  private boolean m_regressionTree;

  /**
   * Creates a new <code>RuleNode</code> instance.
   *
   * @param globalDev the global standard deviation of the class
   * @param globalAbsDev the global absolute deviation of the class
   * @param parent the parent of this node
   */
  public RuleNode(double globalDev, double globalAbsDev, RuleNode parent) {
    m_nodeModel = null;
    m_right = null;
    m_left = null;
    m_parent = parent;
    m_globalDeviation = globalDev;
    m_globalAbsDeviation = globalAbsDev;
  }

    
  /**
   * Build this node (find an attribute and split point)
   *
   * @param data the instances on which to build this node
   * @exception Exception if an error occurs
   */
  public void buildClassifier(Instances data) throws Exception {

    m_rootMeanSquaredError = Double.MAX_VALUE;
    //    m_instances = new Instances(data);
    m_instances = data;
    m_classIndex = m_instances.classIndex();
    m_numInstances = m_instances.numInstances();
    m_numAttributes = m_instances.numAttributes();
    m_nodeModel = null;
    m_right = null;
    m_left = null;

    if ((m_numInstances < m_splitNum) 
	|| (Rule.stdDev(m_classIndex, m_instances) 
	    < (m_globalDeviation * m_devFraction))) {
      m_isLeaf = true;
    } else {
      m_isLeaf = false;
    } 

    split();
  } 
 
  /**
   * Classify an instance using this node. Recursively calls classifyInstance
   * on child nodes.
   *
   * @param inst the instance to classify
   * @return the prediction for this instance
   * @exception Exception if an error occurs
   */
  public double classifyInstance(Instance inst) throws Exception {
    double   pred;
    double   n = 0;
    Instance tempInst;

    if (m_isLeaf) {
      if (m_nodeModel == null) {
	throw new Exception("Classifier has not been built correctly.");
      } 

      return m_nodeModel.classifyInstance(inst);
    }

    if (inst.value(m_splitAtt) <= m_splitValue) {
      return m_left.classifyInstance(inst);
    } else {
      return m_right.classifyInstance(inst);
    } 
  } 

  /**
   * Applies the m5 smoothing procedure to a prediction
   *
   * @param n number of instances in selected child of this node
   * @param pred the prediction so far
   * @param supportPred the prediction of the linear model at this node
   * @return the current prediction smoothed with the prediction of the
   * linear model at this node
   * @exception Exception if an error occurs
   */
  protected static double smoothingOriginal(double n, double pred, 
					    double supportPred) 
    throws Exception {
    double   smoothed;

    smoothed = 
      ((n * pred) + (SMOOTHING_CONSTANT * supportPred)) /
      (n + SMOOTHING_CONSTANT);

    return smoothed;
  } 


  /**
   * Finds an attribute and split point for this node
   *
   * @exception Exception if an error occurs
   */
  public void split() throws Exception {
    int		  i;
    Instances     leftSubset, rightSubset;
    SplitEvaluate bestSplit, currentSplit;
    boolean[]     attsBelow;

    if (!m_isLeaf) {
     
      bestSplit = new YongSplitInfo(0, m_numInstances - 1, -1);
      currentSplit = new YongSplitInfo(0, m_numInstances - 1, -1);

      // find the best attribute to split on
      for (i = 0; i < m_numAttributes; i++) {
	if (i != m_classIndex) {

	  // sort the instances by this attribute
	  m_instances.sort(i);
	  currentSplit.attrSplit(i, m_instances);

	  if ((Math.abs(currentSplit.maxImpurity() - 
			bestSplit.maxImpurity()) > 1.e-6) 
	      && (currentSplit.maxImpurity() 
		  > bestSplit.maxImpurity() + 1.e-6)) {
	    bestSplit = currentSplit.copy();
	  } 
	} 
      } 

      // cant find a good split or split point?
      if (bestSplit.splitAttr() < 0 || bestSplit.position() < 1 
	  || bestSplit.position() > m_numInstances - 1) {
	m_isLeaf = true;
      } else {
	m_splitAtt = bestSplit.splitAttr();
	m_splitValue = bestSplit.splitValue();
	leftSubset = new Instances(m_instances, m_numInstances);
	rightSubset = new Instances(m_instances, m_numInstances);

	for (i = 0; i < m_numInstances; i++) {
	  if (m_instances.instance(i).value(m_splitAtt) <= m_splitValue) {
	    leftSubset.add(m_instances.instance(i));
	  } else {
	    rightSubset.add(m_instances.instance(i));
	  } 
	} 

	leftSubset.compactify();
	rightSubset.compactify();

	// build left and right nodes
	m_left = new RuleNode(m_globalDeviation, m_globalAbsDeviation, this);
	m_left.setMinNumInstances(m_splitNum);
	m_left.setRegressionTree(m_regressionTree);
	m_left.setSaveInstances(m_saveInstances);
	m_left.buildClassifier(leftSubset);

	m_right = new RuleNode(m_globalDeviation, m_globalAbsDeviation, this);
	m_right.setMinNumInstances(m_splitNum);
	m_right.setRegressionTree(m_regressionTree);
	m_right.setSaveInstances(m_saveInstances);
	m_right.buildClassifier(rightSubset);

	// now find out what attributes are tested in the left and right
	// subtrees and use them to learn a linear model for this node
	if (!m_regressionTree) {
	  attsBelow = attsTestedBelow();
	  attsBelow[m_classIndex] = true;
	  int count = 0, j;

	  for (j = 0; j < m_numAttributes; j++) {
	    if (attsBelow[j]) {
	      count++;
	    } 
	  } 
	  
	  int[] indices = new int[count];

	  count = 0;
	  
	  for (j = 0; j < m_numAttributes; j++) {
	    if (attsBelow[j] && (j != m_classIndex)) {
	      indices[count++] = j;
	    } 
	  } 
	  
	  indices[count] = m_classIndex;
	  m_indices = indices;
	} else {
	  m_indices = new int [1];
	  m_indices[0] = m_classIndex;
	  m_numParameters = 1;
	}
      } 
    } 

    if (m_isLeaf) {
      int [] indices = new int [1];
      indices[0] = m_classIndex;
      m_indices = indices;
      m_numParameters = 1;
     
      // need to evaluate the model here if want correct stats for unpruned
      // tree
    } 
  } 

  /**
   * Build a linear model for this node using those attributes
   * specified in indices.
   *
   * @param indices an array of attribute indices to include in the linear
   * model
   */
  private void buildLinearModel(int [] indices) throws Exception {
    // copy the training instances and remove all but the tested
    // attributes
    Instances reducedInst = new Instances(m_instances);
    Remove attributeFilter = new Remove();
    
    attributeFilter.setInvertSelection(true);
    attributeFilter.setAttributeIndicesArray(indices);
    attributeFilter.setInputFormat(reducedInst);

    reducedInst = Filter.useFilter(reducedInst, attributeFilter);
    
    // build a linear regression for the training data using the
    // tested attributes
    LinearRegression temp = new LinearRegression();
    temp.buildClassifier(reducedInst);

    double [] lmCoeffs = temp.coefficients();
    double [] coeffs = new double [m_instances.numAttributes()];

    for (int i = 0; i < lmCoeffs.length - 1; i++) {
      if (indices[i] != m_classIndex) {
	coeffs[indices[i]] = lmCoeffs[i];
      }
    }
    m_nodeModel = new PreConstructedLinearModel(coeffs, lmCoeffs[lmCoeffs.length - 1]);
    m_nodeModel.buildClassifier(m_instances);
  }

  /**
   * Returns an array containing the indexes of attributes used in tests
   * above this node
   *
   * @return an array of attribute indexes
   */
  private boolean[] attsTestedAbove() {
    boolean[] atts = new boolean[m_numAttributes];
    boolean[] attsAbove = null;

    if (m_parent != null) {
      attsAbove = m_parent.attsTestedAbove();
    } 

    if (attsAbove != null) {
      for (int i = 0; i < m_numAttributes; i++) {
	atts[i] = attsAbove[i];
      } 
    } 

    atts[m_splitAtt] = true;
    return atts;
  } 

  /**
   * Returns an array containing the indexes of attributes used in tests
   * below this node
   *
   * @return an array of attribute indexes
   */
  private boolean[] attsTestedBelow() {
    boolean[] attsBelow = new boolean[m_numAttributes];
    boolean[] attsBelowLeft = null;
    boolean[] attsBelowRight = null;

    if (m_right != null) {
      attsBelowRight = m_right.attsTestedBelow();
    } 

    if (m_left != null) {
      attsBelowLeft = m_left.attsTestedBelow();
    } 

    for (int i = 0; i < m_numAttributes; i++) {
      if (attsBelowLeft != null) {
	attsBelow[i] = (attsBelow[i] || attsBelowLeft[i]);
      } 

      if (attsBelowRight != null) {
	attsBelow[i] = (attsBelow[i] || attsBelowRight[i]);
      } 
    } 

    if (!m_isLeaf) {
      attsBelow[m_splitAtt] = true;
    } 
    return attsBelow;
  } 

  /**
   * Sets the leaves' numbers
   * @param leafCounter the number of leaves counted
   * @return the number of the total leaves under the node
   */
  public int numLeaves(int leafCounter) {

    if (!m_isLeaf) {
      // node
      m_leafModelNum = 0;

      if (m_left != null) {
	leafCounter = m_left.numLeaves(leafCounter);
      } 

      if (m_right != null) {
	leafCounter = m_right.numLeaves(leafCounter);
      } 
    } else {
      // leaf
      leafCounter++;
      m_leafModelNum = leafCounter;
    } 
    return leafCounter;
  } 

  /**
   * print the linear model at this node
   */
  public String toString() {
    return printNodeLinearModel();
  } 

  /**
   * print the linear model at this node
   */
  public String printNodeLinearModel() {
    return m_nodeModel.toString();
  } 

  /**
   * print all leaf models
   */
  public String printLeafModels() {
    StringBuffer text = new StringBuffer();

    if (m_isLeaf) {
      text.append("\nLM num: " + m_leafModelNum);
      text.append(m_nodeModel.toString());
      text.append("\n");
    } else {
      text.append(m_left.printLeafModels());
      text.append(m_right.printLeafModels());
    } 
    return text.toString();
  } 

  /**
   * Returns a description of this node (debugging purposes)
   *
   * @return a string describing this node
   */
  public String nodeToString() {
    StringBuffer text = new StringBuffer();

    System.out.println("In to string");
    text.append("Node:\n\tnum inst: " + m_numInstances);

    if (m_isLeaf) {
      text.append("\n\tleaf");
    } else {
      text.append("\tnode");
    }

    text.append("\n\tSplit att: " + m_instances.attribute(m_splitAtt).name());
    text.append("\n\tSplit val: " + Utils.doubleToString(m_splitValue, 1, 3));
    text.append("\n\tLM num: " + m_leafModelNum);
    text.append("\n\tLinear model\n" + m_nodeModel.toString());
    text.append("\n\n");

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -