📄 rulenode.java
字号:
if (m_left != null) {
text.append(m_left.nodeToString());
}
if (m_right != null) {
text.append(m_right.nodeToString());
}
return text.toString();
}
/**
* Recursively builds a textual description of the tree
*
* @param level the level of this node
* @return string describing the tree
*/
public String treeToString(int level) {
int i;
StringBuffer text = new StringBuffer();
if (!m_isLeaf) {
text.append("\n");
for (i = 1; i <= level; i++) {
text.append("| ");
}
if (m_instances.attribute(m_splitAtt).name().charAt(0) != '[') {
text.append(m_instances.attribute(m_splitAtt).name() + " <= "
+ Utils.doubleToString(m_splitValue, 1, 3) + " : ");
} else {
text.append(m_instances.attribute(m_splitAtt).name() + " false : ");
}
if (m_left != null) {
text.append(m_left.treeToString(level + 1));
} else {
text.append("NULL\n");
}
for (i = 1; i <= level; i++) {
text.append("| ");
}
if (m_instances.attribute(m_splitAtt).name().charAt(0) != '[') {
text.append(m_instances.attribute(m_splitAtt).name() + " > "
+ Utils.doubleToString(m_splitValue, 1, 3) + " : ");
} else {
text.append(m_instances.attribute(m_splitAtt).name() + " true : ");
}
if (m_right != null) {
text.append(m_right.treeToString(level + 1));
} else {
text.append("NULL\n");
}
} else {
text.append("LM" + m_leafModelNum);
if (m_globalDeviation > 0.0) {
text
.append(" (" + m_numInstances + "/"
+ Utils.doubleToString((100.0 * m_rootMeanSquaredError /
m_globalAbsDeviation), 1, 3)
+ "%)\n");
} else {
text.append(" (" + m_numInstances + ")\n");
}
}
return text.toString();
}
/**
* Traverses the tree and installs linear models at each node.
* This method must be called if pruning is not to be performed.
*
* @exception Exception if an error occurs
*/
public void installLinearModels() throws Exception {
Evaluation nodeModelEval;
if (m_isLeaf) {
buildLinearModel(m_indices);
} else {
if (m_left != null) {
m_left.installLinearModels();
}
if (m_right != null) {
m_right.installLinearModels();
}
buildLinearModel(m_indices);
}
nodeModelEval = new Evaluation(m_instances);
nodeModelEval.evaluateModel(m_nodeModel, m_instances);
m_rootMeanSquaredError = nodeModelEval.rootMeanSquaredError();
// save space
if (!m_saveInstances) {
m_instances = new Instances(m_instances, 0);
}
}
public void installSmoothedModels() throws Exception {
if (m_isLeaf) {
double [] coefficients = new double [m_numAttributes];
double intercept;
double [] coeffsUsedByLinearModel = m_nodeModel.coefficients();
RuleNode current = this;
// prime array with leaf node coefficients
for (int i = 0; i < coeffsUsedByLinearModel.length; i++) {
if (i != m_classIndex) {
coefficients[i] = coeffsUsedByLinearModel[i];
}
}
// intercept
intercept = m_nodeModel.intercept();
do {
if (current.m_parent != null) {
PreConstructedLinearModel thisL = current.m_parent.getModel();
double n = current.m_numInstances;
// contribution of the model below
for (int i = 0; i < coefficients.length; i++) {
coefficients[i] = ((coefficients[i] * n) / (n + SMOOTHING_CONSTANT));
}
intercept = ((intercept * n) / (n + SMOOTHING_CONSTANT));
// contribution of this model
coeffsUsedByLinearModel = current.m_parent.getModel().coefficients();
for (int i = 0; i < coeffsUsedByLinearModel.length; i++) {
if (i != m_classIndex) {
// smooth in these coefficients (at this node)
coefficients[i] +=
((SMOOTHING_CONSTANT * coeffsUsedByLinearModel[i]) /
(n + SMOOTHING_CONSTANT));
}
}
// smooth in the intercept
intercept +=
((SMOOTHING_CONSTANT *
current.m_parent.getModel().intercept()) /
(n + SMOOTHING_CONSTANT));
current = current.m_parent;
}
} while (current.m_parent != null);
m_nodeModel =
new PreConstructedLinearModel(coefficients, intercept);
m_nodeModel.buildClassifier(m_instances);
}
if (m_left != null) {
m_left.installSmoothedModels();
}
if (m_right != null) {
m_right.installSmoothedModels();
}
}
/**
* Recursively prune the tree
*
* @exception Exception if an error occurs
*/
public void prune() throws Exception {
Evaluation nodeModelEval = null;
if (m_isLeaf) {
buildLinearModel(m_indices);
nodeModelEval = new Evaluation(m_instances);
// count the constant term as a paramter for a leaf
// Evaluate the model
nodeModelEval.evaluateModel(m_nodeModel, m_instances);
m_rootMeanSquaredError = nodeModelEval.rootMeanSquaredError();
} else {
// Prune the left and right subtrees
if (m_left != null) {
m_left.prune();
}
if (m_right != null) {
m_right.prune();
}
buildLinearModel(m_indices);
nodeModelEval = new Evaluation(m_instances);
double rmsModel;
double adjustedErrorModel;
nodeModelEval.evaluateModel(m_nodeModel, m_instances);
rmsModel = nodeModelEval.rootMeanSquaredError();
adjustedErrorModel = rmsModel
* pruningFactor(m_numInstances,
m_nodeModel.numParameters() + 1);
// Evaluate this node (ie its left and right subtrees)
Evaluation nodeEval = new Evaluation(m_instances);
double rmsSubTree;
double adjustedErrorNode;
int l_params = 0, r_params = 0;
nodeEval.evaluateModel(this, m_instances);
rmsSubTree = nodeEval.rootMeanSquaredError();
if (m_left != null) {
l_params = m_left.numParameters();
}
if (m_right != null) {
r_params = m_right.numParameters();
}
adjustedErrorNode = rmsSubTree
* pruningFactor(m_numInstances,
(l_params + r_params + 1));
if ((adjustedErrorModel <= adjustedErrorNode)
|| (adjustedErrorModel < (m_globalDeviation * 0.00001))) {
// Choose linear model for this node rather than subtree model
m_isLeaf = true;
m_right = null;
m_left = null;
m_numParameters = m_nodeModel.numParameters() + 1;
m_rootMeanSquaredError = rmsModel;
} else {
m_numParameters = (l_params + r_params + 1);
m_rootMeanSquaredError = rmsSubTree;
}
}
// save space
if (!m_saveInstances) {
m_instances = new Instances(m_instances, 0);
}
}
/**
* Compute the pruning factor
*
* @param num_instances number of instances
* @param num_params number of parameters in the model
* @return the pruning factor
*/
private double pruningFactor(int num_instances, int num_params) {
if (num_instances <= num_params) {
return 10.0; // Caution says Yong in his code
}
return ((double) (num_instances + m_pruningMultiplier * num_params)
/ (double) (num_instances - num_params));
}
/**
* Find the leaf with greatest coverage
*
* @param maxCoverage the greatest coverage found so far
* @param bestLeaf the leaf with the greatest coverage
*/
public void findBestLeaf(double[] maxCoverage, RuleNode[] bestLeaf) {
if (!m_isLeaf) {
if (m_left != null) {
m_left.findBestLeaf(maxCoverage, bestLeaf);
}
if (m_right != null) {
m_right.findBestLeaf(maxCoverage, bestLeaf);
}
} else {
if (m_numInstances > maxCoverage[0]) {
maxCoverage[0] = m_numInstances;
bestLeaf[0] = this;
}
}
}
/**
* Return a list containing all the leaves in the tree
*
* @param v a single element array containing a vector of leaves
*/
public void returnLeaves(FastVector[] v) {
if (m_isLeaf) {
v[0].addElement(this);
} else {
if (m_left != null) {
m_left.returnLeaves(v);
}
if (m_right != null) {
m_right.returnLeaves(v);
}
}
}
/**
* Get the parent of this node
*
* @return the parent of this node
*/
public RuleNode parentNode() {
return m_parent;
}
/**
* Get the left child of this node
*
* @return the left child of this node
*/
public RuleNode leftNode() {
return m_left;
}
/**
* Get the right child of this node
*
* @return the right child of this node
*/
public RuleNode rightNode() {
return m_right;
}
/**
* Get the index of the splitting attribute for this node
*
* @return the index of the splitting attribute
*/
public int splitAtt() {
return m_splitAtt;
}
/**
* Get the split point for this node
*
* @return the split point for this node
*/
public double splitVal() {
return m_splitValue;
}
/**
* Get the number of linear models in the tree
*
* @return the number of linear models
*/
public int numberOfLinearModels() {
if (m_isLeaf) {
return 1;
} else {
return m_left.numberOfLinearModels() + m_right.numberOfLinearModels();
}
}
/**
* Return true if this node is a leaf
*
* @return true if this node is a leaf
*/
public boolean isLeaf() {
return m_isLeaf;
}
/**
* Get the root mean squared error at this node
*
* @return the root mean squared error
*/
protected double rootMeanSquaredError() {
return m_rootMeanSquaredError;
}
/**
* Get the linear model at this node
*
* @return the linear model at this node
*/
/* public LinearRegression getModel() {
return m_nodeModel;
} */
public PreConstructedLinearModel getModel() {
return m_nodeModel;
}
/**
* Return the number of instances that reach this node.
*
* @return the number of instances at this node.
*/
public int getNumInstances() {
return m_numInstances;
}
/**
* Get the number of parameters in the model at this node
*
* @return the number of parameters in the model at this node
*/
private int numParameters() {
return m_numParameters;
}
/**
* Get the value of regressionTree.
*
* @return Value of regressionTree.
*/
public boolean getRegressionTree() {
return m_regressionTree;
}
/**
* Set the minumum number of instances to allow at a leaf node
*
* @param minNum the minimum number of instances
*/
public void setMinNumInstances(double minNum) {
m_splitNum = minNum;
}
/**
* Get the minimum number of instances to allow at a leaf node
*
* @return a <code>double</code> value
*/
public double getMinNumInstances() {
return m_splitNum;
}
/**
* Set the value of regressionTree.
*
* @param newregressionTree Value to assign to regressionTree.
*/
public void setRegressionTree(boolean newregressionTree) {
m_regressionTree = newregressionTree;
}
/**
* Print all the linear models at the learf (debugging purposes)
*/
public void printAllModels() {
if (m_isLeaf) {
System.out.println(m_nodeModel.toString());
} else {
System.out.println(m_nodeModel.toString());
m_left.printAllModels();
m_right.printAllModels();
}
}
/**
* Assigns a unique identifier to each node in the tree
*
* @param lastID last id number used
* @return ID after processing child nodes
*/
protected int assignIDs(int lastID) {
int currLastID = lastID + 1;
m_id = currLastID;
if (m_left != null) {
currLastID = m_left.assignIDs(currLastID);
}
if (m_right != null) {
currLastID = m_right.assignIDs(currLastID);
}
return currLastID;
}
/**
* Assign a unique identifier to each node in the tree and then
* calls graphTree
*
* @param text a <code>StringBuffer</code> value
*/
public void graph(StringBuffer text) {
assignIDs(-1);
graphTree(text);
}
/**
* Return a dotty style string describing the tree
*
* @param text a <code>StringBuffer</code> value
*/
protected void graphTree(StringBuffer text) {
text.append("N" + m_id
+ (m_isLeaf
? " [label=\"LM " + m_leafModelNum
: " [label=\"" + m_instances.attribute(m_splitAtt).name())
+ (m_isLeaf
? " (" + ((m_globalDeviation > 0.0)
? m_numInstances + "/"
+ Utils.doubleToString((100.0 *
m_rootMeanSquaredError /
m_globalAbsDeviation),
1, 3)
+ "%)"
: m_numInstances + ")")
+ "\" shape=box style=filled "
: "\"")
+ (m_saveInstances
? "data=\n" + m_instances + "\n,\n"
: "")
+ "]\n");
if (m_left != null) {
text.append("N" + m_id + "->" + "N" + m_left.m_id + " [label=\"<="
+ Utils.doubleToString(m_splitValue, 1, 3)
+ "\"]\n");
m_left.graphTree(text);
}
if (m_right != null) {
text.append("N" + m_id + "->" + "N" + m_right.m_id + " [label=\">"
+ Utils.doubleToString(m_splitValue, 1, 3)
+ "\"]\n");
m_right.graphTree(text);
}
}
/**
* Set whether to save instances for visualization purposes.
* Default is to save memory.
*
* @param save a <code>boolean</code> value
*/
protected void setSaveInstances(boolean save) {
m_saveInstances = save;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -