⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 reptree.java

📁 为了下东西 随便发了个 datamining 的源代码
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
	  text.append(m_Info.attribute(m_Attribute).name() + " >= " +
		      Utils.doubleToString(m_SplitPoint, 2));
	  text.append(m_Successors[1].toString(level + 1, this));
	}
      
	return text.toString();
      } catch (Exception e) {
	e.printStackTrace();
	return "Decision tree: tree can't be printed";
      }
    }     

    /**
     * Recursively generates a tree.
     */
    protected void buildTree(int[][] sortedIndices, double[][] weights,
			     Instances data, double totalWeight, 
			     double[] classProbs, Instances header,
			     double minNum, double minVariance,
			     int depth, int maxDepth) 
      throws Exception {
      
      // Store structure of dataset, set minimum number of instances
      // and make space for potential info from pruning data
      m_Info = header;
      m_HoldOutDist = new double[data.numClasses()];
	
      // Make leaf if there are no training instances
      int helpIndex = 0;
      if (data.classIndex() == 0) {
	helpIndex = 1;
      }
      if (sortedIndices[helpIndex].length == 0) {
	if (data.classAttribute().isNumeric()) {
	  m_Distribution = new double[2];
	} else {
	  m_Distribution = new double[data.numClasses()];
	}
	m_ClassProbs = null;
	return;
      }
      
      double priorVar = 0;
      if (data.classAttribute().isNumeric()) {

	// Compute prior variance
	double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0; 
	for (int i = 0; i < sortedIndices[helpIndex].length; i++) {
	  Instance inst = data.instance(sortedIndices[helpIndex][i]);
	  totalSum += inst.classValue() * weights[helpIndex][i];
	  totalSumSquared += 
	    inst.classValue() * inst.classValue() * weights[helpIndex][i];
	  totalSumOfWeights += weights[helpIndex][i];
	}
	priorVar = singleVariance(totalSum, totalSumSquared, 
				  totalSumOfWeights);
      }

      // Check if node doesn't contain enough instances, is pure
      // or the maximum tree depth is reached
      m_ClassProbs = new double[classProbs.length];
      System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
      if ((totalWeight < (2 * minNum)) ||

	  // Nominal case
	  (data.classAttribute().isNominal() &&
	   Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)],
		    Utils.sum(m_ClassProbs))) ||

	  // Numeric case
	  (data.classAttribute().isNumeric() && 
	   ((priorVar / totalWeight) < minVariance)) ||

	  // Check tree depth
	  ((m_MaxDepth >= 0) && (depth >= maxDepth))) {

	// Make leaf
	m_Attribute = -1;
	if (data.classAttribute().isNominal()) {

	  // Nominal case
	  m_Distribution = new double[m_ClassProbs.length];
	  for (int i = 0; i < m_ClassProbs.length; i++) {
	    m_Distribution[i] = m_ClassProbs[i];
	  }
	  Utils.normalize(m_ClassProbs);
	} else {

	  // Numeric case
	  m_Distribution = new double[2];
	  m_Distribution[0] = priorVar;
	  m_Distribution[1] = totalWeight;
	}
	return;
      }

      // Compute class distributions and value of splitting
      // criterion for each attribute
      double[] vals = new double[data.numAttributes()];
      double[][][] dists = new double[data.numAttributes()][0][0];
      double[][] props = new double[data.numAttributes()][0];
      double[][] totalSubsetWeights = new double[data.numAttributes()][0];
      double[] splits = new double[data.numAttributes()];
      if (data.classAttribute().isNominal()) { 

	// Nominal case
	for (int i = 0; i < data.numAttributes(); i++) {
	  if (i != data.classIndex()) {
	    splits[i] = distribution(props, dists, i, sortedIndices[i], 
				     weights[i], totalSubsetWeights, data);
	    vals[i] = gain(dists[i], priorVal(dists[i]));
	  }
	}
      } else {

	// Numeric case
	for (int i = 0; i < data.numAttributes(); i++) {
	  if (i != data.classIndex()) {
	    splits[i] = 
	      numericDistribution(props, dists, i, sortedIndices[i], 
				  weights[i], totalSubsetWeights, data, 
				  vals);
	  }
	}
      }

      // Find best attribute
      m_Attribute = Utils.maxIndex(vals);
      int numAttVals = dists[m_Attribute].length;

      // Check if there are at least two subsets with
      // required minimum number of instances
      int count = 0;
      for (int i = 0; i < numAttVals; i++) {
	if (totalSubsetWeights[m_Attribute][i] >= minNum) {
	  count++;
	}
	if (count > 1) {
	  break;
	}
      }

      // Any useful split found?
      if ((vals[m_Attribute] > 0) && (count > 1)) {

	// Build subtrees
	m_SplitPoint = splits[m_Attribute];
	m_Prop = props[m_Attribute];
	int[][][] subsetIndices = 
	  new int[numAttVals][data.numAttributes()][0];
	double[][][] subsetWeights = 
	  new double[numAttVals][data.numAttributes()][0];
	splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint, 
		  sortedIndices, weights, data);
	m_Successors = new Tree[numAttVals];
	for (int i = 0; i < numAttVals; i++) {
	  m_Successors[i] = new Tree();
	  m_Successors[i].
	    buildTree(subsetIndices[i], subsetWeights[i], 
		      data, totalSubsetWeights[m_Attribute][i],
		      dists[m_Attribute][i], header, minNum, 
		      minVariance, depth + 1, maxDepth);
	}
      } else {
      
	// Make leaf
	m_Attribute = -1;
      }

      // Normalize class counts
      if (data.classAttribute().isNominal()) {
	m_Distribution = new double[m_ClassProbs.length];
	for (int i = 0; i < m_ClassProbs.length; i++) {
	    m_Distribution[i] = m_ClassProbs[i];
	}
	Utils.normalize(m_ClassProbs);
      } else {
	m_Distribution = new double[2];
	m_Distribution[0] = priorVar;
	m_Distribution[1] = totalWeight;
      }
    }

    /**
     * Computes size of the tree.
     */
    protected int numNodes() {
    
      if (m_Attribute == -1) {
	return 1;
      } else {
	int size = 1;
	for (int i = 0; i < m_Successors.length; i++) {
	  size += m_Successors[i].numNodes();
	}
	return size;
      }
    }

    /**
     * Splits instances into subsets.
     */
    protected void splitData(int[][][] subsetIndices, 
			     double[][][] subsetWeights,
			     int att, double splitPoint, 
			     int[][] sortedIndices, double[][] weights, 
			     Instances data) throws Exception {
    
      int j;
      int[] num;
   
      // For each attribute
      for (int i = 0; i < data.numAttributes(); i++) {
	if (i != data.classIndex()) {
	  if (data.attribute(att).isNominal()) {

	    // For nominal attributes
	    num = new int[data.attribute(att).numValues()];
	    for (int k = 0; k < num.length; k++) {
	      subsetIndices[k][i] = new int[sortedIndices[i].length];
	      subsetWeights[k][i] = new double[sortedIndices[i].length];
	    }
	    for (j = 0; j < sortedIndices[i].length; j++) {
	      Instance inst = data.instance(sortedIndices[i][j]);
	      if (inst.isMissing(att)) {

		// Split instance up
		for (int k = 0; k < num.length; k++) {
		  if (m_Prop[k] > 0) {
		    subsetIndices[k][i][num[k]] = sortedIndices[i][j];
		    subsetWeights[k][i][num[k]] = 
		      m_Prop[k] * weights[i][j];
		    num[k]++;
		  }
		}
	      } else {
		int subset = (int)inst.value(att);
		subsetIndices[subset][i][num[subset]] = 
		  sortedIndices[i][j];
		subsetWeights[subset][i][num[subset]] = weights[i][j];
		num[subset]++;
	      }
	    }
	  } else {

	    // For numeric attributes
	    num = new int[2];
	    for (int k = 0; k < 2; k++) {
	      subsetIndices[k][i] = new int[sortedIndices[i].length];
	      subsetWeights[k][i] = new double[weights[i].length];
	    }
	    for (j = 0; j < sortedIndices[i].length; j++) {
	      Instance inst = data.instance(sortedIndices[i][j]);
	      if (inst.isMissing(att)) {

		// Split instance up
		for (int k = 0; k < num.length; k++) {
		  if (m_Prop[k] > 0) {
		    subsetIndices[k][i][num[k]] = sortedIndices[i][j];
		    subsetWeights[k][i][num[k]] = 
		      m_Prop[k] * weights[i][j];
		    num[k]++;
		  }
		}
	      } else {
		int subset = (inst.value(att) < splitPoint) ? 0 : 1;
		subsetIndices[subset][i][num[subset]] = 
		  sortedIndices[i][j];
		subsetWeights[subset][i][num[subset]] = weights[i][j];
		num[subset]++;
	      } 
	    }
	  }
	
	  // Trim arrays
	  for (int k = 0; k < num.length; k++) {
	    int[] copy = new int[num[k]];
	    System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);
	    subsetIndices[k][i] = copy;
	    double[] copyWeights = new double[num[k]];
	    System.arraycopy(subsetWeights[k][i], 0,
			     copyWeights, 0, num[k]);
	    subsetWeights[k][i] = copyWeights;
	  }
	}
      }
    }

    /**
     * Computes class distribution for an attribute.
     */
    protected double distribution(double[][] props,
				  double[][][] dists, int att, 
				  int[] sortedIndices,
				  double[] weights, 
				  double[][] subsetWeights, 
				  Instances data) 
      throws Exception {

      double splitPoint = Double.NaN;
      Attribute attribute = data.attribute(att);
      double[][] dist = null;
      int i;

      if (attribute.isNominal()) {

	// For nominal attributes
	dist = new double[attribute.numValues()][data.numClasses()];
	for (i = 0; i < sortedIndices.length; i++) {
	  Instance inst = data.instance(sortedIndices[i]);
	  if (inst.isMissing(att)) {
	    break;
	  }
	  dist[(int)inst.value(att)][(int)inst.classValue()] += weights[i];
	}
      } else {

	// For numeric attributes
	double[][] currDist = new double[2][data.numClasses()];
	dist = new double[2][data.numClasses()];

	// Move all instances into second subset
	for (int j = 0; j < sortedIndices.length; j++) {
	  Instance inst = data.instance(sortedIndices[j]);
	  if (inst.isMissing(att)) {
	    break;
	  }
	  currDist[1][(int)inst.classValue()] += weights[j];
	}
	double priorVal = priorVal(currDist);
	System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);

	// Try all possible split points
	double currSplit = data.instance(sortedIndices[0]).value(att);
	double currVal, bestVal = -Double.MAX_VALUE;
	for (i = 0; i < sortedIndices.length; i++) {
	  Instance inst = data.instance(sortedIndices[i]);
	  if (inst.isMissing(att)) {
	    break;
	  }
	  if (inst.value(att) > currSplit) {
	    currVal = gain(currDist, priorVal);
	    if (currVal > bestVal) {
	      bestVal = currVal;
	      splitPoint = (inst.value(att) + currSplit) / 2.0;
	      for (int j = 0; j < currDist.length; j++) {
		System.arraycopy(currDist[j], 0, dist[j], 0, 
				 dist[j].length);
	      }
	    } 
	  } 
	  currSplit = inst.value(att);
	  currDist[0][(int)inst.classValue()] += weights[i];
	  currDist[1][(int)inst.classValue()] -= weights[i];
	}
      }

      // Compute weights
      props[att] = new double[dist.length];
      for (int k = 0; k < props[att].length; k++) {
	props[att][k] = Utils.sum(dist[k]);
      }
      if (!(Utils.sum(props[att]) > 0)) {
	for (int k = 0; k < props[att].length; k++) {
	  props[att][k] = 1.0 / (double)props[att].length;
	}
      } else {
	Utils.normalize(props[att]);
      }
    
      // Distribute counts
      while (i < sortedIndices.length) {
	Instance inst = data.instance(sortedIndices[i]);
	for (int j = 0; j < dist.length; j++) {
	  dist[j][(int)inst.classValue()] += props[att][j] * weights[i];
	}
	i++;
      }

      // Compute subset weights
      subsetWeights[att] = new double[dist.length];
      for (int j = 0; j < dist.length; j++) {
	subsetWeights[att][j] += Utils.sum(dist[j]);
      }

      // Return distribution and split point
      dists[att] = dist;
      return splitPoint;
    }      

    /**
     * Computes class distribution for an attribute.
     */
    protected double numericDistribution(double[][] props, 
					 double[][][] dists, int att, 
					 int[] sortedIndices,
					 double[] weights, 
					 double[][] subsetWeights, 
					 Instances data,
					 double[] vals) 
      throws Exception {

      double splitPoint = Double.NaN;
      Attribute attribute = data.attribute(att);
      double[][] dist = null;
      double[] sums = null;
      double[] sumSquared = null;
      double[] sumOfWeights = null;
      double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0;

      int i;

      if (attribute.isNominal()) {

	// For nominal attributes
	sums = new double[attribute.numValues()];
        sumSquared = new double[attribute.numValues()];
	sumOfWeights = new double[attribute.numValues()];
	int attVal;
	for (i = 0; i < sortedIndices.length; i++) {
	  Instance inst = data.instance(sortedIndices[i]);
	  if (inst.isMissing(att)) {
	    break;
	  }
	  attVal = (int)inst.value(att);
	  sums[attVal] += inst.classValue() * weights[i];
	  sumSquared[attVal] += 
	    inst.classValue() * inst.classValue() * weights[i];
	  sumOfWeights[attVal] += weights[i];
	}
	totalSum = Utils.sum(sums);
	totalSumSquared = Utils.sum(sumSquared);
	totalSumOfWeights = Utils.sum(sumOfWeights);
      } else {

	// For numeric attributes
	sums = new double[2];
        sumSquared = new double[2];
	sumOfWeights = new double[2];
	double[] currSums = new double[2];
        double[] currSumSquared = new double[2];
	double[] currSumOfWeights = new double[2];

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -