⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 decisionstump.java

📁 6个java实现的分类方法(检测过
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
    if (m_Instances.classAttribute().isNominal()) {      return findSplitNominalNominal(index);    } else {      return findSplitNominalNumeric(index);    }  }  /**   * Finds best split for nominal attribute and nominal class   * and returns value.   *   * @param index attribute index   * @return value of criterion for the best split   * @throws Exception if something goes wrong   */  private double findSplitNominalNominal(int index) throws Exception {    double bestVal = Double.MAX_VALUE, currVal;    double[][] counts = new double[m_Instances.attribute(index).numValues() 				  + 1][m_Instances.numClasses()];    double[] sumCounts = new double[m_Instances.numClasses()];    double[][] bestDist = new double[3][m_Instances.numClasses()];    int numMissing = 0;    // Compute counts for all the values    for (int i = 0; i < m_Instances.numInstances(); i++) {      Instance inst = m_Instances.instance(i);      if (inst.isMissing(index)) {	numMissing++;	counts[m_Instances.attribute(index).numValues()]	  [(int)inst.classValue()] += inst.weight();      } else {	counts[(int)inst.value(index)][(int)inst.classValue()] += inst	  .weight();      }    }    // Compute sum of counts    for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {      for (int j = 0; j < m_Instances.numClasses(); j++) {	sumCounts[j] += counts[i][j];      }    }        // Make split counts for each possible split and evaluate    System.arraycopy(counts[m_Instances.attribute(index).numValues()], 0,		     m_Distribution[2], 0, m_Instances.numClasses());    for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {      for (int j = 0; j < m_Instances.numClasses(); j++) {	m_Distribution[0][j] = counts[i][j];	m_Distribution[1][j] = sumCounts[j] - counts[i][j];      }      currVal = ContingencyTables.entropyConditionedOnRows(m_Distribution);      if (currVal < bestVal) {	bestVal = currVal;	m_SplitPoint = (double)i;	for (int j = 0; j < 3; j++) {	  System.arraycopy(m_Distribution[j], 0, bestDist[j], 0, 			   m_Instances.numClasses());	}      }    }    // No missing values in training data.    if (numMissing == 0) {      System.arraycopy(sumCounts, 0, bestDist[2], 0, 		       m_Instances.numClasses());    }       m_Distribution = bestDist;    return bestVal;  }  /**   * Finds best split for nominal attribute and numeric class   * and returns value.   *   * @param index attribute index   * @return value of criterion for the best split   * @throws Exception if something goes wrong   */  private double findSplitNominalNumeric(int index) throws Exception {    double bestVal = Double.MAX_VALUE, currVal;    double[] sumsSquaresPerValue =       new double[m_Instances.attribute(index).numValues()],       sumsPerValue = new double[m_Instances.attribute(index).numValues()],       weightsPerValue = new double[m_Instances.attribute(index).numValues()];    double totalSumSquaresW = 0, totalSumW = 0, totalSumOfWeightsW = 0,      totalSumOfWeights = 0, totalSum = 0;    double[] sumsSquares = new double[3], sumOfWeights = new double[3];    double[][] bestDist = new double[3][1];    // Compute counts for all the values    for (int i = 0; i < m_Instances.numInstances(); i++) {      Instance inst = m_Instances.instance(i);      if (inst.isMissing(index)) {	m_Distribution[2][0] += inst.classValue() * inst.weight();	sumsSquares[2] += inst.classValue() * inst.classValue() 	  * inst.weight();	sumOfWeights[2] += inst.weight();      } else {	weightsPerValue[(int)inst.value(index)] += inst.weight();	sumsPerValue[(int)inst.value(index)] += inst.classValue() 	  * inst.weight();	sumsSquaresPerValue[(int)inst.value(index)] += 	  inst.classValue() * inst.classValue() * inst.weight();      }      totalSumOfWeights += inst.weight();      totalSum += inst.classValue() * inst.weight();    }    // Check if the total weight is zero    if (totalSumOfWeights <= 0) {      return bestVal;    }    // Compute sum of counts without missing ones    for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {      totalSumOfWeightsW += weightsPerValue[i];      totalSumSquaresW += sumsSquaresPerValue[i];      totalSumW += sumsPerValue[i];    }        // Make split counts for each possible split and evaluate    for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {            m_Distribution[0][0] = sumsPerValue[i];      sumsSquares[0] = sumsSquaresPerValue[i];      sumOfWeights[0] = weightsPerValue[i];      m_Distribution[1][0] = totalSumW - sumsPerValue[i];      sumsSquares[1] = totalSumSquaresW - sumsSquaresPerValue[i];      sumOfWeights[1] = totalSumOfWeightsW - weightsPerValue[i];      currVal = variance(m_Distribution, sumsSquares, sumOfWeights);            if (currVal < bestVal) {	bestVal = currVal;	m_SplitPoint = (double)i;	for (int j = 0; j < 3; j++) {	  if (sumOfWeights[j] > 0) {	    bestDist[j][0] = m_Distribution[j][0] / sumOfWeights[j];	  } else {	    bestDist[j][0] = totalSum / totalSumOfWeights;	  }	}      }    }    m_Distribution = bestDist;    return bestVal;  }  /**   * Finds best split for numeric attribute and returns value.   *   * @param index attribute index   * @return value of criterion for the best split   * @throws Exception if something goes wrong   */  private double findSplitNumeric(int index) throws Exception {    if (m_Instances.classAttribute().isNominal()) {      return findSplitNumericNominal(index);    } else {      return findSplitNumericNumeric(index);    }  }  /**   * Finds best split for numeric attribute and nominal class   * and returns value.   *   * @param index attribute index   * @return value of criterion for the best split   * @throws Exception if something goes wrong   */  private double findSplitNumericNominal(int index) throws Exception {    double bestVal = Double.MAX_VALUE, currVal, currCutPoint;    int numMissing = 0;    double[] sum = new double[m_Instances.numClasses()];    double[][] bestDist = new double[3][m_Instances.numClasses()];    // Compute counts for all the values    for (int i = 0; i < m_Instances.numInstances(); i++) {      Instance inst = m_Instances.instance(i);      if (!inst.isMissing(index)) {	m_Distribution[1][(int)inst.classValue()] += inst.weight();      } else {	m_Distribution[2][(int)inst.classValue()] += inst.weight();	numMissing++;      }    }    System.arraycopy(m_Distribution[1], 0, sum, 0, m_Instances.numClasses());    // Save current distribution as best distribution    for (int j = 0; j < 3; j++) {      System.arraycopy(m_Distribution[j], 0, bestDist[j], 0, 		       m_Instances.numClasses());    }    // Sort instances    m_Instances.sort(index);        // Make split counts for each possible split and evaluate    for (int i = 0; i < m_Instances.numInstances() - (numMissing + 1); i++) {      Instance inst = m_Instances.instance(i);      Instance instPlusOne = m_Instances.instance(i + 1);      m_Distribution[0][(int)inst.classValue()] += inst.weight();      m_Distribution[1][(int)inst.classValue()] -= inst.weight();      if (inst.value(index) < instPlusOne.value(index)) {	currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;	currVal = ContingencyTables.entropyConditionedOnRows(m_Distribution);	if (currVal < bestVal) {	  m_SplitPoint = currCutPoint;	  bestVal = currVal;	  for (int j = 0; j < 3; j++) {	    System.arraycopy(m_Distribution[j], 0, bestDist[j], 0, 			     m_Instances.numClasses());	  }	}      }    }    // No missing values in training data.    if (numMissing == 0) {      System.arraycopy(sum, 0, bestDist[2], 0, m_Instances.numClasses());    }     m_Distribution = bestDist;    return bestVal;  }  /**   * Finds best split for numeric attribute and numeric class   * and returns value.   *   * @param index attribute index   * @return value of criterion for the best split   * @throws Exception if something goes wrong   */  private double findSplitNumericNumeric(int index) throws Exception {    double bestVal = Double.MAX_VALUE, currVal, currCutPoint;    int numMissing = 0;    double[] sumsSquares = new double[3], sumOfWeights = new double[3];    double[][] bestDist = new double[3][1];    double totalSum = 0, totalSumOfWeights = 0;    // Compute counts for all the values    for (int i = 0; i < m_Instances.numInstances(); i++) {      Instance inst = m_Instances.instance(i);      if (!inst.isMissing(index)) {	m_Distribution[1][0] += inst.classValue() * inst.weight();	sumsSquares[1] += inst.classValue() * inst.classValue() 	  * inst.weight();	sumOfWeights[1] += inst.weight();      } else {	m_Distribution[2][0] += inst.classValue() * inst.weight();	sumsSquares[2] += inst.classValue() * inst.classValue() 	  * inst.weight();	sumOfWeights[2] += inst.weight();	numMissing++;      }      totalSumOfWeights += inst.weight();      totalSum += inst.classValue() * inst.weight();    }    // Check if the total weight is zero    if (totalSumOfWeights <= 0) {      return bestVal;    }    // Sort instances    m_Instances.sort(index);        // Make split counts for each possible split and evaluate    for (int i = 0; i < m_Instances.numInstances() - (numMissing + 1); i++) {      Instance inst = m_Instances.instance(i);      Instance instPlusOne = m_Instances.instance(i + 1);      m_Distribution[0][0] += inst.classValue() * inst.weight();      sumsSquares[0] += inst.classValue() * inst.classValue() * inst.weight();      sumOfWeights[0] += inst.weight();      m_Distribution[1][0] -= inst.classValue() * inst.weight();      sumsSquares[1] -= inst.classValue() * inst.classValue() * inst.weight();      sumOfWeights[1] -= inst.weight();      if (inst.value(index) < instPlusOne.value(index)) {	currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;	currVal = variance(m_Distribution, sumsSquares, sumOfWeights);	if (currVal < bestVal) {	  m_SplitPoint = currCutPoint;	  bestVal = currVal;	  for (int j = 0; j < 3; j++) {	    if (sumOfWeights[j] > 0) {	      bestDist[j][0] = m_Distribution[j][0] / sumOfWeights[j];	    } else {	      bestDist[j][0] = totalSum / totalSumOfWeights;	    }	  }	}      }    }    m_Distribution = bestDist;    return bestVal;  }  /**   * Computes variance for subsets.   *    * @param s   * @param sS   * @param sumOfWeights   * @return the variance   */  private double variance(double[][] s,double[] sS,double[] sumOfWeights) {    double var = 0;    for (int i = 0; i < s.length; i++) {      if (sumOfWeights[i] > 0) {	var += sS[i] - ((s[i][0] * s[i][0]) / (double) sumOfWeights[i]);      }    }        return var;  }  /**   * Returns the subset an instance falls into.   *    * @param instance the instance to check   * @return the subset the instance falls into   * @throws Exception if something goes wrong   */  private int whichSubset(Instance instance) throws Exception {    if (instance.isMissing(m_AttIndex)) {      return 2;    } else if (instance.attribute(m_AttIndex).isNominal()) {      if ((int)instance.value(m_AttIndex) == m_SplitPoint) {	return 0;      } else {	return 1;      }    } else {      if (instance.value(m_AttIndex) <= m_SplitPoint) {	return 0;      } else {	return 1;      }    }  }   /**   * Main method for testing this class.   *   * @param argv the options   */  public static void main(String [] argv) {    runClassifier(new DecisionStump(), argv);  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -