📄 rulenode.java
字号:
if (m_left != null) { text.append(m_left.nodeToString()); } if (m_right != null) { text.append(m_right.nodeToString()); } return text.toString(); } /** * Recursively builds a textual description of the tree * * @param level the level of this node * @return string describing the tree */ public String treeToString(int level) { int i; StringBuffer text = new StringBuffer(); if (!m_isLeaf) { text.append("\n"); for (i = 1; i <= level; i++) { text.append("| "); } if (m_instances.attribute(m_splitAtt).name().charAt(0) != '[') { text.append(m_instances.attribute(m_splitAtt).name() + " <= " + Utils.doubleToString(m_splitValue, 1, 3) + " : "); } else { text.append(m_instances.attribute(m_splitAtt).name() + " false : "); } if (m_left != null) { text.append(m_left.treeToString(level + 1)); } else { text.append("NULL\n"); } for (i = 1; i <= level; i++) { text.append("| "); } if (m_instances.attribute(m_splitAtt).name().charAt(0) != '[') { text.append(m_instances.attribute(m_splitAtt).name() + " > " + Utils.doubleToString(m_splitValue, 1, 3) + " : "); } else { text.append(m_instances.attribute(m_splitAtt).name() + " true : "); } if (m_right != null) { text.append(m_right.treeToString(level + 1)); } else { text.append("NULL\n"); } } else { text.append("LM" + m_leafModelNum); if (m_globalDeviation > 0.0) { text .append(" (" + m_numInstances + "/" + Utils.doubleToString((100.0 * m_rootMeanSquaredError / m_globalDeviation), 1, 3) + "%)\n"); } else { text.append(" (" + m_numInstances + ")\n"); } } return text.toString(); } /** * Traverses the tree and installs linear models at each node. * This method must be called if pruning is not to be performed. * * @exception Exception if an error occurs */ public void installLinearModels() throws Exception { Evaluation nodeModelEval; if (m_isLeaf) { buildLinearModel(m_indices); } else { if (m_left != null) { m_left.installLinearModels(); } if (m_right != null) { m_right.installLinearModels(); } buildLinearModel(m_indices); } nodeModelEval = new Evaluation(m_instances); nodeModelEval.evaluateModel(m_nodeModel, m_instances); m_rootMeanSquaredError = nodeModelEval.rootMeanSquaredError(); // save space if (!m_saveInstances) { m_instances = new Instances(m_instances, 0); } } public void installSmoothedModels() throws Exception { if (m_isLeaf) { double [] coefficients = new double [m_numAttributes]; double intercept; double [] coeffsUsedByLinearModel = m_nodeModel.coefficients(); RuleNode current = this; // prime array with leaf node coefficients for (int i = 0; i < coeffsUsedByLinearModel.length; i++) { if (i != m_classIndex) { coefficients[i] = coeffsUsedByLinearModel[i]; } } // intercept intercept = m_nodeModel.intercept(); do { if (current.m_parent != null) { PreConstructedLinearModel thisL = current.m_parent.getModel(); double n = current.m_numInstances; // contribution of the model below for (int i = 0; i < coefficients.length; i++) { coefficients[i] = ((coefficients[i] * n) / (n + SMOOTHING_CONSTANT)); } intercept = ((intercept * n) / (n + SMOOTHING_CONSTANT)); // contribution of this model coeffsUsedByLinearModel = current.m_parent.getModel().coefficients(); for (int i = 0; i < coeffsUsedByLinearModel.length; i++) { if (i != m_classIndex) { // smooth in these coefficients (at this node) coefficients[i] += ((SMOOTHING_CONSTANT * coeffsUsedByLinearModel[i]) / (n + SMOOTHING_CONSTANT)); } } // smooth in the intercept intercept += ((SMOOTHING_CONSTANT * current.m_parent.getModel().intercept()) / (n + SMOOTHING_CONSTANT)); current = current.m_parent; } } while (current.m_parent != null); m_nodeModel = new PreConstructedLinearModel(coefficients, intercept); m_nodeModel.buildClassifier(m_instances); } if (m_left != null) { m_left.installSmoothedModels(); } if (m_right != null) { m_right.installSmoothedModels(); } } /** * Recursively prune the tree * * @exception Exception if an error occurs */ public void prune() throws Exception { Evaluation nodeModelEval = null; if (m_isLeaf) { buildLinearModel(m_indices); nodeModelEval = new Evaluation(m_instances); // count the constant term as a paramter for a leaf // Evaluate the model nodeModelEval.evaluateModel(m_nodeModel, m_instances); m_rootMeanSquaredError = nodeModelEval.rootMeanSquaredError(); } else { // Prune the left and right subtrees if (m_left != null) { m_left.prune(); } if (m_right != null) { m_right.prune(); } buildLinearModel(m_indices); nodeModelEval = new Evaluation(m_instances); double rmsModel; double adjustedErrorModel; nodeModelEval.evaluateModel(m_nodeModel, m_instances); rmsModel = nodeModelEval.rootMeanSquaredError(); adjustedErrorModel = rmsModel * pruningFactor(m_numInstances, m_nodeModel.numParameters() + 1); // Evaluate this node (ie its left and right subtrees) Evaluation nodeEval = new Evaluation(m_instances); double rmsSubTree; double adjustedErrorNode; int l_params = 0, r_params = 0; nodeEval.evaluateModel(this, m_instances); rmsSubTree = nodeEval.rootMeanSquaredError(); if (m_left != null) { l_params = m_left.numParameters(); } if (m_right != null) { r_params = m_right.numParameters(); } adjustedErrorNode = rmsSubTree * pruningFactor(m_numInstances, (l_params + r_params + 1)); if ((adjustedErrorModel <= adjustedErrorNode) || (adjustedErrorModel < (m_globalDeviation * 0.00001))) { // Choose linear model for this node rather than subtree model m_isLeaf = true; m_right = null; m_left = null; m_numParameters = m_nodeModel.numParameters() + 1; m_rootMeanSquaredError = rmsModel; } else { m_numParameters = (l_params + r_params + 1); m_rootMeanSquaredError = rmsSubTree; } } // save space if (!m_saveInstances) { m_instances = new Instances(m_instances, 0); } } /** * Compute the pruning factor * * @param num_instances number of instances * @param num_params number of parameters in the model * @return the pruning factor */ private double pruningFactor(int num_instances, int num_params) { if (num_instances <= num_params) { return 10.0; // Caution says Yong in his code } return ((double) (num_instances + m_pruningMultiplier * num_params) / (double) (num_instances - num_params)); } /** * Find the leaf with greatest coverage * * @param maxCoverage the greatest coverage found so far * @param bestLeaf the leaf with the greatest coverage */ public void findBestLeaf(double[] maxCoverage, RuleNode[] bestLeaf) { if (!m_isLeaf) { if (m_left != null) { m_left.findBestLeaf(maxCoverage, bestLeaf); } if (m_right != null) { m_right.findBestLeaf(maxCoverage, bestLeaf); } } else { if (m_numInstances > maxCoverage[0]) { maxCoverage[0] = m_numInstances; bestLeaf[0] = this; } } } /** * Return a list containing all the leaves in the tree * * @param v a single element array containing a vector of leaves */ public void returnLeaves(FastVector[] v) { if (m_isLeaf) { v[0].addElement(this); } else { if (m_left != null) { m_left.returnLeaves(v); } if (m_right != null) { m_right.returnLeaves(v); } } } /** * Get the parent of this node * * @return the parent of this node */ public RuleNode parentNode() { return m_parent; } /** * Get the left child of this node * * @return the left child of this node */ public RuleNode leftNode() { return m_left; } /** * Get the right child of this node * * @return the right child of this node */ public RuleNode rightNode() { return m_right; } /** * Get the index of the splitting attribute for this node * * @return the index of the splitting attribute */ public int splitAtt() { return m_splitAtt; } /** * Get the split point for this node * * @return the split point for this node */ public double splitVal() { return m_splitValue; } /** * Get the number of linear models in the tree * * @return the number of linear models */ public int numberOfLinearModels() { if (m_isLeaf) { return 1; } else { return m_left.numberOfLinearModels() + m_right.numberOfLinearModels(); } } /** * Return true if this node is a leaf * * @return true if this node is a leaf */ public boolean isLeaf() { return m_isLeaf; } /** * Get the root mean squared error at this node * * @return the root mean squared error */ protected double rootMeanSquaredError() { return m_rootMeanSquaredError; } /** * Get the linear model at this node * * @return the linear model at this node */ /* public LinearRegression getModel() { return m_nodeModel; } */ public PreConstructedLinearModel getModel() { return m_nodeModel; } /** * Return the number of instances that reach this node. * * @return the number of instances at this node. */ public int getNumInstances() { return m_numInstances; } /** * Get the number of parameters in the model at this node * * @return the number of parameters in the model at this node */ private int numParameters() { return m_numParameters; } /** * Get the value of regressionTree. * * @return Value of regressionTree. */ public boolean getRegressionTree() { return m_regressionTree; } /** * Set the minumum number of instances to allow at a leaf node * * @param minNum the minimum number of instances */ public void setMinNumInstances(double minNum) { m_splitNum = minNum; } /** * Get the minimum number of instances to allow at a leaf node * * @return a <code>double</code> value */ public double getMinNumInstances() { return m_splitNum; } /** * Set the value of regressionTree. * * @param newregressionTree Value to assign to regressionTree. */ public void setRegressionTree(boolean newregressionTree) { m_regressionTree = newregressionTree; } /** * Print all the linear models at the learf (debugging purposes) */ public void printAllModels() { if (m_isLeaf) { System.out.println(m_nodeModel.toString()); } else { System.out.println(m_nodeModel.toString()); m_left.printAllModels(); m_right.printAllModels(); } } /** * Assigns a unique identifier to each node in the tree * * @param lastID last id number used * @return ID after processing child nodes */ protected int assignIDs(int lastID) { int currLastID = lastID + 1; m_id = currLastID; if (m_left != null) { currLastID = m_left.assignIDs(currLastID); } if (m_right != null) { currLastID = m_right.assignIDs(currLastID); } return currLastID; } /** * Assign a unique identifier to each node in the tree and then * calls graphTree * * @param text a <code>StringBuffer</code> value */ public void graph(StringBuffer text) { assignIDs(-1); graphTree(text); } /** * Return a dotty style string describing the tree * * @param text a <code>StringBuffer</code> value */ protected void graphTree(StringBuffer text) { text.append("N" + m_id + (m_isLeaf ? " [label=\"LM " + m_leafModelNum : " [label=\"" + m_instances.attribute(m_splitAtt).name()) + (m_isLeaf ? " (" + ((m_globalDeviation > 0.0) ? m_numInstances + "/" + Utils.doubleToString((100.0 * m_rootMeanSquaredError / m_globalDeviation), 1, 3) + "%)" : m_numInstances + ")") + "\" shape=box style=filled " : "\"") + (m_saveInstances ? "data=\n" + m_instances + "\n,\n" : "") + "]\n"); if (m_left != null) { text.append("N" + m_id + "->" + "N" + m_left.m_id + " [label=\"<=" + Utils.doubleToString(m_splitValue, 1, 3) + "\"]\n"); m_left.graphTree(text); } if (m_right != null) { text.append("N" + m_id + "->" + "N" + m_right.m_id + " [label=\">" + Utils.doubleToString(m_splitValue, 1, 3) + "\"]\n"); m_right.graphTree(text); } } /** * Set whether to save instances for visualization purposes. * Default is to save memory. * * @param save a <code>boolean</code> value */ protected void setSaveInstances(boolean save) { m_saveInstances = save; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -