randomtree.java

来自「Weka」· Java 代码 · 共 1,087 行 · 第 1/3 页

JAVA
1,087
字号
    } else {      setSeed(1);    }        tmpStr = Utils.getOption("depth", options);    if (tmpStr.length() != 0) {      setMaxDepth(Integer.parseInt(tmpStr));    } else {      setMaxDepth(0);    }        super.setOptions(options);        Utils.checkForRemainingOptions(options);  }  /**   * Returns default capabilities of the classifier.   *   * @return      the capabilities of this classifier   */  public Capabilities getCapabilities() {    Capabilities result = super.getCapabilities();    // attributes    result.enable(Capability.NOMINAL_ATTRIBUTES);    result.enable(Capability.NUMERIC_ATTRIBUTES);    result.enable(Capability.DATE_ATTRIBUTES);    result.enable(Capability.MISSING_VALUES);    // class    result.enable(Capability.NOMINAL_CLASS);    result.enable(Capability.MISSING_CLASS_VALUES);        return result;  }  /**   * Builds classifier.   *    * @param data the data to train with   * @throws Exception if something goes wrong or the data doesn't fit   */  public void buildClassifier(Instances data) throws Exception {    // Make sure K value is in range    if (m_KValue > data.numAttributes()-1) m_KValue = data.numAttributes()-1;    if (m_KValue < 1) m_KValue = (int) Utils.log2(data.numAttributes())+1;    // can classifier handle the data?    getCapabilities().testWithFail(data);    // remove instances with missing class    data = new Instances(data);    data.deleteWithMissingClass();        // only class? -> build ZeroR model    if (data.numAttributes() == 1) {      System.err.println(	  "Cannot build model (only class attribute present in data!), "	  + "using ZeroR model instead!");      m_ZeroR = new weka.classifiers.rules.ZeroR();      m_ZeroR.buildClassifier(data);      return;    }    else {      m_ZeroR = null;    }        Instances train = data;    // Create array of sorted indices and weights    int[][] sortedIndices = new int[train.numAttributes()][0];    double[][] weights = new double[train.numAttributes()][0];    double[] vals = new double[train.numInstances()];    for (int j = 0; j < train.numAttributes(); j++) {      if (j != train.classIndex()) {	weights[j] = new double[train.numInstances()];	if (train.attribute(j).isNominal()) {	  // Handling nominal attributes. Putting indices of	  // instances with missing values at the end.	  sortedIndices[j] = new int[train.numInstances()];	  int count = 0;	  for (int i = 0; i < train.numInstances(); i++) {	    Instance inst = train.instance(i);	    if (!inst.isMissing(j)) {	      sortedIndices[j][count] = i;	      weights[j][count] = inst.weight();	      count++;	    }	  }	  for (int i = 0; i < train.numInstances(); i++) {	    Instance inst = train.instance(i);	    if (inst.isMissing(j)) {	      sortedIndices[j][count] = i;	      weights[j][count] = inst.weight();	      count++;	    }	  }	} else {	  	  // Sorted indices are computed for numeric attributes	  for (int i = 0; i < train.numInstances(); i++) {	    Instance inst = train.instance(i);	    vals[i] = inst.value(j);	  }	  sortedIndices[j] = Utils.sort(vals);	  for (int i = 0; i < train.numInstances(); i++) {	    weights[j][i] = train.instance(sortedIndices[j][i]).weight();	  }	}      }    }    // Compute initial class counts    double[] classProbs = new double[train.numClasses()];    for (int i = 0; i < train.numInstances(); i++) {      Instance inst = train.instance(i);      classProbs[(int)inst.classValue()] += inst.weight();    }    // Create the attribute indices window    int[] attIndicesWindow = new int[data.numAttributes()-1];    int j=0;    for (int i=0; i<attIndicesWindow.length; i++) {      if (j == data.classIndex()) j++; // do not include the class      attIndicesWindow[i] = j++;    }    // Build tree    buildTree(sortedIndices, weights, train, classProbs,	      new Instances(train, 0), m_MinNum, m_Debug,	      attIndicesWindow, data.getRandomNumberGenerator(m_randomSeed), 0);  }    /**   * Computes class distribution of an instance using the decision tree.   *    * @param instance the instance to compute the distribution for   * @return the computed class distribution   * @throws Exception if computation fails   */  public double[] distributionForInstance(Instance instance) throws Exception {        // default model?    if (m_ZeroR != null) {      return m_ZeroR.distributionForInstance(instance);    }        double[] returnedDist = null;        if (m_Attribute > -1) {            // Node is not a leaf      if (instance.isMissing(m_Attribute)) {	// Value is missing	returnedDist = new double[m_Info.numClasses()];	// Split instance up	for (int i = 0; i < m_Successors.length; i++) {	  double[] help = m_Successors[i].distributionForInstance(instance);	  if (help != null) {	    for (int j = 0; j < help.length; j++) {	      returnedDist[j] += m_Prop[i] * help[j];	    }	  }	}      } else if (m_Info.attribute(m_Attribute).isNominal()) {	  	// For nominal attributes	returnedDist =  m_Successors[(int)instance.value(m_Attribute)].	  distributionForInstance(instance);      } else {		// For numeric attributes	if (instance.value(m_Attribute) < m_SplitPoint) {	  returnedDist = m_Successors[0].distributionForInstance(instance);	} else {	  returnedDist = m_Successors[1].distributionForInstance(instance);	}      }    }    if ((m_Attribute == -1) || (returnedDist == null)) {      // Node is a leaf or successor is empty      return m_ClassProbs;    } else {      return returnedDist;    }  }  /**   * Outputs the decision tree as a graph   *    * @return the tree as a graph   */  public String toGraph() {    try {      StringBuffer resultBuff = new StringBuffer();      toGraph(resultBuff, 0);      String result = "digraph Tree {\n" + "edge [style=bold]\n" + resultBuff.toString()	+ "\n}\n";      return result;    } catch (Exception e) {      return null;    }  }    /**   * Outputs one node for graph.   *    * @param text the buffer to append the output to   * @param num unique node id   * @return the next node id   * @throws Exception if generation fails   */  public int toGraph(StringBuffer text, int num) throws Exception {        int maxIndex = Utils.maxIndex(m_ClassProbs);    String classValue = m_Info.classAttribute().value(maxIndex);        num++;    if (m_Attribute == -1) {      text.append("N" + Integer.toHexString(hashCode()) +		  " [label=\"" + num + ": " + classValue + "\"" +		  "shape=box]\n");    }else {      text.append("N" + Integer.toHexString(hashCode()) +		  " [label=\"" + num + ": " + classValue + "\"]\n");      for (int i = 0; i < m_Successors.length; i++) {	text.append("N" + Integer.toHexString(hashCode()) 		    + "->" + 		    "N" + Integer.toHexString(m_Successors[i].hashCode())  +		    " [label=\"" + m_Info.attribute(m_Attribute).name());	if (m_Info.attribute(m_Attribute).isNumeric()) {	  if (i == 0) {	    text.append(" < " +			Utils.doubleToString(m_SplitPoint, 2));	  } else {	    text.append(" >= " +			Utils.doubleToString(m_SplitPoint, 2));	  }	} else {	  text.append(" = " + m_Info.attribute(m_Attribute).value(i));	}	text.append("\"]\n");	num = m_Successors[i].toGraph(text, num);      }    }        return num;  }    /**   * Outputs the decision tree.   *    * @return a string representation of the classifier   */  public String toString() {        // only ZeroR model?    if (m_ZeroR != null) {      StringBuffer buf = new StringBuffer();      buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");      buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");      buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");      buf.append(m_ZeroR.toString());      return buf.toString();    }        if (m_Successors == null) {      return "RandomTree: no model has been built yet.";    } else {      return     	"\nRandomTree\n==========\n" + toString(0) + "\n" +	"\nSize of the tree : " + numNodes() +	(getMaxDepth() > 0 ? ("\nMax depth of tree: " + getMaxDepth()) : (""));    }  }  /**   * Outputs a leaf.   *    * @return the leaf as string   * @throws Exception if generation fails   */  protected String leafString() throws Exception {        int maxIndex = Utils.maxIndex(m_Distribution[0]);    return " : " + m_Info.classAttribute().value(maxIndex) +       " (" + Utils.doubleToString(Utils.sum(m_Distribution[0]), 2) + "/" +       Utils.doubleToString((Utils.sum(m_Distribution[0]) - 			    m_Distribution[0][maxIndex]), 2) + ")";  }    /**   * Recursively outputs the tree.   *    * @param level the current level of the tree   * @return the generated subtree   */  protected String toString(int level) {    try {      StringBuffer text = new StringBuffer();            if (m_Attribute == -1) {		// Output leaf info	return leafString();      } else if (m_Info.attribute(m_Attribute).isNominal()) {		// For nominal attributes	for (int i = 0; i < m_Successors.length; i++) {	  text.append("\n");	  for (int j = 0; j < level; j++) {	    text.append("|   ");	  }	  text.append(m_Info.attribute(m_Attribute).name() + " = " +		      m_Info.attribute(m_Attribute).value(i));	  text.append(m_Successors[i].toString(level + 1));	}      } else {		// For numeric attributes	text.append("\n");	for (int j = 0; j < level; j++) {	  text.append("|   ");	}	text.append(m_Info.attribute(m_Attribute).name() + " < " +		    Utils.doubleToString(m_SplitPoint, 2));	text.append(m_Successors[0].toString(level + 1));	text.append("\n");	for (int j = 0; j < level; j++) {	  text.append("|   ");	}	text.append(m_Info.attribute(m_Attribute).name() + " >= " +		    Utils.doubleToString(m_SplitPoint, 2));	text.append(m_Successors[1].toString(level + 1));      }            return text.toString();    } catch (Exception e) {      e.printStackTrace();      return "RandomTree: tree can't be printed";    }  }       /**   * Recursively generates a tree.   *    * @param sortedIndices the indices of the instances   * @param weights the weights of the instances   * @param data the data to work with   * @param classProbs the class distribution   * @param header the header of the data   * @param minNum the minimum number of instances per leaf   * @param debug whether debugging is on   * @param attIndicesWindow the attribute window to choose attributes from

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?