⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 decisiontreealgorithm.java

📁 一个决策树的Applet(转载
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
    i = m_algorithmListeners.iterator();

    while( i.hasNext() )
      ((AlgorithmListener)i.next()).notifyAlgorithmFinished();

    if( m_manager.getStatusBar() != null )
      m_manager.getStatusBar().postMessage( "Algorithm finished." );
  }

  /**
   * An implementation of the recursive decision tree
   * learning algorithm.  Given a parent node and an arc
   * number, the method will attach a new decision 'sub'-tree
   * below the parent node.
   *
   * @param parent The parent node for the new decision tree.
   *
   * @param arcNum The arc number (or path) along which the
   *        new subtree will be attached.
   *
   * @return true if an entire subtree was successfully added,
   *         false otherwise.
   */
  public boolean learnDT( DecisionTreeNode parent, int arcNum )
  {
    AttributeMask mask;
    Iterator i;

    if( parent == null )
      // We have to add at the root.
      mask = new AttributeMask( m_dataset.getNumAttributes() );
    else {
      mask = new AttributeMask( parent.getMask() );

      // Mask off the specified arc number.
      try {
        mask.mask(
          m_dataset.getAttributePosition( parent.getLabel() ), arcNum );
      }
      catch( NonexistentAttributeException e ) {
        e.printStackTrace();
        return false;
      }
    }

    //----------------- Breakpoint -----------------
    if( !handleBreakpoint( 1, null ) ) return false;

    // Now, classify the examples at the current position.
    int[] conclusion = new int[8];
    int   result     = classifyExamples( mask, conclusion, null, null, null );

    Attribute target  = m_dataset.getTargetAttribute();
    int numTargetVals = target.getNumValues();
    String label;

    //----------------- Breakpoint -----------------
    if( !handleBreakpoint( 2, null ) ) return false;

    if( result == DATASET_EMPTY ) {
      // If no examples reach our current position
      // we add a leaf with the most common target
      // classfication for the parent node.

      // Save testing results.
      int numTestingExamplesReachHere     =  conclusion[5];
      int bestTestingTargetIndex          =  conclusion[4];
      int numTestingExamplesCorrectClass  =  conclusion[6];
      int numTrainingExamplesCorrectClass =  conclusion[7];

      classifyExamples( parent.getMask(), conclusion, null, null, null );

      try {
        label = target.getAttributeValueByNum( conclusion[0] );
      }
      catch( NonexistentAttributeValueException e ) {
        e.printStackTrace();
        return false;
      }

      //--------------------- Debug ---------------------
      if( DEBUG_ON ) {
        System.out.println();
        System.out.println( "DecisionTreeAlgorithm::learnDT: " +
          "No examples reach the current position." );
      }

      // We have to grab the counts again for the testing data...
      int[] currTestingCounts = new int[ target.getNumValues() ];
      getExampleCounts( mask,
        m_dataset.getTestingExamples(), currTestingCounts, null );      

      // Mask target value and add a leaf to the tree.
      mask.mask( 0, conclusion[0] );

      DecisionTreeNode node =
        m_tree.addLeafNode( parent,
                            arcNum,
                            label,
                            mask,
                            0,
                            conclusion[0],
                            0,
                            currTestingCounts[conclusion[0]],
                            numTestingExamplesReachHere,
                            bestTestingTargetIndex,
                            numTestingExamplesCorrectClass,
                            numTrainingExamplesCorrectClass );

      i = m_algorithmListeners.iterator();

      if( m_verboseFlag )
        while( i.hasNext() )
          ((AlgorithmListener)
            i.next()).notifyAlgorithmStepStart( createLeafMsg( 1, label ) );
      else
        while( i.hasNext() )
          ((AlgorithmListener)
            i.next()).notifyAlgorithmStepStart();

      return true;
    }

    //----------------- Breakpoint -----------------
    if( !handleBreakpoint( 3, null ) ) return false;

    if( result == DATASET_IDENT_CONCL ) {
      // Pure result - we can add a leaf node with the
      // correct target attribute value.
      try {
        label = target.getAttributeValueByNum( conclusion[0] );
      }
      catch( NonexistentAttributeValueException e ) {
        e.printStackTrace();
        return false;
      }

      //--------------------- Debug ---------------------
      if( DEBUG_ON ) {
        System.out.println();
        System.out.println( "DecisionTreeAlgorithm::learnDT: " +
          "All examples at the current position " +
          "have the same target class '" + label + "'." );
      }

      // Mask target value and add a leaf to the tree.
      mask.mask( 0, conclusion[0] );

      DecisionTreeNode node =
        m_tree.addLeafNode( parent,
                            arcNum,
                            label,
                            mask,
                            conclusion[1],
                            conclusion[0],
                            conclusion[2],
                            conclusion[3],
                            conclusion[5],
                            conclusion[4],
                            conclusion[6],
                            conclusion[7] );

      i = m_algorithmListeners.iterator();

      if( m_verboseFlag )
        while( i.hasNext() )
          ((AlgorithmListener)
            i.next()).notifyAlgorithmStepStart( createLeafMsg( 2, label ) );
      else
        while( i.hasNext() )
          ((AlgorithmListener)
            i.next()).notifyAlgorithmStepStart();

      return true;
    }

    // Mixed conclusion - so we have to select
    // an attribute to split on, and then build a
    // new internal node with that attribute.

    // First, generate statistics - this may take awhile.
    int[]  nodeStats = new int[ numTargetVals ];
    Vector availableAtts = generateStats( mask, nodeStats );

    //----------------- Breakpoint -----------------
    if( !handleBreakpoint( 4, null ) ) return false;

    if( availableAtts.size() == 0 ) {
      // No attributes left to split on - so use
      // the most common target value at the current position.
      try {
        label = target.getAttributeValueByNum( conclusion[0] );
      }
      catch( NonexistentAttributeValueException e ) {
        e.printStackTrace();
        return false;
      }

      //--------------------- Debug ---------------------
      if( DEBUG_ON ) {
        System.out.println();
        System.out.println( "DecisionTreeAlgorithm::learnDT: " +
          "No attributes left to split on at current position." );
      }

      mask.mask( 0, conclusion[0] );

      DecisionTreeNode node =
        m_tree.addLeafNode( parent,
                            arcNum,
                            label,
                            mask,
                            conclusion[1],
                            conclusion[0],
                            conclusion[2],
                            conclusion[3],
                            conclusion[5],
                            conclusion[4],
                            conclusion[6],
                            conclusion[7] );

      i = m_algorithmListeners.iterator();

      if( m_verboseFlag )
        while( i.hasNext() )
          ((AlgorithmListener)
            i.next()).notifyAlgorithmStepStart( createLeafMsg( 3, label ) );
      else
        while( i.hasNext() )
          ((AlgorithmListener)
            i.next()).notifyAlgorithmStepStart();

      return true;
    }

    // Choose an attribute, based on the set of
    // available attributes.
    Vector results = new Vector();
    Attribute att  = chooseAttribute( availableAtts, nodeStats, results );

    //--------------------- Debug ---------------------
    if( DEBUG_ON ) {
      System.out.println();
      System.out.println( "DecisionTreeAlgorithm::learnDT: " +
        "Preparing to split..." );
      System.out.println();
      System.out.println( "Available attributes and " +
        "associated " + m_splitFun + " values:" );

      for( int j = 0; j < availableAtts.size(); j++ ) {
        System.out.print(
          ((Attribute)availableAtts.elementAt( j )).getName() );

        if( j < results.size() ) {
          System.out.print( " - " );
          System.out.println( (Double)results.elementAt(j) );
        }
        else {
          System.out.println();
        }
      }

      System.out.println();
      System.out.println( "Selected " + att.getName() + "." );
      System.out.println();
    }

    //----------------- Breakpoint -----------------
    if( !handleBreakpoint( 5, createInternalMsg(1, att.getName() )) )
      return false;

    //----------------- Breakpoint -----------------
    if( !handleBreakpoint( 6, createInternalMsg(2, att.getName() )) )
      return false;

    int attPos;

    try {
      attPos = m_dataset.getAttributePosition( att.getName() );
    }
    catch( NonexistentAttributeException e ) {
      e.printStackTrace();
      return false;
    }

    DecisionTreeNode newParent =
      m_tree.addInternalNode( parent,
                              arcNum,
                              attPos,
                              att,
                              mask,
                              conclusion[1],
                              conclusion[0],
                              conclusion[2],
                              conclusion[3],
                              conclusion[5],
                              conclusion[4],
                              conclusion[6],
                              conclusion[7] );

    // Now, recursively decend along each branch of the new node.
    for( int j = 0; j < newParent.getArcLabelCount(); j++ ) {
      //----------------- Breakpoint -----------------
      if( !handleBreakpoint( 7, null ) ) return false;

      //----------------- Breakpoint -----------------
      if( !handleBreakpoint( 8, null ) ) return false;

      //----------------- Breakpoint -----------------
      if( !handleBreakpoint( 9, null ) ) return false;

      //--------------------- Debug ---------------------
      if( DEBUG_ON ) {
        System.out.println();
        System.out.println( "DecisionTreeAlgorithm::learnDT: " +
          "Descending along branch " + (j+1) + " of " +
          newParent.getArcLabelCount() + "." );
      }

      // Recursive call.
      if( !learnDT( newParent, j ) ) return false;

      //----------------- Breakpoint -----------------
      if( !handleBreakpoint( 10, null ) ) return false;

      //----------------- Breakpoint -----------------
      if( !handleBreakpoint( 11,
        createInternalMsg(3, newParent.getArcLabel(j)) ) )
        return false;
    }

    //----------------- Breakpoint -----------------
    if( !handleBreakpoint( 12, null ) ) return false;

    return true;
  }

  /**
   * An implementation of the recursive decision tree
   * reduced error pruning algorithm.  Given a parent
   * node, the method will prune all the branches
   * below the node.
   *
   * @param node The root node of the tree to prune.
   *
   * @param error A <code>double</code> array of size 1. The
   *        array is used to store the current error value.
   *
   * @return <code>true</code> if an entire subtree was successfully
   *         pruned, or <code>false</code> otherwise.
   */
  public boolean pruneReducedErrorDT( DecisionTreeNode node, double[] error )
  {
    // Post-order walk through the tree, marking
    // our path as we go along.
    //----------------- Breakpoint -----------------
    if( !handleBreakpoint( 1, null ) ) return false;

    //----------------- Breakpoint -----------------
    if( !handleBreakpoint( 2, null ) ) return false;

    if( node.isLeaf() ) {
      error[0] = node.getTestingEgsAtNode() -
        node.getTestingEgsCorrectClassUsingBestTrainingIndex();
      return true;
    }

    //----------------- Breakpoint -----------------
    if( !handleBreakpoint( 3, null ) ) return false;

    // We're at an internal node, so compute the error
    // of the children and use the result to determine
    // if we prune or not.
    double errorSum = 0;

    for( int i = 0; i < node.getArcLabelCount(); i++ ) {
      // Mark our current path.
      m_tree.flagNode( node, i );

      //----------------- Breakpoint -----------------
      if( !handleBreakpoint( 4, null ) ) return false;

      //----------------- Breakpoint -----------------
      if( !handleBreakpoint( 5, null ) ) return false;

      //----------------- Breakpoint -----------------
      if( !handleBreakpoint( 6, null ) ) return false;

      if( !pruneReducedErrorDT( node.getChild( i ), error ) ) {
        m_tree.flagNode( node, -2 );
        return false;
      }

      errorSum += error[0];
    }

    // Mark the node as our current target.
    m_tree.flagNode( node, -1 );

    //----------------- Breakpoint -----------------
    if( !handleBreakpoint( 8, null ) ) return false;

    // Get the best-case performance of this node.
    double errorBest = node.getTestingEgsAtNode() -
      node.getTestingEgsCorrectClassUsingBestTestingIndex();

    DecisionTreeNode newNode = node;

    //----------------- Breakpoint -----------------
    if( !handleBreakpoint( 9,
      createPruningMsg( errorSum, errorBest ) ) ) return false;

    if( errorBest < errorSum ) {
      // We need to "prune" this node to a leaf.
      DecisionTreeNode parent = node.getParent();
      int arcNum = -1;

      if( parent != null )
        arcNum = parent.getChildPosition( node );

      //--------------------- Debug ---------------------
      if( DEBUG_ON ) {
        System.out.println();
        System.out.println( "DecisionTreeAlgorithm::pruneReduceErrorDT: " +
                            " Pruning node " + node.getLabel() + "." );
      }

      m_tree.pruneSubtree( node );

      // Figure out the label for the new leaf.
      String label = null;

      try {
        label =
          m_dataset.getTargetAttribute().getAttributeValueByNum(
            node.getTestingBestTarget() );
      }
      catch( NonexistentAttributeValueException e ) {
        // Should never happen.
        e.printStackTrace();
      }

      node.getMask().mask( 0, node.getTestingBestTarget() );

      newNode =
        m_tree.addLeafNode( parent, arcNum, label,
          node.getMask(),
          node.getTrainingEgsAtNode(),
          node.getTestingBestTarget(),
          node.getTrainingEgsCorrectClassUsingBestTestingIndex(),
          node.getTestingEgsCorrectClassUsingBestTestingIndex(),
          node.getTestingEgsAtNode(),
          node.getTestingBestTarget(),
          node.getTestingEgsCorrectClassUsingBestTestingIndex(),

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -