⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 adtree.java

📁 为了下东西 随便发了个 datamining 的源代码
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    ADTree.java
 *    Copyright (C) 2001 Richard Kirkby, Bernhard Pfahringer
 *
 */

package weka.classifiers.trees;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.IterativeClassifier;
import weka.classifiers.trees.adtree.PredictionNode;
import weka.classifiers.trees.adtree.ReferenceInstances;
import weka.classifiers.trees.adtree.Splitter;
import weka.classifiers.trees.adtree.TwoWayNominalSplit;
import weka.classifiers.trees.adtree.TwoWayNumericSplit;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.SelectedTag;
import weka.core.SerializedObject;
import weka.core.Tag;
import weka.core.UnassignedClassException;
import weka.core.UnsupportedAttributeTypeException;
import weka.core.UnsupportedClassTypeException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/**
 * Class for generating an alternating decision tree. The basic algorithm is based on:<p>
 *
 * Freund, Y., Mason, L.: The alternating decision tree learning algorithm.
 * Proceeding of the Sixteenth International Conference on Machine Learning,
 * Bled, Slovenia, (1999) 124-133.</p>
 *
 * This version currently only supports two-class problems. The number of boosting
 * iterations needs to be manually tuned to suit the dataset and the desired 
 * complexity/accuracy tradeoff. Induction of the trees has been optimized, and heuristic
 * search methods have been introduced to speed learning.<p>
 *
 * Valid options are: <p>
 *
 * -B num <br>
 * Set the number of boosting iterations
 * (default 10) <p>
 *
 * -E num <br>
 * Set the nodes to expand: -3(all), -2(weight), -1(z_pure), >=0 seed for random walk
 * (default -3) <p>
 *
 * -D <br>
 * Save the instance data with the model <p>
 *
 * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
 * @author Bernhard Pfahringer (bernhard@cs.waikato.ac.nz)
 * @version $Revision$
 */
public class ADTree
  extends Classifier implements OptionHandler, Drawable,
				AdditionalMeasureProducer,
				WeightedInstancesHandler,
				IterativeClassifier
{

  /**
   * Returns a string describing classifier
   * @return a description suitable for
   * displaying in the explorer/experimenter gui
   */
  public String globalInfo() {

    return  "Class for generating an alternating decision tree. The basic "
      + "algorithm is based on:\n\n"
      + "Freund, Y., Mason, L.: \"The alternating decision tree learning algorithm\". "
      + "Proceeding of the Sixteenth International Conference on Machine Learning, "
      + "Bled, Slovenia, (1999) 124-133.\n\n"
      + "This version currently only supports two-class problems. The number of boosting "
      + "iterations needs to be manually tuned to suit the dataset and the desired "
      + "complexity/accuracy tradeoff. Induction of the trees has been optimized, and heuristic "
      + "search methods have been introduced to speed learning.";
  }

  /** The search modes */
  public static final int SEARCHPATH_ALL = 0;
  public static final int SEARCHPATH_HEAVIEST = 1;
  public static final int SEARCHPATH_ZPURE = 2;
  public static final int SEARCHPATH_RANDOM = 3;
  public static final Tag [] TAGS_SEARCHPATH = {
    new Tag(SEARCHPATH_ALL, "Expand all paths"),
    new Tag(SEARCHPATH_HEAVIEST, "Expand the heaviest path"),
    new Tag(SEARCHPATH_ZPURE, "Expand the best z-pure path"),
    new Tag(SEARCHPATH_RANDOM, "Expand a random path")
  };

  /** The instances used to train the tree */
  protected Instances m_trainInstances;

  /** The root of the tree */
  protected PredictionNode m_root = null;

  /** The random number generator - used for the random search heuristic */
  protected Random m_random = null; 

  /** The number of the last splitter added to the tree */
  protected int m_lastAddedSplitNum = 0;

  /** An array containing the inidices to the numeric attributes in the data */
  protected int[] m_numericAttIndices;

  /** An array containing the inidices to the nominal attributes in the data */
  protected int[] m_nominalAttIndices;

  /** The total weight of the instances - used to speed Z calculations */
  protected double m_trainTotalWeight;

  /** The training instances with positive class - referencing the training dataset */
  protected ReferenceInstances m_posTrainInstances;

  /** The training instances with negative class - referencing the training dataset */
  protected ReferenceInstances m_negTrainInstances;

  /** The best node to insert under, as found so far by the latest search */
  protected PredictionNode m_search_bestInsertionNode;

  /** The best splitter to insert, as found so far by the latest search */
  protected Splitter m_search_bestSplitter;

  /** The smallest Z value found so far by the latest search */
  protected double m_search_smallestZ;

  /** The positive instances that apply to the best path found so far */
  protected Instances m_search_bestPathPosInstances;

  /** The negative instances that apply to the best path found so far */
  protected Instances m_search_bestPathNegInstances;

  /** Statistics - the number of prediction nodes investigated during search */
  protected int m_nodesExpanded = 0;

  /** Statistics - the number of instances processed during search */
  protected int m_examplesCounted = 0;

  /** Option - the number of boosting iterations o perform */
  protected int m_boostingIterations = 10;

  /** Option - the search mode */
  protected int m_searchPath = 0;

  /** Option - the seed to use for a random search */
  protected int m_randomSeed = 0; 

  /** Option - whether the tree should remember the instance data */
  protected boolean m_saveInstanceData = false; 

  /**
   * Sets up the tree ready to be trained, using two-class optimized method.
   *
   * @param instances the instances to train the tree with
   * @exception Exception if training data is unsuitable
   */
  public void initClassifier(Instances instances) throws Exception {

    // clear stats
    m_nodesExpanded = 0;
    m_examplesCounted = 0;
    m_lastAddedSplitNum = 0;

    // make sure training data is suitable
    if (instances.classIndex() < 0) {
      throw new UnassignedClassException("ADTree: Needs a class to be assigned");
    }
    if (instances.checkForStringAttributes()) {
      throw new UnsupportedAttributeTypeException("ADTree: Can't handle string attributes");
    }
    if (!instances.classAttribute().isNominal()) {
      throw new UnsupportedClassTypeException("ADTree: Class must be nominal");
    }
    if (instances.numClasses() != 2) {
      throw new UnsupportedClassTypeException("ADTree: Must be a two-class problem");
    }

    // prepare the random generator
    m_random = new Random(m_randomSeed);

    // create training set
    m_trainInstances = new Instances(instances);
    m_trainInstances.deleteWithMissingClass();

    // create positive/negative subsets
    m_posTrainInstances = new ReferenceInstances(m_trainInstances,
						 m_trainInstances.numInstances());
    m_negTrainInstances = new ReferenceInstances(m_trainInstances,
						 m_trainInstances.numInstances());
    for (Enumeration e = m_trainInstances.emerateInstances(); e.hasMoreElements(); ) {
      Instance inst = (Instance) e.nextElement();
      if ((int) inst.classValue() == 0)
	m_negTrainInstances.addReference(inst); // belongs in negative class
      else
	m_posTrainInstances.addReference(inst); // belongs in positive class
    }
    m_posTrainInstances.compactify();
    m_negTrainInstances.compactify();

    // create the root prediction node
    double rootPredictionValue = calcPredictionValue(m_posTrainInstances,
						     m_negTrainInstances);
    m_root = new PredictionNode(rootPredictionValue);

    // pre-adjust weights
    updateWeights(m_posTrainInstances, m_negTrainInstances, rootPredictionValue);
    
    // pre-calculate what we can
    generateAttributeIndicesSingle();
  }

  /**
   * Performs one iteration.
   * 
   * @param iteration the index of the current iteration (0-based)
   * @exception Exception if this iteration fails 
   */  
  public void next(int iteration) throws Exception {

    boost();
  }

  /**
   * Performs a single boosting iteration, using two-class optimized method.
   * Will add a new splitter node and two prediction nodes to the tree
   * (unless merging takes place).
   *
   * @exception Exception if try to boost without setting up tree first or there are no 
   * instances to train with
   */
  public void boost() throws Exception {

    if (m_trainInstances == null || m_trainInstances.numInstances() == 0)
      throw new Exception("Trying to boost with no training data");

    // perform the search
    searchForBestTestSingle();

    if (m_search_bestSplitter == null) return; // handle empty instances

    // create the new nodes for the tree, updating the weights
    for (int i=0; i<2; i++) {
      Instances posInstances =
	m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathPosInstances);
      Instances negInstances =
	m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathNegInstances);
      double predictionValue = calcPredictionValue(posInstances, negInstances);
      PredictionNode newPredictor = new PredictionNode(predictionValue);
      updateWeights(posInstances, negInstances, predictionValue);
      m_search_bestSplitter.setChildForBranch(i, newPredictor);
    }

    // insert the new nodes
    m_search_bestInsertionNode.addChild((Splitter) m_search_bestSplitter, this);

    // free memory
    m_search_bestPathPosInstances = null;
    m_search_bestPathNegInstances = null;
    m_search_bestSplitter = null;
  }

  /**
   * Generates the m_nominalAttIndices and m_numericAttIndices arrays to index
   * the respective attribute types in the training data.
   *
   */
  private void generateAttributeIndicesSingle() {

    // insert indices into vectors
    FastVector nominalIndices = new FastVector();
    FastVector numericIndices = new FastVector();

    for (int i=0; i<m_trainInstances.numAttributes(); i++) {
      if (i != m_trainInstances.classIndex()) {
	if (m_trainInstances.attribute(i).isNumeric())
	  numericIndices.addElement(new Integer(i));
	else
	  nominalIndices.addElement(new Integer(i));
      }
    }

    // create nominal array
    m_nominalAttIndices = new int[nominalIndices.size()];
    for (int i=0; i<nominalIndices.size(); i++)
      m_nominalAttIndices[i] = ((Integer)nominalIndices.elementAt(i)).intValue();
    
    // create numeric array
    m_numericAttIndices = new int[numericIndices.size()];
    for (int i=0; i<numericIndices.size(); i++)
      m_numericAttIndices[i] = ((Integer)numericIndices.elementAt(i)).intValue();
  }

  /**
   * Performs a search for the best test (splitter) to add to the tree, by aiming to
   * minimize the Z value.
   *
   * @exception Exception if search fails
   */
  private void searchForBestTestSingle() throws Exception {

    // keep track of total weight for efficient wRemainder calculations
    m_trainTotalWeight = m_trainInstances.sumOfWeights();
    
    m_search_smallestZ = Double.POSITIVE_INFINITY;
    searchForBestTestSingle(m_root, m_posTrainInstances, m_negTrainInstances);
  }

  /**
   * Recursive function that carries out search for the best test (splitter) to add to
   * this part of the tree, by aiming to minimize the Z value. Performs Z-pure cutoff to
   * reduce search space.
   *
   * @param currentNode the root of the subtree to be searched, and the current node 
   * being considered as parent of a new split
   * @param posInstances the positive-class instances that apply at this node
   * @param negInstances the negative-class instances that apply at this node
   * @exception Exception if search fails
   */
  private void searchForBestTestSingle(PredictionNode currentNode,
				       Instances posInstances, Instances negInstances)
    throws Exception {

    // don't investigate pure or empty nodes any further
    if (posInstances.numInstances() == 0 || negInstances.numInstances() == 0) return;

    // do z-pure cutoff
    if (calcZpure(posInstances, negInstances) >= m_search_smallestZ) return;

    // keep stats
    m_nodesExpanded++;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -