⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 decisionstump.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
    double[] sumCounts = new double[m_Instances.numClasses()];
    double[][] bestDist = new double[3][m_Instances.numClasses()];
    int numMissing = 0;

    // Compute counts for all the values
    for (int i = 0; i < m_Instances.numInstances(); i++) {
      Instance inst = m_Instances.instance(i);
      if (inst.isMissing(index)) {
	numMissing++;
	counts[m_Instances.attribute(index).numValues()]
	  [(int)inst.classValue()] += inst.weight();
      } else {
	counts[(int)inst.value(index)][(int)inst.classValue()] += inst
	  .weight();
      }
    }

    // Compute sum of counts
    for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
      for (int j = 0; j < m_Instances.numClasses(); j++) {
	sumCounts[j] += counts[i][j];
      }
    }
    
    // Make split counts for each possible split and evaluate
    System.arraycopy(counts[m_Instances.attribute(index).numValues()], 0,
		     m_Distribution[2], 0, m_Instances.numClasses());
    for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
      for (int j = 0; j < m_Instances.numClasses(); j++) {
	m_Distribution[0][j] = counts[i][j];
	m_Distribution[1][j] = sumCounts[j] - counts[i][j];
      }
      currVal = ContingencyTables.entropyConditionedOnRows(m_Distribution);
      if (currVal < bestVal) {
	bestVal = currVal;
	m_SplitPoint = (double)i;
	for (int j = 0; j < 3; j++) {
	  System.arraycopy(m_Distribution[j], 0, bestDist[j], 0, 
			   m_Instances.numClasses());
	}
      }
    }

    // No missing values in training data.
    if (numMissing == 0) {
      System.arraycopy(sumCounts, 0, bestDist[2], 0, 
		       m_Instances.numClasses());
    }
   
    m_Distribution = bestDist;
    return bestVal;
  }

  /**
   * Finds best split for nominal attribute and numeric class
   * and returns value.
   *
   * @param index attribute index
   * @return value of criterion for the best split
   * @exception Exception if something goes wrong
   */
  private double findSplitNominalNumeric(int index) throws Exception {

    double bestVal = Double.MAX_VALUE, currVal;
    double[] sumsSquaresPerValue = 
      new double[m_Instances.attribute(index).numValues()], 
      sumsPerValue = new double[m_Instances.attribute(index).numValues()], 
      weightsPerValue = new double[m_Instances.attribute(index).numValues()];
    double totalSumSquaresW = 0, totalSumW = 0, totalSumOfWeightsW = 0,
      totalSumOfWeights = 0, totalSum = 0;
    double[] sumsSquares = new double[3], sumOfWeights = new double[3];
    double[][] bestDist = new double[3][1];

    // Compute counts for all the values
    for (int i = 0; i < m_Instances.numInstances(); i++) {
      Instance inst = m_Instances.instance(i);
      if (inst.isMissing(index)) {
	m_Distribution[2][0] += inst.classValue() * inst.weight();
	sumsSquares[2] += inst.classValue() * inst.classValue() 
	  * inst.weight();
	sumOfWeights[2] += inst.weight();
      } else {
	weightsPerValue[(int)inst.value(index)] += inst.weight();
	sumsPerValue[(int)inst.value(index)] += inst.classValue() 
	  * inst.weight();
	sumsSquaresPerValue[(int)inst.value(index)] += 
	  inst.classValue() * inst.classValue() * inst.weight();
      }
      totalSumOfWeights += inst.weight();
      totalSum += inst.classValue() * inst.weight();
    }

    // Check if the total weight is zero
    if (totalSumOfWeights <= 0) {
      return bestVal;
    }

    // Compute sum of counts without missing ones
    for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
      totalSumOfWeightsW += weightsPerValue[i];
      totalSumSquaresW += sumsSquaresPerValue[i];
      totalSumW += sumsPerValue[i];
    }
    
    // Make split counts for each possible split and evaluate
    for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
      
      m_Distribution[0][0] = sumsPerValue[i];
      sumsSquares[0] = sumsSquaresPerValue[i];
      sumOfWeights[0] = weightsPerValue[i];
      m_Distribution[1][0] = totalSumW - sumsPerValue[i];
      sumsSquares[1] = totalSumSquaresW - sumsSquaresPerValue[i];
      sumOfWeights[1] = totalSumOfWeightsW - weightsPerValue[i];

      currVal = variance(m_Distribution, sumsSquares, sumOfWeights);
      
      if (currVal < bestVal) {
	bestVal = currVal;
	m_SplitPoint = (double)i;
	for (int j = 0; j < 3; j++) {
	  if (sumOfWeights[j] > 0) {
	    bestDist[j][0] = m_Distribution[j][0] / sumOfWeights[j];
	  } else {
	    bestDist[j][0] = totalSum / totalSumOfWeights;
	  }
	}
      }
    }

    m_Distribution = bestDist;
    return bestVal;
  }

  /**
   * Finds best split for numeric attribute and returns value.
   *
   * @param index attribute index
   * @return value of criterion for the best split
   * @exception Exception if something goes wrong
   */
  private double findSplitNumeric(int index) throws Exception {

    if (m_Instances.classAttribute().isNominal()) {
      return findSplitNumericNominal(index);
    } else {
      return findSplitNumericNumeric(index);
    }
  }

  /**
   * Finds best split for numeric attribute and nominal class
   * and returns value.
   *
   * @param index attribute index
   * @return value of criterion for the best split
   * @exception Exception if something goes wrong
   */
  private double findSplitNumericNominal(int index) throws Exception {

    double bestVal = Double.MAX_VALUE, currVal, currCutPoint;
    int numMissing = 0;
    double[] sum = new double[m_Instances.numClasses()];
    double[][] bestDist = new double[3][m_Instances.numClasses()];

    // Compute counts for all the values
    for (int i = 0; i < m_Instances.numInstances(); i++) {
      Instance inst = m_Instances.instance(i);
      if (!inst.isMissing(index)) {
	m_Distribution[1][(int)inst.classValue()] += inst.weight();
      } else {
	m_Distribution[2][(int)inst.classValue()] += inst.weight();
	numMissing++;
      }
    }
    System.arraycopy(m_Distribution[1], 0, sum, 0, m_Instances.numClasses());

    // Save current distribution as best distribution
    for (int j = 0; j < 3; j++) {
      System.arraycopy(m_Distribution[j], 0, bestDist[j], 0, 
		       m_Instances.numClasses());
    }

    // Sort instances
    m_Instances.sort(index);
    
    // Make split counts for each possible split and evaluate
    for (int i = 0; i < m_Instances.numInstances() - (numMissing + 1); i++) {
      Instance inst = m_Instances.instance(i);
      Instance instPlusOne = m_Instances.instance(i + 1);
      m_Distribution[0][(int)inst.classValue()] += inst.weight();
      m_Distribution[1][(int)inst.classValue()] -= inst.weight();
      if (inst.value(index) < instPlusOne.value(index)) {
	currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;
	currVal = ContingencyTables.entropyConditionedOnRows(m_Distribution);
	if (currVal < bestVal) {
	  m_SplitPoint = currCutPoint;
	  bestVal = currVal;
	  for (int j = 0; j < 3; j++) {
	    System.arraycopy(m_Distribution[j], 0, bestDist[j], 0, 
			     m_Instances.numClasses());
	  }
	}
      }
    }

    // No missing values in training data.
    if (numMissing == 0) {
      System.arraycopy(sum, 0, bestDist[2], 0, m_Instances.numClasses());
    }
 
    m_Distribution = bestDist;
    return bestVal;
  }

  /**
   * Finds best split for numeric attribute and numeric class
   * and returns value.
   *
   * @param index attribute index
   * @return value of criterion for the best split
   * @exception Exception if something goes wrong
   */
  private double findSplitNumericNumeric(int index) throws Exception {

    double bestVal = Double.MAX_VALUE, currVal, currCutPoint;
    int numMissing = 0;
    double[] sumsSquares = new double[3], sumOfWeights = new double[3];
    double[][] bestDist = new double[3][1];
    double totalSum = 0, totalSumOfWeights = 0;

    // Compute counts for all the values
    for (int i = 0; i < m_Instances.numInstances(); i++) {
      Instance inst = m_Instances.instance(i);
      if (!inst.isMissing(index)) {
	m_Distribution[1][0] += inst.classValue() * inst.weight();
	sumsSquares[1] += inst.classValue() * inst.classValue() 
	  * inst.weight();
	sumOfWeights[1] += inst.weight();
      } else {
	m_Distribution[2][0] += inst.classValue() * inst.weight();
	sumsSquares[2] += inst.classValue() * inst.classValue() 
	  * inst.weight();
	sumOfWeights[2] += inst.weight();
	numMissing++;
      }
      totalSumOfWeights += inst.weight();
      totalSum += inst.classValue() * inst.weight();
    }

    // Check if the total weight is zero
    if (totalSumOfWeights <= 0) {
      return bestVal;
    }

    // Sort instances
    m_Instances.sort(index);
    
    // Make split counts for each possible split and evaluate
    for (int i = 0; i < m_Instances.numInstances() - (numMissing + 1); i++) {
      Instance inst = m_Instances.instance(i);
      Instance instPlusOne = m_Instances.instance(i + 1);
      m_Distribution[0][0] += inst.classValue() * inst.weight();
      sumsSquares[0] += inst.classValue() * inst.classValue() * inst.weight();
      sumOfWeights[0] += inst.weight();
      m_Distribution[1][0] -= inst.classValue() * inst.weight();
      sumsSquares[1] -= inst.classValue() * inst.classValue() * inst.weight();
      sumOfWeights[1] -= inst.weight();
      if (inst.value(index) < instPlusOne.value(index)) {
	currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;
	currVal = variance(m_Distribution, sumsSquares, sumOfWeights);
	if (currVal < bestVal) {
	  m_SplitPoint = currCutPoint;
	  bestVal = currVal;
	  for (int j = 0; j < 3; j++) {
	    if (sumOfWeights[j] > 0) {
	      bestDist[j][0] = m_Distribution[j][0] / sumOfWeights[j];
	    } else {
	      bestDist[j][0] = totalSum / totalSumOfWeights;
	    }
	  }
	}
      }
    }

    m_Distribution = bestDist;
    return bestVal;
  }

  /**
   * Computes variance for subsets.
   */
  private double variance(double[][] s,double[] sS,double[] sumOfWeights) {

    double var = 0;

    for (int i = 0; i < s.length; i++) {
      if (sumOfWeights[i] > 0) {
	var += sS[i] - ((s[i][0] * s[i][0]) / (double) sumOfWeights[i]);
      }
    }
    
    return var;
  }

  /**
   * Returns the subset an instance falls into.
   */
  private int whichSubset(Instance instance) throws Exception {

    if (instance.isMissing(m_AttIndex)) {
      return 2;
    } else if (instance.attribute(m_AttIndex).isNominal()) {
      if ((int)instance.value(m_AttIndex) == m_SplitPoint) {
	return 0;
      } else {
	return 1;
      }
    } else {
      if (instance.value(m_AttIndex) <= m_SplitPoint) {
	return 0;
      } else {
	return 1;
      }
    }
  }
 
  /**
   * Main method for testing this class.
   *
   * @param argv the options
   */
  public static void main(String [] argv) {

    Classifier scheme;

    try {
      scheme = new DecisionStump();
      System.out.println(Evaluation.evaluateModel(scheme, argv));
    } catch (Exception e) {
      System.err.println(e.getMessage());
    }
  }
}


⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -