⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 randomtree.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
   */
  public int toGraph(StringBuffer text, int num) throws Exception {
    
    int maxIndex = Utils.maxIndex(m_ClassProbs);
    String classValue = m_Info.classAttribute().value(maxIndex);
    
    num++;
    if (m_Attribute == -1) {
      text.append("N" + Integer.toHexString(hashCode()) +
		  " [label=\"" + num + ": " + classValue + "\"" +
		  "shape=box]\n");
    }else {
      text.append("N" + Integer.toHexString(hashCode()) +
		  " [label=\"" + num + ": " + classValue + "\"]\n");
      for (int i = 0; i < m_Successors.length; i++) {
	text.append("N" + Integer.toHexString(hashCode()) 
		    + "->" + 
		    "N" + Integer.toHexString(m_Successors[i].hashCode())  +
		    " [label=\"" + m_Info.attribute(m_Attribute).name());
	if (m_Info.attribute(m_Attribute).isNumeric()) {
	  if (i == 0) {
	    text.append(" < " +
			Utils.doubleToString(m_SplitPoint, 2));
	  } else {
	    text.append(" >= " +
			Utils.doubleToString(m_SplitPoint, 2));
	  }
	} else {
	  text.append(" = " + m_Info.attribute(m_Attribute).value(i));
	}
	text.append("\"]\n");
	num = m_Successors[i].toGraph(text, num);
      }
    }
    
    return num;
  }
  
  /**
   * Outputs the decision tree.
   */
  public String toString() {
    
    if (m_Successors == null) {
      return "RandomTree: no model has been built yet.";
    } else {
      return     
	"\nRandomTree\n==========\n" + toString(0) + "\n" +
	"\nSize of the tree : " + numNodes();
    }
  }

  /**
   * Outputs a leaf.
   */
  protected String leafString() throws Exception {
    
    int maxIndex = Utils.maxIndex(m_Distribution[0]);

    return " : " + m_Info.classAttribute().value(maxIndex) + 
      " (" + Utils.doubleToString(Utils.sum(m_Distribution[0]), 2) + "/" + 
      Utils.doubleToString((Utils.sum(m_Distribution[0]) - 
			    m_Distribution[0][maxIndex]), 2) + ")";
  }
  
  /**
   * Recursively outputs the tree.
   */
  protected String toString(int level) {

    try {
      StringBuffer text = new StringBuffer();
      
      if (m_Attribute == -1) {
	
	// Output leaf info
	return leafString();
      } else if (m_Info.attribute(m_Attribute).isNominal()) {
	
	// For nominal attributes
	for (int i = 0; i < m_Successors.length; i++) {
	  text.append("\n");
	  for (int j = 0; j < level; j++) {
	    text.append("|   ");
	  }
	  text.append(m_Info.attribute(m_Attribute).name() + " = " +
		      m_Info.attribute(m_Attribute).value(i));
	  text.append(m_Successors[i].toString(level + 1));
	}
      } else {
	
	// For numeric attributes
	text.append("\n");
	for (int j = 0; j < level; j++) {
	  text.append("|   ");
	}
	text.append(m_Info.attribute(m_Attribute).name() + " < " +
		    Utils.doubleToString(m_SplitPoint, 2));
	text.append(m_Successors[0].toString(level + 1));
	text.append("\n");
	for (int j = 0; j < level; j++) {
	  text.append("|   ");
	}
	text.append(m_Info.attribute(m_Attribute).name() + " >= " +
		    Utils.doubleToString(m_SplitPoint, 2));
	text.append(m_Successors[1].toString(level + 1));
      }
      
      return text.toString();
    } catch (Exception e) {
      e.printStackTrace();
      return "RandomTree: tree can't be printed";
    }
  }     

  /**
   * Recursively generates a tree.
   */
  protected void buildTree(int[][] sortedIndices, double[][] weights,
			 Instances data, double[] classProbs, 
			 Instances header, double minNum, boolean debug,
			 int[] attIndicesWindow, Random random) 
    throws Exception {

    // Store structure of dataset, set minimum number of instances
    m_Info = header;
    m_Debug = debug;
    m_MinNum = minNum;

    // Make leaf if there are no training instances
    if (sortedIndices[0].length == 0) {
      m_Distribution = new double[1][data.numClasses()];
      m_ClassProbs = null;
      return;
    }

    // Check if node doesn't contain enough instances or is pure
    m_ClassProbs = new double[classProbs.length];
    System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
    if (Utils.sm(Utils.sum(m_ClassProbs), 2 * m_MinNum) ||
	Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)],
		 Utils.sum(m_ClassProbs))) {

      // Make leaf
      m_Attribute = -1;
      m_Distribution = new double[1][m_ClassProbs.length];
      for (int i = 0; i < m_ClassProbs.length; i++) {
	m_Distribution[0][i] = m_ClassProbs[i];
      }
      Utils.normalize(m_ClassProbs);
      return;
    }

    // Compute class distributions and value of splitting
    // criterion for each attribute
    double[] vals = new double[data.numAttributes()];
    double[][][] dists = new double[data.numAttributes()][0][0];
    double[][] props = new double[data.numAttributes()][0];
    double[] splits = new double[data.numAttributes()];

    // Investigate K random attributes
    int attIndex = 0;
    int windowSize = attIndicesWindow.length;
    int k = m_KValue;
    boolean gainFound = false;
    while ((windowSize > 0) && (k-- > 0 || !gainFound)) {

      int chosenIndex = random.nextInt(windowSize);
      attIndex = attIndicesWindow[chosenIndex];
      
      // shift chosen attIndex out of window
      attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize-1];
      attIndicesWindow[windowSize-1] = attIndex;
      windowSize--;

      splits[attIndex] = distribution(props, dists, attIndex,
				      sortedIndices[attIndex], 
				      weights[attIndex], data);
      vals[attIndex] = gain(dists[attIndex], priorVal(dists[attIndex]));

      if (Utils.gr(vals[attIndex], 0)) gainFound = true;
    }

    // Find best attribute
    m_Attribute = Utils.maxIndex(vals);
    m_Distribution = dists[m_Attribute];

    // Any useful split found?
    if (Utils.gr(vals[m_Attribute], 0)) {

      // Build subtrees
      m_SplitPoint = splits[m_Attribute];
      m_Prop = props[m_Attribute];
      int[][][] subsetIndices = 
	new int[m_Distribution.length][data.numAttributes()][0];
      double[][][] subsetWeights = 
	new double[m_Distribution.length][data.numAttributes()][0];
      splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint, 
		sortedIndices, weights, m_Distribution, data);
      m_Successors = new RandomTree[m_Distribution.length];
      for (int i = 0; i < m_Distribution.length; i++) {
	m_Successors[i] = new RandomTree();
	m_Successors[i].setKValue(m_KValue);
	m_Successors[i].buildTree(subsetIndices[i], subsetWeights[i], data, 
				  m_Distribution[i], header, m_MinNum, m_Debug,
				  attIndicesWindow, random);
      }
    } else {
      
      // Make leaf
      m_Attribute = -1;
      m_Distribution = new double[1][m_ClassProbs.length];
      for (int i = 0; i < m_ClassProbs.length; i++) {
	m_Distribution[0][i] = m_ClassProbs[i];
      }
    }

    // Normalize class counts
    Utils.normalize(m_ClassProbs);
  }

  /**
   * Computes size of the tree.
   */
  public int numNodes() {
    
    if (m_Attribute == -1) {
      return 1;
    } else {
      int size = 1;
      for (int i = 0; i < m_Successors.length; i++) {
	size += m_Successors[i].numNodes();
      }
      return size;
    }
  }

  /**
   * Splits instances into subsets.
   */
  protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights,
			 int att, double splitPoint, 
			 int[][] sortedIndices, double[][] weights,
			 double[][] dist, Instances data) throws Exception {
    
    int j;
    int[] num;
   
    // For each attribute
    for (int i = 0; i < data.numAttributes(); i++) {
      if (i != data.classIndex()) {
	if (data.attribute(att).isNominal()) {

	  // For nominal attributes
	  num = new int[data.attribute(att).numValues()];
	  for (int k = 0; k < num.length; k++) {
	    subsetIndices[k][i] = new int[sortedIndices[i].length];
	    subsetWeights[k][i] = new double[sortedIndices[i].length];
	  }
	  for (j = 0; j < sortedIndices[i].length; j++) {
	    Instance inst = data.instance(sortedIndices[i][j]);
	    if (inst.isMissing(att)) {

	      // Split instance up
	      for (int k = 0; k < num.length; k++) {
		if (Utils.gr(m_Prop[k], 0)) {
		  subsetIndices[k][i][num[k]] = sortedIndices[i][j];
		  subsetWeights[k][i][num[k]] = m_Prop[k] * weights[i][j];
		  num[k]++;
		}
	      }
	    } else {
	      int subset = (int)inst.value(att);
	      subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
	      subsetWeights[subset][i][num[subset]] = weights[i][j];
	      num[subset]++;
	    }
	  }
	} else {

	  // For numeric attributes
	  num = new int[2];
	  for (int k = 0; k < 2; k++) {
	    subsetIndices[k][i] = new int[sortedIndices[i].length];
	    subsetWeights[k][i] = new double[weights[i].length];
	  }
	  for (j = 0; j < sortedIndices[i].length; j++) {
	    Instance inst = data.instance(sortedIndices[i][j]);
	    if (inst.isMissing(att)) {

	      // Split instance up
	      for (int k = 0; k < num.length; k++) {
		if (Utils.gr(m_Prop[k], 0)) {
		  subsetIndices[k][i][num[k]] = sortedIndices[i][j];
		  subsetWeights[k][i][num[k]] = m_Prop[k] * weights[i][j];
		  num[k]++;
		}
	      }
	    } else {
	      int subset = Utils.sm(inst.value(att), splitPoint) ? 0 : 1;
	      subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
	      subsetWeights[subset][i][num[subset]] = weights[i][j];
	      num[subset]++;
	    } 
	  }
	}
	
	// Trim arrays
	for (int k = 0; k < num.length; k++) {
	  int[] copy = new int[num[k]];
	  System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);
	  subsetIndices[k][i] = copy;
	  double[] copyWeights = new double[num[k]];
	  System.arraycopy(subsetWeights[k][i], 0, copyWeights, 0, num[k]);
	  subsetWeights[k][i] = copyWeights;
	}
      }
    }
  }

  /**
   * Computes class distribution for an attribute.
   */
  protected double distribution(double[][] props, double[][][] dists, int att, 
			      int[] sortedIndices,
			      double[] weights, Instances data) 
    throws Exception {

    double splitPoint = Double.NaN;
    Attribute attribute = data.attribute(att);
    double[][] dist = null;
    int i;

    if (attribute.isNominal()) {

      // For nominal attributes
      dist = new double[attribute.numValues()][data.numClasses()];
      for (i = 0; i < sortedIndices.length; i++) {
	Instance inst = data.instance(sortedIndices[i]);
	if (inst.isMissing(att)) {
	  break;
	}
	dist[(int)inst.value(att)][(int)inst.classValue()] += weights[i];
      }
    } else {

      // For numeric attributes
      double[][] currDist = new double[2][data.numClasses()];
      dist = new double[2][data.numClasses()];

      // Move all instances into second subset
      for (int j = 0; j < sortedIndices.length; j++) {
	Instance inst = data.instance(sortedIndices[j]);
	if (inst.isMissing(att)) {
	  break;
	}
	currDist[1][(int)inst.classValue()] += weights[j];
      }
      double priorVal = priorVal(currDist);
      for (int j = 0; j < currDist.length; j++) {
	System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length);
      }

      // Try all possible split points
      double currSplit = data.instance(sortedIndices[0]).value(att);
      double currVal, bestVal = -Double.MAX_VALUE;
      for (i = 0; i < sortedIndices.length; i++) {
	Instance inst = data.instance(sortedIndices[i]);
	if (inst.isMissing(att)) {
	  break;
	}
	if (Utils.gr(inst.value(att), currSplit)) {
	  currVal = gain(currDist, priorVal);
	  if (Utils.gr(currVal, bestVal)) {
	    bestVal = currVal;
	    splitPoint = (inst.value(att) + currSplit) / 2.0;
	    for (int j = 0; j < currDist.length; j++) {
	      System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length);
	    }
	  } 
	} 
	currSplit = inst.value(att);
	currDist[0][(int)inst.classValue()] += weights[i];
	currDist[1][(int)inst.classValue()] -= weights[i];
      }
    }

    // Compute weights
    props[att] = new double[dist.length];
    for (int k = 0; k < props[att].length; k++) {
      props[att][k] = Utils.sum(dist[k]);
    }
    if (Utils.eq(Utils.sum(props[att]), 0)) {
      for (int k = 0; k < props[att].length; k++) {
	props[att][k] = 1.0 / (double)props[att].length;
      }
    } else {
      Utils.normalize(props[att]);
    }
    
    // Any instances with missing values ?
    if (i < sortedIndices.length) {
	
      // Distribute counts
      while (i < sortedIndices.length) {
	Instance inst = data.instance(sortedIndices[i]);
	for (int j = 0; j < dist.length; j++) {
	  dist[j][(int)inst.classValue()] += props[att][j] * weights[i];
	}
	i++;
      }
    }

    // Return distribution and split point
    dists[att] = dist;
    return splitPoint;
  }      

  /**
   * Computes value of splitting criterion before split.
   */
  protected double priorVal(double[][] dist) {

    return ContingencyTables.entropyOverColumns(dist);
  }

  /**
   * Computes value of splitting criterion after split.
   */
  protected double gain(double[][] dist, double priorVal) {

    return priorVal - ContingencyTables.entropyConditionedOnRows(dist);
  }

  /**
   * Main method for this class.
   */
  public static void main(String[] argv) {

    try {
      System.out.println(Evaluation.evaluateModel(new RandomTree(), argv));
    } catch (Exception e) {
      e.printStackTrace();
      System.err.println(e.getMessage());
    }
  }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -