randomtree.java
来自「Weka」· Java 代码 · 共 1,087 行 · 第 1/3 页
JAVA
1,087 行
} else { setSeed(1); } tmpStr = Utils.getOption("depth", options); if (tmpStr.length() != 0) { setMaxDepth(Integer.parseInt(tmpStr)); } else { setMaxDepth(0); } super.setOptions(options); Utils.checkForRemainingOptions(options); } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); return result; } /** * Builds classifier. * * @param data the data to train with * @throws Exception if something goes wrong or the data doesn't fit */ public void buildClassifier(Instances data) throws Exception { // Make sure K value is in range if (m_KValue > data.numAttributes()-1) m_KValue = data.numAttributes()-1; if (m_KValue < 1) m_KValue = (int) Utils.log2(data.numAttributes())+1; // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); // only class? -> build ZeroR model if (data.numAttributes() == 1) { System.err.println( "Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!"); m_ZeroR = new weka.classifiers.rules.ZeroR(); m_ZeroR.buildClassifier(data); return; } else { m_ZeroR = null; } Instances train = data; // Create array of sorted indices and weights int[][] sortedIndices = new int[train.numAttributes()][0]; double[][] weights = new double[train.numAttributes()][0]; double[] vals = new double[train.numInstances()]; for (int j = 0; j < train.numAttributes(); j++) { if (j != train.classIndex()) { weights[j] = new double[train.numInstances()]; if (train.attribute(j).isNominal()) { // Handling nominal attributes. Putting indices of // instances with missing values at the end. sortedIndices[j] = new int[train.numInstances()]; int count = 0; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (!inst.isMissing(j)) { sortedIndices[j][count] = i; weights[j][count] = inst.weight(); count++; } } for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (inst.isMissing(j)) { sortedIndices[j][count] = i; weights[j][count] = inst.weight(); count++; } } } else { // Sorted indices are computed for numeric attributes for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); vals[i] = inst.value(j); } sortedIndices[j] = Utils.sort(vals); for (int i = 0; i < train.numInstances(); i++) { weights[j][i] = train.instance(sortedIndices[j][i]).weight(); } } } } // Compute initial class counts double[] classProbs = new double[train.numClasses()]; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); classProbs[(int)inst.classValue()] += inst.weight(); } // Create the attribute indices window int[] attIndicesWindow = new int[data.numAttributes()-1]; int j=0; for (int i=0; i<attIndicesWindow.length; i++) { if (j == data.classIndex()) j++; // do not include the class attIndicesWindow[i] = j++; } // Build tree buildTree(sortedIndices, weights, train, classProbs, new Instances(train, 0), m_MinNum, m_Debug, attIndicesWindow, data.getRandomNumberGenerator(m_randomSeed), 0); } /** * Computes class distribution of an instance using the decision tree. * * @param instance the instance to compute the distribution for * @return the computed class distribution * @throws Exception if computation fails */ public double[] distributionForInstance(Instance instance) throws Exception { // default model? if (m_ZeroR != null) { return m_ZeroR.distributionForInstance(instance); } double[] returnedDist = null; if (m_Attribute > -1) { // Node is not a leaf if (instance.isMissing(m_Attribute)) { // Value is missing returnedDist = new double[m_Info.numClasses()]; // Split instance up for (int i = 0; i < m_Successors.length; i++) { double[] help = m_Successors[i].distributionForInstance(instance); if (help != null) { for (int j = 0; j < help.length; j++) { returnedDist[j] += m_Prop[i] * help[j]; } } } } else if (m_Info.attribute(m_Attribute).isNominal()) { // For nominal attributes returnedDist = m_Successors[(int)instance.value(m_Attribute)]. distributionForInstance(instance); } else { // For numeric attributes if (instance.value(m_Attribute) < m_SplitPoint) { returnedDist = m_Successors[0].distributionForInstance(instance); } else { returnedDist = m_Successors[1].distributionForInstance(instance); } } } if ((m_Attribute == -1) || (returnedDist == null)) { // Node is a leaf or successor is empty return m_ClassProbs; } else { return returnedDist; } } /** * Outputs the decision tree as a graph * * @return the tree as a graph */ public String toGraph() { try { StringBuffer resultBuff = new StringBuffer(); toGraph(resultBuff, 0); String result = "digraph Tree {\n" + "edge [style=bold]\n" + resultBuff.toString() + "\n}\n"; return result; } catch (Exception e) { return null; } } /** * Outputs one node for graph. * * @param text the buffer to append the output to * @param num unique node id * @return the next node id * @throws Exception if generation fails */ public int toGraph(StringBuffer text, int num) throws Exception { int maxIndex = Utils.maxIndex(m_ClassProbs); String classValue = m_Info.classAttribute().value(maxIndex); num++; if (m_Attribute == -1) { text.append("N" + Integer.toHexString(hashCode()) + " [label=\"" + num + ": " + classValue + "\"" + "shape=box]\n"); }else { text.append("N" + Integer.toHexString(hashCode()) + " [label=\"" + num + ": " + classValue + "\"]\n"); for (int i = 0; i < m_Successors.length; i++) { text.append("N" + Integer.toHexString(hashCode()) + "->" + "N" + Integer.toHexString(m_Successors[i].hashCode()) + " [label=\"" + m_Info.attribute(m_Attribute).name()); if (m_Info.attribute(m_Attribute).isNumeric()) { if (i == 0) { text.append(" < " + Utils.doubleToString(m_SplitPoint, 2)); } else { text.append(" >= " + Utils.doubleToString(m_SplitPoint, 2)); } } else { text.append(" = " + m_Info.attribute(m_Attribute).value(i)); } text.append("\"]\n"); num = m_Successors[i].toGraph(text, num); } } return num; } /** * Outputs the decision tree. * * @return a string representation of the classifier */ public String toString() { // only ZeroR model? if (m_ZeroR != null) { StringBuffer buf = new StringBuffer(); buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n"); buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n"); buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n"); buf.append(m_ZeroR.toString()); return buf.toString(); } if (m_Successors == null) { return "RandomTree: no model has been built yet."; } else { return "\nRandomTree\n==========\n" + toString(0) + "\n" + "\nSize of the tree : " + numNodes() + (getMaxDepth() > 0 ? ("\nMax depth of tree: " + getMaxDepth()) : ("")); } } /** * Outputs a leaf. * * @return the leaf as string * @throws Exception if generation fails */ protected String leafString() throws Exception { int maxIndex = Utils.maxIndex(m_Distribution[0]); return " : " + m_Info.classAttribute().value(maxIndex) + " (" + Utils.doubleToString(Utils.sum(m_Distribution[0]), 2) + "/" + Utils.doubleToString((Utils.sum(m_Distribution[0]) - m_Distribution[0][maxIndex]), 2) + ")"; } /** * Recursively outputs the tree. * * @param level the current level of the tree * @return the generated subtree */ protected String toString(int level) { try { StringBuffer text = new StringBuffer(); if (m_Attribute == -1) { // Output leaf info return leafString(); } else if (m_Info.attribute(m_Attribute).isNominal()) { // For nominal attributes for (int i = 0; i < m_Successors.length; i++) { text.append("\n"); for (int j = 0; j < level; j++) { text.append("| "); } text.append(m_Info.attribute(m_Attribute).name() + " = " + m_Info.attribute(m_Attribute).value(i)); text.append(m_Successors[i].toString(level + 1)); } } else { // For numeric attributes text.append("\n"); for (int j = 0; j < level; j++) { text.append("| "); } text.append(m_Info.attribute(m_Attribute).name() + " < " + Utils.doubleToString(m_SplitPoint, 2)); text.append(m_Successors[0].toString(level + 1)); text.append("\n"); for (int j = 0; j < level; j++) { text.append("| "); } text.append(m_Info.attribute(m_Attribute).name() + " >= " + Utils.doubleToString(m_SplitPoint, 2)); text.append(m_Successors[1].toString(level + 1)); } return text.toString(); } catch (Exception e) { e.printStackTrace(); return "RandomTree: tree can't be printed"; } } /** * Recursively generates a tree. * * @param sortedIndices the indices of the instances * @param weights the weights of the instances * @param data the data to work with * @param classProbs the class distribution * @param header the header of the data * @param minNum the minimum number of instances per leaf * @param debug whether debugging is on * @param attIndicesWindow the attribute window to choose attributes from
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?