⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 decisiontreealgorithm.java

📁 一个决策树的Applet(转载
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
   */
  public Attribute
    chooseAttribute( Vector atts, int[] stats, Vector results )
  {
    // If the list of available attributes is empty,
    // return null.
    if( atts.size() == 0 ) return null;

    int pos = 0;

    // Select an attribute, based on the current
    // splitting function.
    if( m_splitFun.equals( SPLIT_RANDOM ) )
      // Choose one of the unmasked attributes
      // at random.  We leave the results vector
      // empty in this case.
      pos = m_random.nextInt( atts.size() );
    else {
      // Calculate a result value for each
      // attribute that is available.
      double val = 0.0;
      double temp;

      for( int i = 0; i < atts.size(); i++ ) {
        if( m_splitFun.equals( SPLIT_GAIN ) ) {
          temp = getGain( stats, (Attribute)atts.elementAt( i ));

          if( temp > val ) {
            val = temp;
            pos = i;
          }

          results.add( new Double( temp ) );
        }
        else if( m_splitFun.equals( SPLIT_GAIN_RATIO ) ) {
          temp = getGainRatio( stats, (Attribute)atts.elementAt( i ) );

          if( temp > val ) {
            val = temp;
            pos = i;
          }

          results.add( new Double( temp ) );
        }
        else if( m_splitFun.equals( SPLIT_GINI ) ) {
          temp = getGINI( stats, (Attribute)atts.elementAt( i ));

          if( temp > val ) {
            val = temp;
            pos = i;
          }
        }
      }
    }

    return (Attribute)atts.elementAt( pos );
  }

  // Private methods

  /**
   * Fills the supplied array with the number of examples
   * from the current dataset that fall into each of the
   * target categories (based on the attribute mask).
   *
   * @param mask The mask that determines which examples
   *        reach the current position in the decision
   *        tree.
   *
   * @param examples An iteration over a series of
   *        examples from the current dataset.
   *
   * @param counts The method expects the parameter to be
   *        an array with a size equal to the number of
   *        target attribute values.  Each position in
   *        the array is filled with a corresponding count
   *        of the number of training examples that fall into
   *        that particular target class, at the current
   *        position in the decision tree.
   *
   * @param reachedHere The method expects the parameter
   *        to be an array with a size equal to the
   *        <i>total</i> number of examples being examined.
   *        Each cell in the array is set to true or
   *        false, depending on whether or not the
   *        corresponding example reaches the current
   *        position in the decision tree.
   */
  private void getExampleCounts( AttributeMask mask,
    Iterator examples, int counts[], boolean[] reachedHere )
  {
    // Zero any values currently in stats.
    for( int i = 0; i < counts.length; i++ )
      counts[i] = 0;

    int i = 0;

    // Loop through and get totals.
    while( examples.hasNext() ) {
      int[] example = (int[])examples.next();

      if( mask.matchMask( example ) ) {
        counts[ example[0] ]++;   // Increment appropriate
                                  // target count.
        if( reachedHere != null && reachedHere.length > i )
          reachedHere[i] = true;
      }

      i++;
    }
  }

  /**
   * Creates a formatted HTML message to display when a
   * leaf node is added.  This is a convenience method -
   * it makes the learnDT() method less cluttered.
   *
   * @param step The current leaf addition step (which
   *        identifies the message that should be
   *        returned).
   *
   * @param label The label for the new leaf.
   *
   * @return An HTML text string, suitable for display
   *         in a GUI panel.
   */
  private String createLeafMsg( int step, String label )
  {
    StringBuffer msg =
      new StringBuffer( "<html><font size=\"-1\">" );

    if( step == 1 )
      msg.append( "No examples reach this position. " +
        "Adding new leaf with default target class " +
        "<font color=\"yellow\">" + label + "</font>." );
    else if( step == 2 )
      msg.append( "All examples have the same target " +
        "classification. Adding new leaf with common " +
        "target class <font color=\"yellow\">" + label +
        "</font>." );
    else if( step == 3 )
      msg.append( "The set of attributes available for " +
        "splitting is empty. Adding new leaf " +
        "with most common target class" +
        "<font color=\"yellow\">" + label + "</font>." );

    return msg.append( "</font>" ).toString();
  }

  /**
   * Creates a formatted HTML message to display when an
   * internal node is added.  This is a convenience method -
   * it makes the learnDT() method less cluttered.
   *
   * @param step The current internal addition step (which
   *        identifies the message that should be
   *        returned).
   *
   * @param label The label for the new node.
   *
   * @return An HTML text string, suitable for display
   *         in a GUI panel.
   */
  private String createInternalMsg( int step, String label )
  {
    StringBuffer msg =
      new StringBuffer( "<html><font size=\"-1\">" );

    if( step == 1 )
      msg.append( "Choosing the best attribute to " +
        "split on, based on " + m_splitFun + " criteria." );
    else if( step == 2 )
      msg.append( "Creating new internal node for the best " +
        "attribute, <font color=\"yellow\">" + label +
        "</font>." );
    else if( step == 3 )
      msg.append( "Attached new subtree along branch " +
        "<font color=\"yellow\">" + label + "</font>." );

    return msg.append( "</font>" ).toString();
  }

  /**
   * Computes entropy based on classification counts stored
   * in the stats array.
   *
   * @param stats An array of integers, where each value
   *        indicates the number of examples that fall into
   *        a particular target category.  The size of the
   *        array should be the same as the number of possible
   *        target attribute values.
   *
   * @param numExamples The total number of examples (sum of
   *        all the counts in the stats array).
   *
   * @return The entropy as a double value.
   */
  private double entropy( int[] stats, int numExamples )
  {
    double entropy = 0;

    if( numExamples == 0 ) return 0;

    for( int i = 0; i < stats.length; i++ ) {
      if( stats[i] == 0 ) continue;

      // Unfortunately, the Java math class only
      // supports log{base e}...
      entropy -= ((double)stats[i])/numExamples*
        (Math.log( ((double)stats[i])/numExamples) ) / Math.log( 2.0 );
    }

    return entropy;
  }

  /**
   * Computes the information gain, which is the
   * &quot;expected reduction in entropy caused by
   * paritioning the examples according to [a certain]
   * attribute&quot;.
   *
   * @param stats An array of integers, where each value
   *        indicates the number of examples that fall into
   *        a particular target category.  The size of the
   *        array should be the same as the number of
   *        possible target attribute values.  This array
   *        holds the statistics associated with the number
   *        of examples that reach a certain node/arc in the
   *        tree.
   *
   * @param att An Attribute that can be split on at the
   *        current position in the tree.  The Attribute's
   *        internal storage space should already be
   *        populated with correct statistical information.
   *
   * @return The information gain value for the supplied
   *         attribute.
   */
  private double getGain( int stats[], Attribute att )
  {
    // First, we calculate the entropy of the set
    // of examples before splitting (using the counts
    // in the stats array).
    int numExamples = 0;

    for( int i = 0; i < stats.length; i++ )
      numExamples += stats[i];

    if( numExamples == 0 ) return 0;

    double originalEntropy = entropy( stats, numExamples );

    // Now, we determine the entropy after splitting.
    double splitEntropy = 0;
    int[][] attStats = att.getStatsArray();

    // Loop over all possible values.
    for( int j = 0; j < attStats.length; j++ ) {
      int numSubsetExamples = 0;

      // Determine number of examples along this path.
      for( int k = 0; k < attStats[j].length; k++ )
        numSubsetExamples += attStats[j][k];

        splitEntropy +=
          ((double)numSubsetExamples)/numExamples*
          entropy( attStats[j], numSubsetExamples );
    }

    return originalEntropy - splitEntropy;
  }

  /**
   * Computes the gain ratio for a particular attribute.
   * The gain ratio calculation includes a term called
   * 'split information' that penalizes attributes that
   * split data broadly and uniformly.
   *
   * @param stats An array of integers, where each value
   *        indicates the number of examples that fall into
   *        a particular target category.  The size of the
   *        array should be the same as the number of
   *        possible target attribute values.  This array
   *        holds the statistics associated with the number
   *        of examples that reach a certain node/arc in the
   *        tree.
   *
   * @param att An Attribute that can be split on at the
   *        current position in the tree.  The Attribute's
   *        internal storage space should already be
   *        populated with correct statistical information.
   *
   * @return The gain ratio for the supplied attribute.
   */
  private double getGainRatio( int stats[], Attribute att )
  {
    // We recompute some of the same quantities
    // calculated in the gain method here - it would
    // be more efficient to merge the gain calculation
    // here.
    int numExamples = 0;

    for( int i = 0; i < stats.length; i++ )
      numExamples += stats[i];

    if( numExamples == 0 ) return 0;

    // Compute the gain.
    double gain = getGain( stats, att );

    // Compute the SplitInformation term.
    // (which is the entropy of the examples with
    // respect to the attribute values of att).
    int[] splitInfoStats = new int[ att.getNumValues() ];

    int[][] attStats = att.getStatsArray();

    // Loop over all possible values.
    for( int j = 0; j < attStats.length; j++ ) {
      // Determine number of examples along this path.
      for( int k = 0; k < attStats[j].length; k++ )
        splitInfoStats[j] += attStats[j][k];
    }

    double splitInfo = entropy( splitInfoStats, numExamples );
    return gain/splitInfo;
  }

  /**
   * Computes the GINI score for a particular attribute.
   *
   * @param stats An array of integers, where each value
   *        indicates the number of examples that fall into
   *        a particular target category.  The size of the
   *        array should be the same as the number of
   *        possible target attribute values.  This array
   *        holds the statistics associated with the number
   *        of examples that reach a certain node/arc in the
   *        tree.
   *
   * @param att An Attribute that can be split on at the
   *        current position in the tree.  The Attribute's
   *        internal storage space should already be
   *        populated with correct statistical information.
   *
   * @return the GINI score for the supplied attribute.
   */
  private double getGINI( int stats[], Attribute att )
  {
    // Determine the total number of examples.
    int numExamples = 0;

    for( int i = 0; i < stats.length; i++ )
      numExamples += stats[i];

    if( numExamples == 0 ) return 0;

    double giniScore = 0;
    int[][] attStats = att.getStatsArray();

    // Loop over all possible values.
    for( int j = 0; j < attStats.length; j++ ) {
      int numSubsetExamples = 0;
      int sumOfSquares      = 0;

      // Determine number of examples along this path.
      for( int k = 0; k < attStats[j].length; k++ ) {
        sumOfSquares += attStats[j][k] * attStats[j][k];
        numSubsetExamples += attStats[j][k];
      }

      if( numSubsetExamples != 0 )
        giniScore += ((double)sumOfSquares)/numSubsetExamples;
    }

    // Now, compute the second term in the GINI score.
    for( int l = 0; l < stats.length; l++ )
      giniScore -= ((double)stats[l])*((double)stats[l]) / numExamples;

    // Finally, divide by the total number of examples.
    giniScore /= numExamples;
    return giniScore;
  }

  /**
   * Reduced error pruing error function.
   *
   * @param numExamplesReachParent The number of examples
   *        that reach the parent of a given position in the tree.
   *
   * @param numExamplesIncorrectClass The number of examples
   *        <i>in</i>correctly classified at the given position in
   *        the tree (this has meaning only for leaf nodes).
   */
  private double errorRE( int numExamplesReachParent,
                          int numExamplesIncorrectClass )
  {
    // Include Laplacian correction.
    int total = 1 + numExamplesIncorrectClass;

    return ((double)total)/(numExamplesReachParent + 2);
  }

  /**
   * Returns an error bar value based on the current
   * confidence interval.
   *
   * @param mean The mean value used to calculate the error
   *        bar value.
   *
   * @param size The sample size.
   */
  private double errorBar( double mean, int size )
  {
    return Math.sqrt( mean * (1 - mean) / size ) * m_pessPruneZScore;
  }

  /**
   * Creates a formatted HTML message to display during
   * reduced error pruning.  This is a convenience method -
   * it makes the pruneReducedErrorDT() method less cluttered.
   *
   * @param errCurrent The error produced by the current tree.
   *
   * @param errPrune The error that would result if the tree
   *        was pruned.
   *
   * @return An HTML text string, suitable for display
   *         in a GUI panel.
   */
  private String createPruningMsg( double errCurrent, double errPrune )
  {
    String errCurrentString = Double.toString( errCurrent );

    if( errCurrentString.length() > 5 )
     errCurrentString = errCurrentString.substring( 0, 4 );

    String errPruneString = Double.toString( errPrune );

    if( errPruneString.length() > 5 )
      errPruneString = errPruneString.substring( 0, 4 );

    StringBuffer msg =
      new StringBuffer( "<html><font size=\"-1\">" );

    msg.append( "Current error = " + errCurrentString +
                ", pruning error = " + errPruneString + "." );

    return msg.append( "</font>" ).toString();
  }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -