⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 rulenode.java

📁 this is an weka tool source code implemented in java in data mining purpose
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/* *    RuleNode.java *    Copyright (C) 2000 Mark Hall * *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */package weka.classifiers.trees.m5;import java.io.*;import java.util.*;import weka.core.*;import weka.classifiers.*;import weka.classifiers.functions.LinearRegression;import weka.filters.unsupervised.attribute.Remove;import weka.filters.Filter;/** * Constructs a node for use in an m5 tree or rule * * @author Mark Hall (mhall@cs.waikato.ac.nz) * @version $Revision: 1.8.2.1 $ */public class RuleNode extends Classifier {  /**   * instances reaching this node   */  private Instances	   m_instances;  /**   * the class index   */  private int		   m_classIndex;  /**   * the number of instances reaching this node   */  protected int		   m_numInstances;  /**   * the number of attributes   */  private int		   m_numAttributes;  /**   * Node is a leaf   */  private boolean	   m_isLeaf;  /**   * attribute this node splits on   */  private int		   m_splitAtt;  /**   * the value of the split attribute   */  private double	   m_splitValue;  /**   * the linear model at this node   */  private PreConstructedLinearModel m_nodeModel;  /**   * the number of paramters in the chosen model for this node---either   * the subtree model or the linear model.   * The constant term is counted as a paramter---this is for pruning   * purposes   */  public int		   m_numParameters;  /**   * the mean squared error of the model at this node (either linear or   * subtree)   */  private double	   m_rootMeanSquaredError;  /**   * child nodes   */  protected RuleNode	   m_left;  protected RuleNode	   m_right;  /**   * the parent of this node   */  private RuleNode	   m_parent;  /**   * a node will not be split if it contains less then m_splitNum instances   */  private double	   m_splitNum = 4;  /**   * a node will not be split if its class standard deviation is less   * than 5% of the class standard deviation of all the instances   */  private double	   m_devFraction = 0.05;  private double	   m_pruningMultiplier = 2;  /**   * the number assigned to the linear model if this node is a leaf.   * = 0 if this node is not a leaf   */  private int		   m_leafModelNum;  /**   * a node will not be split if the class deviation of its   * instances is less than m_devFraction of the deviation of the   * global class   */  private double	   m_globalDeviation;  /**   * the absolute deviation of the global class   */  private double	   m_globalAbsDeviation;  /**   * Indices of the attributes to be used in generating a linear model   * at this node   */  private int [] m_indices;      /**   * Constant used in original m5 smoothing calculation   */  private static final double	   SMOOTHING_CONSTANT = 15.0;  /**   * Node id.   */  private int m_id;  /**   * Save the instances at each node (for visualizing in the   * Explorer's treevisualizer.   */  private boolean m_saveInstances = false;  /**   * Make a regression tree instead of a model tree   */  private boolean m_regressionTree;  /**   * Creates a new <code>RuleNode</code> instance.   *   * @param globalDev the global standard deviation of the class   * @param globalAbsDev the global absolute deviation of the class   * @param parent the parent of this node   */  public RuleNode(double globalDev, double globalAbsDev, RuleNode parent) {    m_nodeModel = null;    m_right = null;    m_left = null;    m_parent = parent;    m_globalDeviation = globalDev;    m_globalAbsDeviation = globalAbsDev;  }      /**   * Build this node (find an attribute and split point)   *   * @param data the instances on which to build this node   * @exception Exception if an error occurs   */  public void buildClassifier(Instances data) throws Exception {    m_rootMeanSquaredError = Double.MAX_VALUE;    //    m_instances = new Instances(data);    m_instances = data;    m_classIndex = m_instances.classIndex();    m_numInstances = m_instances.numInstances();    m_numAttributes = m_instances.numAttributes();    m_nodeModel = null;    m_right = null;    m_left = null;    if ((m_numInstances < m_splitNum) 	|| (Rule.stdDev(m_classIndex, m_instances) 	    < (m_globalDeviation * m_devFraction))) {      m_isLeaf = true;    } else {      m_isLeaf = false;    }     split();  }    /**   * Classify an instance using this node. Recursively calls classifyInstance   * on child nodes.   *   * @param inst the instance to classify   * @return the prediction for this instance   * @exception Exception if an error occurs   */  public double classifyInstance(Instance inst) throws Exception {    double   pred;    double   n = 0;    Instance tempInst;    if (m_isLeaf) {      if (m_nodeModel == null) {	throw new Exception("Classifier has not been built correctly.");      }       return m_nodeModel.classifyInstance(inst);    }    if (inst.value(m_splitAtt) <= m_splitValue) {      return m_left.classifyInstance(inst);    } else {      return m_right.classifyInstance(inst);    }   }   /**   * Applies the m5 smoothing procedure to a prediction   *   * @param n number of instances in selected child of this node   * @param pred the prediction so far   * @param supportPred the prediction of the linear model at this node   * @return the current prediction smoothed with the prediction of the   * linear model at this node   * @exception Exception if an error occurs   */  protected static double smoothingOriginal(double n, double pred, 					    double supportPred)     throws Exception {    double   smoothed;    smoothed =       ((n * pred) + (SMOOTHING_CONSTANT * supportPred)) /      (n + SMOOTHING_CONSTANT);    return smoothed;  }   /**   * Finds an attribute and split point for this node   *   * @exception Exception if an error occurs   */  public void split() throws Exception {    int		  i;    Instances     leftSubset, rightSubset;    SplitEvaluate bestSplit, currentSplit;    boolean[]     attsBelow;    if (!m_isLeaf) {           bestSplit = new YongSplitInfo(0, m_numInstances - 1, -1);      currentSplit = new YongSplitInfo(0, m_numInstances - 1, -1);      // find the best attribute to split on      for (i = 0; i < m_numAttributes; i++) {	if (i != m_classIndex) {	  // sort the instances by this attribute	  m_instances.sort(i);	  currentSplit.attrSplit(i, m_instances);	  if ((Math.abs(currentSplit.maxImpurity() - 			bestSplit.maxImpurity()) > 1.e-6) 	      && (currentSplit.maxImpurity() 		  > bestSplit.maxImpurity() + 1.e-6)) {	    bestSplit = currentSplit.copy();	  } 	}       }       // cant find a good split or split point?      if (bestSplit.splitAttr() < 0 || bestSplit.position() < 1 	  || bestSplit.position() > m_numInstances - 1) {	m_isLeaf = true;      } else {	m_splitAtt = bestSplit.splitAttr();	m_splitValue = bestSplit.splitValue();	leftSubset = new Instances(m_instances, m_numInstances);	rightSubset = new Instances(m_instances, m_numInstances);	for (i = 0; i < m_numInstances; i++) {	  if (m_instances.instance(i).value(m_splitAtt) <= m_splitValue) {	    leftSubset.add(m_instances.instance(i));	  } else {	    rightSubset.add(m_instances.instance(i));	  } 	} 	leftSubset.compactify();	rightSubset.compactify();	// build left and right nodes	m_left = new RuleNode(m_globalDeviation, m_globalAbsDeviation, this);	m_left.setMinNumInstances(m_splitNum);	m_left.setRegressionTree(m_regressionTree);	m_left.setSaveInstances(m_saveInstances);	m_left.buildClassifier(leftSubset);	m_right = new RuleNode(m_globalDeviation, m_globalAbsDeviation, this);	m_right.setMinNumInstances(m_splitNum);	m_right.setRegressionTree(m_regressionTree);	m_right.setSaveInstances(m_saveInstances);	m_right.buildClassifier(rightSubset);	// now find out what attributes are tested in the left and right	// subtrees and use them to learn a linear model for this node	if (!m_regressionTree) {	  attsBelow = attsTestedBelow();	  attsBelow[m_classIndex] = true;	  int count = 0, j;	  for (j = 0; j < m_numAttributes; j++) {	    if (attsBelow[j]) {	      count++;	    } 	  } 	  	  int[] indices = new int[count];	  count = 0;	  	  for (j = 0; j < m_numAttributes; j++) {	    if (attsBelow[j] && (j != m_classIndex)) {	      indices[count++] = j;	    } 	  } 	  	  indices[count] = m_classIndex;	  m_indices = indices;	} else {	  m_indices = new int [1];	  m_indices[0] = m_classIndex;	  m_numParameters = 1;	}      }     }     if (m_isLeaf) {      int [] indices = new int [1];      indices[0] = m_classIndex;      m_indices = indices;      m_numParameters = 1;           // need to evaluate the model here if want correct stats for unpruned      // tree    }   }   /**   * Build a linear model for this node using those attributes   * specified in indices.   *   * @param indices an array of attribute indices to include in the linear   * model   */  private void buildLinearModel(int [] indices) throws Exception {    // copy the training instances and remove all but the tested    // attributes    Instances reducedInst = new Instances(m_instances);    Remove attributeFilter = new Remove();        attributeFilter.setInvertSelection(true);    attributeFilter.setAttributeIndicesArray(indices);    attributeFilter.setInputFormat(reducedInst);    reducedInst = Filter.useFilter(reducedInst, attributeFilter);        // build a linear regression for the training data using the    // tested attributes    LinearRegression temp = new LinearRegression();    temp.buildClassifier(reducedInst);    double [] lmCoeffs = temp.coefficients();    double [] coeffs = new double [m_instances.numAttributes()];    for (int i = 0; i < lmCoeffs.length - 1; i++) {      if (indices[i] != m_classIndex) {	coeffs[indices[i]] = lmCoeffs[i];      }    }    m_nodeModel = new PreConstructedLinearModel(coeffs, lmCoeffs[lmCoeffs.length - 1]);    m_nodeModel.buildClassifier(m_instances);  }  /**   * Returns an array containing the indexes of attributes used in tests   * above this node   *   * @return an array of attribute indexes   */  private boolean[] attsTestedAbove() {    boolean[] atts = new boolean[m_numAttributes];    boolean[] attsAbove = null;    if (m_parent != null) {      attsAbove = m_parent.attsTestedAbove();    }     if (attsAbove != null) {      for (int i = 0; i < m_numAttributes; i++) {	atts[i] = attsAbove[i];      }     }     atts[m_splitAtt] = true;    return atts;  }   /**   * Returns an array containing the indexes of attributes used in tests   * below this node   *   * @return an array of attribute indexes   */  private boolean[] attsTestedBelow() {    boolean[] attsBelow = new boolean[m_numAttributes];    boolean[] attsBelowLeft = null;    boolean[] attsBelowRight = null;    if (m_right != null) {      attsBelowRight = m_right.attsTestedBelow();    }     if (m_left != null) {      attsBelowLeft = m_left.attsTestedBelow();    }     for (int i = 0; i < m_numAttributes; i++) {      if (attsBelowLeft != null) {	attsBelow[i] = (attsBelow[i] || attsBelowLeft[i]);      }       if (attsBelowRight != null) {	attsBelow[i] = (attsBelow[i] || attsBelowRight[i]);      }     }     if (!m_isLeaf) {      attsBelow[m_splitAtt] = true;    }     return attsBelow;  }   /**   * Sets the leaves' numbers   * @param leafCounter the number of leaves counted   * @return the number of the total leaves under the node   */  public int numLeaves(int leafCounter) {    if (!m_isLeaf) {      // node      m_leafModelNum = 0;      if (m_left != null) {	leafCounter = m_left.numLeaves(leafCounter);      }       if (m_right != null) {	leafCounter = m_right.numLeaves(leafCounter);      }     } else {      // leaf      leafCounter++;      m_leafModelNum = leafCounter;    }     return leafCounter;  }   /**   * print the linear model at this node   */  public String toString() {    return printNodeLinearModel();  }   /**   * print the linear model at this node   */  public String printNodeLinearModel() {    return m_nodeModel.toString();  }   /**   * print all leaf models   */  public String printLeafModels() {    StringBuffer text = new StringBuffer();    if (m_isLeaf) {      text.append("\nLM num: " + m_leafModelNum);      text.append(m_nodeModel.toString());      text.append("\n");    } else {      text.append(m_left.printLeafModels());      text.append(m_right.printLeafModels());    }     return text.toString();  }   /**   * Returns a description of this node (debugging purposes)   *   * @return a string describing this node   */  public String nodeToString() {    StringBuffer text = new StringBuffer();    System.out.println("In to string");    text.append("Node:\n\tnum inst: " + m_numInstances);    if (m_isLeaf) {      text.append("\n\tleaf");    } else {      text.append("\tnode");    }    text.append("\n\tSplit att: " + m_instances.attribute(m_splitAtt).name());    text.append("\n\tSplit val: " + Utils.doubleToString(m_splitValue, 1, 3));    text.append("\n\tLM num: " + m_leafModelNum);    text.append("\n\tLinear model\n" + m_nodeModel.toString());    text.append("\n\n");

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -