📄 rulenode.java
字号:
/*
* RuleNode.java
* Copyright (C) 2000 Mark Hall
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
package weka.classifiers.trees.m5;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.LinearRegression;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;
/**
* Constructs a node for use in an m5 tree or rule
*
* @author Mark Hall (mhall@cs.waikato.ac.nz)
* @version $Revision$
*/
public class RuleNode extends Classifier {
/**
* instances reaching this node
*/
private Instances m_instances;
/**
* the class index
*/
private int m_classIndex;
/**
* the number of instances reaching this node
*/
protected int m_numInstances;
/**
* the number of attributes
*/
private int m_numAttributes;
/**
* Node is a leaf
*/
private boolean m_isLeaf;
/**
* attribute this node splits on
*/
private int m_splitAtt;
/**
* the value of the split attribute
*/
private double m_splitValue;
/**
* the linear model at this node
*/
private PreConstructedLinearModel m_nodeModel;
/**
* the number of paramters in the chosen model for this node---either
* the subtree model or the linear model.
* The constant term is counted as a paramter---this is for pruning
* purposes
*/
public int m_numParameters;
/**
* the mean squared error of the model at this node (either linear or
* subtree)
*/
private double m_rootMeanSquaredError;
/**
* child nodes
*/
protected RuleNode m_left;
protected RuleNode m_right;
/**
* the parent of this node
*/
private RuleNode m_parent;
/**
* a node will not be split if it contains less then m_splitNum instances
*/
private double m_splitNum = 4;
/**
* a node will not be split if its class standard deviation is less
* than 5% of the class standard deviation of all the instances
*/
private double m_devFraction = 0.05;
private double m_pruningMultiplier = 2;
/**
* the number assigned to the linear model if this node is a leaf.
* = 0 if this node is not a leaf
*/
private int m_leafModelNum;
/**
* a node will not be split if the class deviation of its
* instances is less than m_devFraction of the deviation of the
* global class
*/
private double m_globalDeviation;
/**
* the absolute deviation of the global class
*/
private double m_globalAbsDeviation;
/**
* Indices of the attributes to be used in generating a linear model
* at this node
*/
private int [] m_indices;
/**
* Constant used in original m5 smoothing calculation
*/
private static final double SMOOTHING_CONSTANT = 15.0;
/**
* Node id.
*/
private int m_id;
/**
* Save the instances at each node (for visualizing in the
* Explorer's treevisualizer.
*/
private boolean m_saveInstances = false;
/**
* Make a regression tree instead of a model tree
*/
private boolean m_regressionTree;
/**
* Creates a new <code>RuleNode</code> instance.
*
* @param globalDev the global standard deviation of the class
* @param globalAbsDev the global absolute deviation of the class
* @param parent the parent of this node
*/
public RuleNode(double globalDev, double globalAbsDev, RuleNode parent) {
m_nodeModel = null;
m_right = null;
m_left = null;
m_parent = parent;
m_globalDeviation = globalDev;
m_globalAbsDeviation = globalAbsDev;
}
/**
* Build this node (find an attribute and split point)
*
* @param data the instances on which to build this node
* @exception Exception if an error occurs
*/
public void buildClassifier(Instances data) throws Exception {
m_rootMeanSquaredError = Double.MAX_VALUE;
// m_instances = new Instances(data);
m_instances = data;
m_classIndex = m_instances.classIndex();
m_numInstances = m_instances.numInstances();
m_numAttributes = m_instances.numAttributes();
m_nodeModel = null;
m_right = null;
m_left = null;
if ((m_numInstances < m_splitNum)
|| (Rule.stdDev(m_classIndex, m_instances)
< (m_globalDeviation * m_devFraction))) {
m_isLeaf = true;
} else {
m_isLeaf = false;
}
split();
}
/**
* Classify an instance using this node. Recursively calls classifyInstance
* on child nodes.
*
* @param inst the instance to classify
* @return the prediction for this instance
* @exception Exception if an error occurs
*/
public double classifyInstance(Instance inst) throws Exception {
double pred;
double n = 0;
Instance tempInst;
if (m_isLeaf) {
if (m_nodeModel == null) {
throw new Exception("Classifier has not been built correctly.");
}
return m_nodeModel.classifyInstance(inst);
}
if (inst.value(m_splitAtt) <= m_splitValue) {
return m_left.classifyInstance(inst);
} else {
return m_right.classifyInstance(inst);
}
}
/**
* Applies the m5 smoothing procedure to a prediction
*
* @param n number of instances in selected child of this node
* @param pred the prediction so far
* @param supportPred the prediction of the linear model at this node
* @return the current prediction smoothed with the prediction of the
* linear model at this node
* @exception Exception if an error occurs
*/
protected static double smoothingOriginal(double n, double pred,
double supportPred)
throws Exception {
double smoothed;
smoothed =
((n * pred) + (SMOOTHING_CONSTANT * supportPred)) /
(n + SMOOTHING_CONSTANT);
return smoothed;
}
/**
* Finds an attribute and split point for this node
*
* @exception Exception if an error occurs
*/
public void split() throws Exception {
int i;
Instances leftSubset, rightSubset;
SplitEvaluate bestSplit, currentSplit;
boolean[] attsBelow;
if (!m_isLeaf) {
bestSplit = new YongSplitInfo(0, m_numInstances - 1, -1);
currentSplit = new YongSplitInfo(0, m_numInstances - 1, -1);
// find the best attribute to split on
for (i = 0; i < m_numAttributes; i++) {
if (i != m_classIndex) {
// sort the instances by this attribute
m_instances.sort(i);
currentSplit.attrSplit(i, m_instances);
if ((Math.abs(currentSplit.maxImpurity() -
bestSplit.maxImpurity()) > 1.e-6)
&& (currentSplit.maxImpurity()
> bestSplit.maxImpurity() + 1.e-6)) {
bestSplit = currentSplit.copy();
}
}
}
// cant find a good split or split point?
if (bestSplit.splitAttr() < 0 || bestSplit.position() < 1
|| bestSplit.position() > m_numInstances - 1) {
m_isLeaf = true;
} else {
m_splitAtt = bestSplit.splitAttr();
m_splitValue = bestSplit.splitValue();
leftSubset = new Instances(m_instances, m_numInstances);
rightSubset = new Instances(m_instances, m_numInstances);
for (i = 0; i < m_numInstances; i++) {
if (m_instances.instance(i).value(m_splitAtt) <= m_splitValue) {
leftSubset.add(m_instances.instance(i));
} else {
rightSubset.add(m_instances.instance(i));
}
}
leftSubset.compactify();
rightSubset.compactify();
// build left and right nodes
m_left = new RuleNode(m_globalDeviation, m_globalAbsDeviation, this);
m_left.setMinNumInstances(m_splitNum);
m_left.setRegressionTree(m_regressionTree);
m_left.setSaveInstances(m_saveInstances);
m_left.buildClassifier(leftSubset);
m_right = new RuleNode(m_globalDeviation, m_globalAbsDeviation, this);
m_right.setMinNumInstances(m_splitNum);
m_right.setRegressionTree(m_regressionTree);
m_right.setSaveInstances(m_saveInstances);
m_right.buildClassifier(rightSubset);
// now find out what attributes are tested in the left and right
// subtrees and use them to learn a linear model for this node
if (!m_regressionTree) {
attsBelow = attsTestedBelow();
attsBelow[m_classIndex] = true;
int count = 0, j;
for (j = 0; j < m_numAttributes; j++) {
if (attsBelow[j]) {
count++;
}
}
int[] indices = new int[count];
count = 0;
for (j = 0; j < m_numAttributes; j++) {
if (attsBelow[j] && (j != m_classIndex)) {
indices[count++] = j;
}
}
indices[count] = m_classIndex;
m_indices = indices;
} else {
m_indices = new int [1];
m_indices[0] = m_classIndex;
m_numParameters = 1;
}
}
}
if (m_isLeaf) {
int [] indices = new int [1];
indices[0] = m_classIndex;
m_indices = indices;
m_numParameters = 1;
// need to evaluate the model here if want correct stats for unpruned
// tree
}
}
/**
* Build a linear model for this node using those attributes
* specified in indices.
*
* @param indices an array of attribute indices to include in the linear
* model
*/
private void buildLinearModel(int [] indices) throws Exception {
// copy the training instances and remove all but the tested
// attributes
Instances reducedInst = new Instances(m_instances);
Remove attributeFilter = new Remove();
attributeFilter.setInvertSelection(true);
attributeFilter.setAttributeIndicesArray(indices);
attributeFilter.setInputFormat(reducedInst);
reducedInst = Filter.useFilter(reducedInst, attributeFilter);
// build a linear regression for the training data using the
// tested attributes
LinearRegression temp = new LinearRegression();
temp.buildClassifier(reducedInst);
double [] lmCoeffs = temp.coefficients();
double [] coeffs = new double [m_instances.numAttributes()];
for (int i = 0; i < lmCoeffs.length - 1; i++) {
if (indices[i] != m_classIndex) {
coeffs[indices[i]] = lmCoeffs[i];
}
}
m_nodeModel = new PreConstructedLinearModel(coeffs, lmCoeffs[lmCoeffs.length - 1]);
m_nodeModel.buildClassifier(m_instances);
}
/**
* Returns an array containing the indexes of attributes used in tests
* above this node
*
* @return an array of attribute indexes
*/
private boolean[] attsTestedAbove() {
boolean[] atts = new boolean[m_numAttributes];
boolean[] attsAbove = null;
if (m_parent != null) {
attsAbove = m_parent.attsTestedAbove();
}
if (attsAbove != null) {
for (int i = 0; i < m_numAttributes; i++) {
atts[i] = attsAbove[i];
}
}
atts[m_splitAtt] = true;
return atts;
}
/**
* Returns an array containing the indexes of attributes used in tests
* below this node
*
* @return an array of attribute indexes
*/
private boolean[] attsTestedBelow() {
boolean[] attsBelow = new boolean[m_numAttributes];
boolean[] attsBelowLeft = null;
boolean[] attsBelowRight = null;
if (m_right != null) {
attsBelowRight = m_right.attsTestedBelow();
}
if (m_left != null) {
attsBelowLeft = m_left.attsTestedBelow();
}
for (int i = 0; i < m_numAttributes; i++) {
if (attsBelowLeft != null) {
attsBelow[i] = (attsBelow[i] || attsBelowLeft[i]);
}
if (attsBelowRight != null) {
attsBelow[i] = (attsBelow[i] || attsBelowRight[i]);
}
}
if (!m_isLeaf) {
attsBelow[m_splitAtt] = true;
}
return attsBelow;
}
/**
* Sets the leaves' numbers
* @param leafCounter the number of leaves counted
* @return the number of the total leaves under the node
*/
public int numLeaves(int leafCounter) {
if (!m_isLeaf) {
// node
m_leafModelNum = 0;
if (m_left != null) {
leafCounter = m_left.numLeaves(leafCounter);
}
if (m_right != null) {
leafCounter = m_right.numLeaves(leafCounter);
}
} else {
// leaf
leafCounter++;
m_leafModelNum = leafCounter;
}
return leafCounter;
}
/**
* print the linear model at this node
*/
public String toString() {
return printNodeLinearModel();
}
/**
* print the linear model at this node
*/
public String printNodeLinearModel() {
return m_nodeModel.toString();
}
/**
* print all leaf models
*/
public String printLeafModels() {
StringBuffer text = new StringBuffer();
if (m_isLeaf) {
text.append("\nLM num: " + m_leafModelNum);
text.append(m_nodeModel.toString());
text.append("\n");
} else {
text.append(m_left.printLeafModels());
text.append(m_right.printLeafModels());
}
return text.toString();
}
/**
* Returns a description of this node (debugging purposes)
*
* @return a string describing this node
*/
public String nodeToString() {
StringBuffer text = new StringBuffer();
System.out.println("In to string");
text.append("Node:\n\tnum inst: " + m_numInstances);
if (m_isLeaf) {
text.append("\n\tleaf");
} else {
text.append("\tnode");
}
text.append("\n\tSplit att: " + m_instances.attribute(m_splitAtt).name());
text.append("\n\tSplit val: " + Utils.doubleToString(m_splitValue, 1, 3));
text.append("\n\tLM num: " + m_leafModelNum);
text.append("\n\tLinear model\n" + m_nodeModel.toString());
text.append("\n\n");
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -