⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 randomtree.java

📁 Java 编写的多种数据挖掘算法 包括聚类、分类、预处理等
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
    } else {      return returnedDist;    }  }  /**   * Outputs the decision tree as a graph   *    * @return the tree as a graph   */  public String toGraph() {    try {      StringBuffer resultBuff = new StringBuffer();      toGraph(resultBuff, 0);      String result = "digraph Tree {\n" + "edge [style=bold]\n" + resultBuff.toString()	+ "\n}\n";      return result;    } catch (Exception e) {      return null;    }  }    /**   * Outputs one node for graph.   *    * @param text the buffer to append the output to   * @param num unique node id   * @return the next node id   * @throws Exception if generation fails   */  public int toGraph(StringBuffer text, int num) throws Exception {        int maxIndex = Utils.maxIndex(m_ClassProbs);    String classValue = m_Info.classAttribute().value(maxIndex);        num++;    if (m_Attribute == -1) {      text.append("N" + Integer.toHexString(hashCode()) +		  " [label=\"" + num + ": " + classValue + "\"" +		  "shape=box]\n");    }else {      text.append("N" + Integer.toHexString(hashCode()) +		  " [label=\"" + num + ": " + classValue + "\"]\n");      for (int i = 0; i < m_Successors.length; i++) {	text.append("N" + Integer.toHexString(hashCode()) 		    + "->" + 		    "N" + Integer.toHexString(m_Successors[i].hashCode())  +		    " [label=\"" + m_Info.attribute(m_Attribute).name());	if (m_Info.attribute(m_Attribute).isNumeric()) {	  if (i == 0) {	    text.append(" < " +			Utils.doubleToString(m_SplitPoint, 2));	  } else {	    text.append(" >= " +			Utils.doubleToString(m_SplitPoint, 2));	  }	} else {	  text.append(" = " + m_Info.attribute(m_Attribute).value(i));	}	text.append("\"]\n");	num = m_Successors[i].toGraph(text, num);      }    }        return num;  }    /**   * Outputs the decision tree.   *    * @return a string representation of the classifier   */  public String toString() {        if (m_Successors == null) {      return "RandomTree: no model has been built yet.";    } else {      return     	"\nRandomTree\n==========\n" + toString(0) + "\n" +	"\nSize of the tree : " + numNodes() +	(getMaxDepth() > 0 ? ("\nMax depth of tree: " + getMaxDepth()) : (""));    }  }  /**   * Outputs a leaf.   *    * @return the leaf as string   * @throws Exception if generation fails   */  protected String leafString() throws Exception {        int maxIndex = Utils.maxIndex(m_Distribution[0]);    return " : " + m_Info.classAttribute().value(maxIndex) +       " (" + Utils.doubleToString(Utils.sum(m_Distribution[0]), 2) + "/" +       Utils.doubleToString((Utils.sum(m_Distribution[0]) - 			    m_Distribution[0][maxIndex]), 2) + ")";  }    /**   * Recursively outputs the tree.   *    * @param level the current level of the tree   * @return the generated subtree   */  protected String toString(int level) {    try {      StringBuffer text = new StringBuffer();            if (m_Attribute == -1) {		// Output leaf info	return leafString();      } else if (m_Info.attribute(m_Attribute).isNominal()) {		// For nominal attributes	for (int i = 0; i < m_Successors.length; i++) {	  text.append("\n");	  for (int j = 0; j < level; j++) {	    text.append("|   ");	  }	  text.append(m_Info.attribute(m_Attribute).name() + " = " +		      m_Info.attribute(m_Attribute).value(i));	  text.append(m_Successors[i].toString(level + 1));	}      } else {		// For numeric attributes	text.append("\n");	for (int j = 0; j < level; j++) {	  text.append("|   ");	}	text.append(m_Info.attribute(m_Attribute).name() + " < " +		    Utils.doubleToString(m_SplitPoint, 2));	text.append(m_Successors[0].toString(level + 1));	text.append("\n");	for (int j = 0; j < level; j++) {	  text.append("|   ");	}	text.append(m_Info.attribute(m_Attribute).name() + " >= " +		    Utils.doubleToString(m_SplitPoint, 2));	text.append(m_Successors[1].toString(level + 1));      }            return text.toString();    } catch (Exception e) {      e.printStackTrace();      return "RandomTree: tree can't be printed";    }  }       /**   * Recursively generates a tree.   *    * @param sortedIndices the indices of the instances   * @param weights the weights of the instances   * @param data the data to work with   * @param classProbs the class distribution   * @param header the header of the data   * @param minNum the minimum number of instances per leaf   * @param debug whether debugging is on   * @param attIndicesWindow the attribute window to choose attributes from   * @param random random number generator for choosing random attributes   * @param depth the current depth   * @throws Exception if generation fails   */  protected void buildTree(int[][] sortedIndices, double[][] weights,			 Instances data, double[] classProbs, 			 Instances header, double minNum, boolean debug,			 int[] attIndicesWindow, Random random, int depth)     throws Exception {    // Store structure of dataset, set minimum number of instances    m_Info = header;    m_Debug = debug;    m_MinNum = minNum;    // Make leaf if there are no training instances    if (((data.classIndex() > 0) && (sortedIndices[0].length == 0)) ||	((data.classIndex() == 0) && sortedIndices[1].length == 0)) {      m_Distribution = new double[1][data.numClasses()];      m_ClassProbs = null;      return;    }    // Check if node doesn't contain enough instances or is pure     // or maximum depth reached    m_ClassProbs = new double[classProbs.length];    System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);    if (Utils.sm(Utils.sum(m_ClassProbs), 2 * m_MinNum) ||	Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)],		 Utils.sum(m_ClassProbs)) ||         ((getMaxDepth() > 0) && (depth >= getMaxDepth()))) {      // Make leaf      m_Attribute = -1;      m_Distribution = new double[1][m_ClassProbs.length];      for (int i = 0; i < m_ClassProbs.length; i++) {	m_Distribution[0][i] = m_ClassProbs[i];      }      Utils.normalize(m_ClassProbs);      return;    }    // Compute class distributions and value of splitting    // criterion for each attribute    double[] vals = new double[data.numAttributes()];    double[][][] dists = new double[data.numAttributes()][0][0];    double[][] props = new double[data.numAttributes()][0];    double[] splits = new double[data.numAttributes()];    // Investigate K random attributes    int attIndex = 0;    int windowSize = attIndicesWindow.length;    int k = m_KValue;    boolean gainFound = false;    while ((windowSize > 0) && (k-- > 0 || !gainFound)) {      int chosenIndex = random.nextInt(windowSize);      attIndex = attIndicesWindow[chosenIndex];            // shift chosen attIndex out of window      attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize-1];      attIndicesWindow[windowSize-1] = attIndex;      windowSize--;      splits[attIndex] = distribution(props, dists, attIndex,				      sortedIndices[attIndex], 				      weights[attIndex], data);      vals[attIndex] = gain(dists[attIndex], priorVal(dists[attIndex]));      if (Utils.gr(vals[attIndex], 0)) gainFound = true;    }    // Find best attribute    m_Attribute = Utils.maxIndex(vals);    m_Distribution = dists[m_Attribute];    // Any useful split found?    if (Utils.gr(vals[m_Attribute], 0)) {      // Build subtrees      m_SplitPoint = splits[m_Attribute];      m_Prop = props[m_Attribute];      int[][][] subsetIndices = 	new int[m_Distribution.length][data.numAttributes()][0];      double[][][] subsetWeights = 	new double[m_Distribution.length][data.numAttributes()][0];      splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint, 		sortedIndices, weights, m_Distribution, data);      m_Successors = new RandomTree[m_Distribution.length];      for (int i = 0; i < m_Distribution.length; i++) {	m_Successors[i] = new RandomTree();	m_Successors[i].setKValue(m_KValue);	m_Successors[i].setMaxDepth(getMaxDepth());	m_Successors[i].buildTree(subsetIndices[i], subsetWeights[i], data, 				  m_Distribution[i], header, m_MinNum, m_Debug,				  attIndicesWindow, random, depth + 1);      }    } else {            // Make leaf      m_Attribute = -1;      m_Distribution = new double[1][m_ClassProbs.length];      for (int i = 0; i < m_ClassProbs.length; i++) {	m_Distribution[0][i] = m_ClassProbs[i];      }    }    // Normalize class counts    Utils.normalize(m_ClassProbs);  }  /**   * Computes size of the tree.   *    * @return the number of nodes   */  public int numNodes() {        if (m_Attribute == -1) {      return 1;    } else {      int size = 1;      for (int i = 0; i < m_Successors.length; i++) {	size += m_Successors[i].numNodes();      }      return size;    }  }  /**   * Splits instances into subsets.   *    * @param subsetIndices the sorted indices of the subset   * @param subsetWeights the weights of the subset   * @param att the attribute index   * @param splitPoint the splitpoint for numeric attributes   * @param sortedIndices the sorted indices of the whole set   * @param weights the weights of the whole set   * @param dist the distribution   * @param data the data to work with   * @throws Exception if something goes wrong   */  protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights,			 int att, double splitPoint, 			 int[][] sortedIndices, double[][] weights,			 double[][] dist, Instances data) throws Exception {        int j;    int[] num;       // For each attribute    for (int i = 0; i < data.numAttributes(); i++) {      if (i != data.classIndex()) {	if (data.attribute(att).isNominal()) {	  // For nominal attributes	  num = new int[data.attribute(att).numValues()];	  for (int k = 0; k < num.length; k++) {	    subsetIndices[k][i] = new int[sortedIndices[i].length];	    subsetWeights[k][i] = new double[sortedIndices[i].length];	  }	  for (j = 0; j < sortedIndices[i].length; j++) {	    Instance inst = data.instance(sortedIndices[i][j]);	    if (inst.isMissing(att)) {	      // Split instance up	      for (int k = 0; k < num.length; k++) {		if (Utils.gr(m_Prop[k], 0)) {		  subsetIndices[k][i][num[k]] = sortedIndices[i][j];		  subsetWeights[k][i][num[k]] = m_Prop[k] * weights[i][j];		  num[k]++;		}	      }	    } else {	      int subset = (int)inst.value(att);	      subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];	      subsetWeights[subset][i][num[subset]] = weights[i][j];	      num[subset]++;	    }	  }	} else {	  // For numeric attributes	  num = new int[2];	  for (int k = 0; k < 2; k++) {	    subsetIndices[k][i] = new int[sortedIndices[i].length];	    subsetWeights[k][i] = new double[weights[i].length];	  }	  for (j = 0; j < sortedIndices[i].length; j++) {	    Instance inst = data.instance(sortedIndices[i][j]);	    if (inst.isMissing(att)) {	      // Split instance up	      for (int k = 0; k < num.length; k++) {		if (Utils.gr(m_Prop[k], 0)) {		  subsetIndices[k][i][num[k]] = sortedIndices[i][j];		  subsetWeights[k][i][num[k]] = m_Prop[k] * weights[i][j];		  num[k]++;		}	      }	    } else {	      int subset = Utils.sm(inst.value(att), splitPoint) ? 0 : 1;	      subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];	      subsetWeights[subset][i][num[subset]] = weights[i][j];	      num[subset]++;	    } 	  }	}		// Trim arrays	for (int k = 0; k < num.length; k++) {	  int[] copy = new int[num[k]];	  System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);	  subsetIndices[k][i] = copy;	  double[] copyWeights = new double[num[k]];	  System.arraycopy(subsetWeights[k][i], 0, copyWeights, 0, num[k]);	  subsetWeights[k][i] = copyWeights;	}      }    }  }  /**   * Computes class distribution for an attribute.   *    * @param probs   * @param dists   * @param att the attribute index   * @param sortedIndices the sorted indices of the data   * @param data the data to work with   * @throws Exception if something goes wrong   */  protected double distribution(double[][] props, double[][][] dists, int att, 			      int[] sortedIndices,			      double[] weights, Instances data)     throws Exception {    double splitPoint = Double.NaN;    Attribute attribute = data.attribute(att);    double[][] dist = null;    int i;    if (attribute.isNominal()) {      // For nominal attributes      dist = new double[attribute.numValues()][data.numClasses()];      for (i = 0; i < sortedIndices.length; i++) {	Instance inst = data.instance(sortedIndices[i]);	if (inst.isMissing(att)) {	  break;	}	dist[(int)inst.value(att)][(int)inst.classValue()] += weights[i];      }    } else {      // For numeric attributes      double[][] currDist = new double[2][data.numClasses()];      dist = new double[2][data.numClasses()];      // Move all instances into second subset      for (int j = 0; j < sortedIndices.length; j++) {	Instance inst = data.instance(sortedIndices[j]);	if (inst.isMissing(att)) {	  break;	}	currDist[1][(int)inst.classValue()] += weights[j];      }      double priorVal = priorVal(currDist);      for (int j = 0; j < currDist.length; j++) {	System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length);      }      // Try all possible split points      double currSplit = data.instance(sortedIndices[0]).value(att);      double currVal, bestVal = -Double.MAX_VALUE;      for (i = 0; i < sortedIndices.length; i++) {	Instance inst = data.instance(sortedIndices[i]);	if (inst.isMissing(att)) {	  break;	}	if (Utils.gr(inst.value(att), currSplit)) {	  currVal = gain(currDist, priorVal);	  if (Utils.gr(currVal, bestVal)) {	    bestVal = currVal;	    splitPoint = (inst.value(att) + currSplit) / 2.0;	    for (int j = 0; j < currDist.length; j++) {	      System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length);	    }	  } 	} 	currSplit = inst.value(att);	currDist[0][(int)inst.classValue()] += weights[i];	currDist[1][(int)inst.classValue()] -= weights[i];      }    }    // Compute weights    props[att] = new double[dist.length];    for (int k = 0; k < props[att].length; k++) {      props[att][k] = Utils.sum(dist[k]);    }    if (Utils.eq(Utils.sum(props[att]), 0)) {      for (int k = 0; k < props[att].length; k++) {	props[att][k] = 1.0 / (double)props[att].length;      }    } else {      Utils.normalize(props[att]);    }        // Any instances with missing values ?    if (i < sortedIndices.length) {	      // Distribute counts      while (i < sortedIndices.length) {	Instance inst = data.instance(sortedIndices[i]);	for (int j = 0; j < dist.length; j++) {	  dist[j][(int)inst.classValue()] += props[att][j] * weights[i];	}	i++;      }    }    // Return distribution and split point    dists[att] = dist;    return splitPoint;  }        /**   * Computes value of splitting criterion before split.   *    * @param dist the distributions   * @return the splitting criterion   */  protected double priorVal(double[][] dist) {    return ContingencyTables.entropyOverColumns(dist);  }  /**   * Computes value of splitting criterion after split.   *    * @param dist the distributions   * @param priorVal the splitting criterion   * @return the gain after the split   */  protected double gain(double[][] dist, double priorVal) {    return priorVal - ContingencyTables.entropyConditionedOnRows(dist);  }  /**   * Main method for this class.   *    * @param argv the commandline parameters   */  public static void main(String[] argv) {    try {      System.out.println(Evaluation.evaluateModel(new RandomTree(), argv));    } catch (Exception e) {      e.printStackTrace();      System.err.println(e.getMessage());    }  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -