📄 rulenode.java
字号:
/* * RuleNode.java * Copyright (C) 2000 Mark Hall * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */package weka.classifiers.trees.m5;import java.io.*;import java.util.*;import weka.core.*;import weka.classifiers.*;import weka.classifiers.functions.LinearRegression;import weka.filters.unsupervised.attribute.Remove;import weka.filters.Filter;/** * Constructs a node for use in an m5 tree or rule * * @author Mark Hall (mhall@cs.waikato.ac.nz) * @version $Revision: 1.8.2.1 $ */public class RuleNode extends Classifier { /** * instances reaching this node */ private Instances m_instances; /** * the class index */ private int m_classIndex; /** * the number of instances reaching this node */ protected int m_numInstances; /** * the number of attributes */ private int m_numAttributes; /** * Node is a leaf */ private boolean m_isLeaf; /** * attribute this node splits on */ private int m_splitAtt; /** * the value of the split attribute */ private double m_splitValue; /** * the linear model at this node */ private PreConstructedLinearModel m_nodeModel; /** * the number of paramters in the chosen model for this node---either * the subtree model or the linear model. * The constant term is counted as a paramter---this is for pruning * purposes */ public int m_numParameters; /** * the mean squared error of the model at this node (either linear or * subtree) */ private double m_rootMeanSquaredError; /** * child nodes */ protected RuleNode m_left; protected RuleNode m_right; /** * the parent of this node */ private RuleNode m_parent; /** * a node will not be split if it contains less then m_splitNum instances */ private double m_splitNum = 4; /** * a node will not be split if its class standard deviation is less * than 5% of the class standard deviation of all the instances */ private double m_devFraction = 0.05; private double m_pruningMultiplier = 2; /** * the number assigned to the linear model if this node is a leaf. * = 0 if this node is not a leaf */ private int m_leafModelNum; /** * a node will not be split if the class deviation of its * instances is less than m_devFraction of the deviation of the * global class */ private double m_globalDeviation; /** * the absolute deviation of the global class */ private double m_globalAbsDeviation; /** * Indices of the attributes to be used in generating a linear model * at this node */ private int [] m_indices; /** * Constant used in original m5 smoothing calculation */ private static final double SMOOTHING_CONSTANT = 15.0; /** * Node id. */ private int m_id; /** * Save the instances at each node (for visualizing in the * Explorer's treevisualizer. */ private boolean m_saveInstances = false; /** * Make a regression tree instead of a model tree */ private boolean m_regressionTree; /** * Creates a new <code>RuleNode</code> instance. * * @param globalDev the global standard deviation of the class * @param globalAbsDev the global absolute deviation of the class * @param parent the parent of this node */ public RuleNode(double globalDev, double globalAbsDev, RuleNode parent) { m_nodeModel = null; m_right = null; m_left = null; m_parent = parent; m_globalDeviation = globalDev; m_globalAbsDeviation = globalAbsDev; } /** * Build this node (find an attribute and split point) * * @param data the instances on which to build this node * @exception Exception if an error occurs */ public void buildClassifier(Instances data) throws Exception { m_rootMeanSquaredError = Double.MAX_VALUE; // m_instances = new Instances(data); m_instances = data; m_classIndex = m_instances.classIndex(); m_numInstances = m_instances.numInstances(); m_numAttributes = m_instances.numAttributes(); m_nodeModel = null; m_right = null; m_left = null; if ((m_numInstances < m_splitNum) || (Rule.stdDev(m_classIndex, m_instances) < (m_globalDeviation * m_devFraction))) { m_isLeaf = true; } else { m_isLeaf = false; } split(); } /** * Classify an instance using this node. Recursively calls classifyInstance * on child nodes. * * @param inst the instance to classify * @return the prediction for this instance * @exception Exception if an error occurs */ public double classifyInstance(Instance inst) throws Exception { double pred; double n = 0; Instance tempInst; if (m_isLeaf) { if (m_nodeModel == null) { throw new Exception("Classifier has not been built correctly."); } return m_nodeModel.classifyInstance(inst); } if (inst.value(m_splitAtt) <= m_splitValue) { return m_left.classifyInstance(inst); } else { return m_right.classifyInstance(inst); } } /** * Applies the m5 smoothing procedure to a prediction * * @param n number of instances in selected child of this node * @param pred the prediction so far * @param supportPred the prediction of the linear model at this node * @return the current prediction smoothed with the prediction of the * linear model at this node * @exception Exception if an error occurs */ protected static double smoothingOriginal(double n, double pred, double supportPred) throws Exception { double smoothed; smoothed = ((n * pred) + (SMOOTHING_CONSTANT * supportPred)) / (n + SMOOTHING_CONSTANT); return smoothed; } /** * Finds an attribute and split point for this node * * @exception Exception if an error occurs */ public void split() throws Exception { int i; Instances leftSubset, rightSubset; SplitEvaluate bestSplit, currentSplit; boolean[] attsBelow; if (!m_isLeaf) { bestSplit = new YongSplitInfo(0, m_numInstances - 1, -1); currentSplit = new YongSplitInfo(0, m_numInstances - 1, -1); // find the best attribute to split on for (i = 0; i < m_numAttributes; i++) { if (i != m_classIndex) { // sort the instances by this attribute m_instances.sort(i); currentSplit.attrSplit(i, m_instances); if ((Math.abs(currentSplit.maxImpurity() - bestSplit.maxImpurity()) > 1.e-6) && (currentSplit.maxImpurity() > bestSplit.maxImpurity() + 1.e-6)) { bestSplit = currentSplit.copy(); } } } // cant find a good split or split point? if (bestSplit.splitAttr() < 0 || bestSplit.position() < 1 || bestSplit.position() > m_numInstances - 1) { m_isLeaf = true; } else { m_splitAtt = bestSplit.splitAttr(); m_splitValue = bestSplit.splitValue(); leftSubset = new Instances(m_instances, m_numInstances); rightSubset = new Instances(m_instances, m_numInstances); for (i = 0; i < m_numInstances; i++) { if (m_instances.instance(i).value(m_splitAtt) <= m_splitValue) { leftSubset.add(m_instances.instance(i)); } else { rightSubset.add(m_instances.instance(i)); } } leftSubset.compactify(); rightSubset.compactify(); // build left and right nodes m_left = new RuleNode(m_globalDeviation, m_globalAbsDeviation, this); m_left.setMinNumInstances(m_splitNum); m_left.setRegressionTree(m_regressionTree); m_left.setSaveInstances(m_saveInstances); m_left.buildClassifier(leftSubset); m_right = new RuleNode(m_globalDeviation, m_globalAbsDeviation, this); m_right.setMinNumInstances(m_splitNum); m_right.setRegressionTree(m_regressionTree); m_right.setSaveInstances(m_saveInstances); m_right.buildClassifier(rightSubset); // now find out what attributes are tested in the left and right // subtrees and use them to learn a linear model for this node if (!m_regressionTree) { attsBelow = attsTestedBelow(); attsBelow[m_classIndex] = true; int count = 0, j; for (j = 0; j < m_numAttributes; j++) { if (attsBelow[j]) { count++; } } int[] indices = new int[count]; count = 0; for (j = 0; j < m_numAttributes; j++) { if (attsBelow[j] && (j != m_classIndex)) { indices[count++] = j; } } indices[count] = m_classIndex; m_indices = indices; } else { m_indices = new int [1]; m_indices[0] = m_classIndex; m_numParameters = 1; } } } if (m_isLeaf) { int [] indices = new int [1]; indices[0] = m_classIndex; m_indices = indices; m_numParameters = 1; // need to evaluate the model here if want correct stats for unpruned // tree } } /** * Build a linear model for this node using those attributes * specified in indices. * * @param indices an array of attribute indices to include in the linear * model */ private void buildLinearModel(int [] indices) throws Exception { // copy the training instances and remove all but the tested // attributes Instances reducedInst = new Instances(m_instances); Remove attributeFilter = new Remove(); attributeFilter.setInvertSelection(true); attributeFilter.setAttributeIndicesArray(indices); attributeFilter.setInputFormat(reducedInst); reducedInst = Filter.useFilter(reducedInst, attributeFilter); // build a linear regression for the training data using the // tested attributes LinearRegression temp = new LinearRegression(); temp.buildClassifier(reducedInst); double [] lmCoeffs = temp.coefficients(); double [] coeffs = new double [m_instances.numAttributes()]; for (int i = 0; i < lmCoeffs.length - 1; i++) { if (indices[i] != m_classIndex) { coeffs[indices[i]] = lmCoeffs[i]; } } m_nodeModel = new PreConstructedLinearModel(coeffs, lmCoeffs[lmCoeffs.length - 1]); m_nodeModel.buildClassifier(m_instances); } /** * Returns an array containing the indexes of attributes used in tests * above this node * * @return an array of attribute indexes */ private boolean[] attsTestedAbove() { boolean[] atts = new boolean[m_numAttributes]; boolean[] attsAbove = null; if (m_parent != null) { attsAbove = m_parent.attsTestedAbove(); } if (attsAbove != null) { for (int i = 0; i < m_numAttributes; i++) { atts[i] = attsAbove[i]; } } atts[m_splitAtt] = true; return atts; } /** * Returns an array containing the indexes of attributes used in tests * below this node * * @return an array of attribute indexes */ private boolean[] attsTestedBelow() { boolean[] attsBelow = new boolean[m_numAttributes]; boolean[] attsBelowLeft = null; boolean[] attsBelowRight = null; if (m_right != null) { attsBelowRight = m_right.attsTestedBelow(); } if (m_left != null) { attsBelowLeft = m_left.attsTestedBelow(); } for (int i = 0; i < m_numAttributes; i++) { if (attsBelowLeft != null) { attsBelow[i] = (attsBelow[i] || attsBelowLeft[i]); } if (attsBelowRight != null) { attsBelow[i] = (attsBelow[i] || attsBelowRight[i]); } } if (!m_isLeaf) { attsBelow[m_splitAtt] = true; } return attsBelow; } /** * Sets the leaves' numbers * @param leafCounter the number of leaves counted * @return the number of the total leaves under the node */ public int numLeaves(int leafCounter) { if (!m_isLeaf) { // node m_leafModelNum = 0; if (m_left != null) { leafCounter = m_left.numLeaves(leafCounter); } if (m_right != null) { leafCounter = m_right.numLeaves(leafCounter); } } else { // leaf leafCounter++; m_leafModelNum = leafCounter; } return leafCounter; } /** * print the linear model at this node */ public String toString() { return printNodeLinearModel(); } /** * print the linear model at this node */ public String printNodeLinearModel() { return m_nodeModel.toString(); } /** * print all leaf models */ public String printLeafModels() { StringBuffer text = new StringBuffer(); if (m_isLeaf) { text.append("\nLM num: " + m_leafModelNum); text.append(m_nodeModel.toString()); text.append("\n"); } else { text.append(m_left.printLeafModels()); text.append(m_right.printLeafModels()); } return text.toString(); } /** * Returns a description of this node (debugging purposes) * * @return a string describing this node */ public String nodeToString() { StringBuffer text = new StringBuffer(); System.out.println("In to string"); text.append("Node:\n\tnum inst: " + m_numInstances); if (m_isLeaf) { text.append("\n\tleaf"); } else { text.append("\tnode"); } text.append("\n\tSplit att: " + m_instances.attribute(m_splitAtt).name()); text.append("\n\tSplit val: " + Utils.doubleToString(m_splitValue, 1, 3)); text.append("\n\tLM num: " + m_leafModelNum); text.append("\n\tLinear model\n" + m_nodeModel.toString()); text.append("\n\n");
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -