📄 adtree.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* ADTree.java
* Copyright (C) 2001 Richard Kirkby, Bernhard Pfahringer
*
*/
package weka.classifiers.trees;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.IterativeClassifier;
import weka.classifiers.trees.adtree.PredictionNode;
import weka.classifiers.trees.adtree.ReferenceInstances;
import weka.classifiers.trees.adtree.Splitter;
import weka.classifiers.trees.adtree.TwoWayNominalSplit;
import weka.classifiers.trees.adtree.TwoWayNumericSplit;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.SelectedTag;
import weka.core.SerializedObject;
import weka.core.Tag;
import weka.core.UnassignedClassException;
import weka.core.UnsupportedAttributeTypeException;
import weka.core.UnsupportedClassTypeException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
/**
* Class for generating an alternating decision tree. The basic algorithm is based on:<p>
*
* Freund, Y., Mason, L.: The alternating decision tree learning algorithm.
* Proceeding of the Sixteenth International Conference on Machine Learning,
* Bled, Slovenia, (1999) 124-133.</p>
*
* This version currently only supports two-class problems. The number of boosting
* iterations needs to be manually tuned to suit the dataset and the desired
* complexity/accuracy tradeoff. Induction of the trees has been optimized, and heuristic
* search methods have been introduced to speed learning.<p>
*
* Valid options are: <p>
*
* -B num <br>
* Set the number of boosting iterations
* (default 10) <p>
*
* -E num <br>
* Set the nodes to expand: -3(all), -2(weight), -1(z_pure), >=0 seed for random walk
* (default -3) <p>
*
* -D <br>
* Save the instance data with the model <p>
*
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @author Bernhard Pfahringer (bernhard@cs.waikato.ac.nz)
* @version $Revision$
*/
public class ADTree
extends Classifier implements OptionHandler, Drawable,
AdditionalMeasureProducer,
WeightedInstancesHandler,
IterativeClassifier
{
/**
* Returns a string describing classifier
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Class for generating an alternating decision tree. The basic "
+ "algorithm is based on:\n\n"
+ "Freund, Y., Mason, L.: \"The alternating decision tree learning algorithm\". "
+ "Proceeding of the Sixteenth International Conference on Machine Learning, "
+ "Bled, Slovenia, (1999) 124-133.\n\n"
+ "This version currently only supports two-class problems. The number of boosting "
+ "iterations needs to be manually tuned to suit the dataset and the desired "
+ "complexity/accuracy tradeoff. Induction of the trees has been optimized, and heuristic "
+ "search methods have been introduced to speed learning.";
}
/** The search modes */
public static final int SEARCHPATH_ALL = 0;
public static final int SEARCHPATH_HEAVIEST = 1;
public static final int SEARCHPATH_ZPURE = 2;
public static final int SEARCHPATH_RANDOM = 3;
public static final Tag [] TAGS_SEARCHPATH = {
new Tag(SEARCHPATH_ALL, "Expand all paths"),
new Tag(SEARCHPATH_HEAVIEST, "Expand the heaviest path"),
new Tag(SEARCHPATH_ZPURE, "Expand the best z-pure path"),
new Tag(SEARCHPATH_RANDOM, "Expand a random path")
};
/** The instances used to train the tree */
protected Instances m_trainInstances;
/** The root of the tree */
protected PredictionNode m_root = null;
/** The random number generator - used for the random search heuristic */
protected Random m_random = null;
/** The number of the last splitter added to the tree */
protected int m_lastAddedSplitNum = 0;
/** An array containing the inidices to the numeric attributes in the data */
protected int[] m_numericAttIndices;
/** An array containing the inidices to the nominal attributes in the data */
protected int[] m_nominalAttIndices;
/** The total weight of the instances - used to speed Z calculations */
protected double m_trainTotalWeight;
/** The training instances with positive class - referencing the training dataset */
protected ReferenceInstances m_posTrainInstances;
/** The training instances with negative class - referencing the training dataset */
protected ReferenceInstances m_negTrainInstances;
/** The best node to insert under, as found so far by the latest search */
protected PredictionNode m_search_bestInsertionNode;
/** The best splitter to insert, as found so far by the latest search */
protected Splitter m_search_bestSplitter;
/** The smallest Z value found so far by the latest search */
protected double m_search_smallestZ;
/** The positive instances that apply to the best path found so far */
protected Instances m_search_bestPathPosInstances;
/** The negative instances that apply to the best path found so far */
protected Instances m_search_bestPathNegInstances;
/** Statistics - the number of prediction nodes investigated during search */
protected int m_nodesExpanded = 0;
/** Statistics - the number of instances processed during search */
protected int m_examplesCounted = 0;
/** Option - the number of boosting iterations o perform */
protected int m_boostingIterations = 10;
/** Option - the search mode */
protected int m_searchPath = 0;
/** Option - the seed to use for a random search */
protected int m_randomSeed = 0;
/** Option - whether the tree should remember the instance data */
protected boolean m_saveInstanceData = false;
/**
* Sets up the tree ready to be trained, using two-class optimized method.
*
* @param instances the instances to train the tree with
* @exception Exception if training data is unsuitable
*/
public void initClassifier(Instances instances) throws Exception {
// clear stats
m_nodesExpanded = 0;
m_examplesCounted = 0;
m_lastAddedSplitNum = 0;
// make sure training data is suitable
if (instances.classIndex() < 0) {
throw new UnassignedClassException("ADTree: Needs a class to be assigned");
}
if (instances.checkForStringAttributes()) {
throw new UnsupportedAttributeTypeException("ADTree: Can't handle string attributes");
}
if (!instances.classAttribute().isNominal()) {
throw new UnsupportedClassTypeException("ADTree: Class must be nominal");
}
if (instances.numClasses() != 2) {
throw new UnsupportedClassTypeException("ADTree: Must be a two-class problem");
}
// prepare the random generator
m_random = new Random(m_randomSeed);
// create training set
m_trainInstances = new Instances(instances);
m_trainInstances.deleteWithMissingClass();
// create positive/negative subsets
m_posTrainInstances = new ReferenceInstances(m_trainInstances,
m_trainInstances.numInstances());
m_negTrainInstances = new ReferenceInstances(m_trainInstances,
m_trainInstances.numInstances());
for (Enumeration e = m_trainInstances.emerateInstances(); e.hasMoreElements(); ) {
Instance inst = (Instance) e.nextElement();
if ((int) inst.classValue() == 0)
m_negTrainInstances.addReference(inst); // belongs in negative class
else
m_posTrainInstances.addReference(inst); // belongs in positive class
}
m_posTrainInstances.compactify();
m_negTrainInstances.compactify();
// create the root prediction node
double rootPredictionValue = calcPredictionValue(m_posTrainInstances,
m_negTrainInstances);
m_root = new PredictionNode(rootPredictionValue);
// pre-adjust weights
updateWeights(m_posTrainInstances, m_negTrainInstances, rootPredictionValue);
// pre-calculate what we can
generateAttributeIndicesSingle();
}
/**
* Performs one iteration.
*
* @param iteration the index of the current iteration (0-based)
* @exception Exception if this iteration fails
*/
public void next(int iteration) throws Exception {
boost();
}
/**
* Performs a single boosting iteration, using two-class optimized method.
* Will add a new splitter node and two prediction nodes to the tree
* (unless merging takes place).
*
* @exception Exception if try to boost without setting up tree first or there are no
* instances to train with
*/
public void boost() throws Exception {
if (m_trainInstances == null || m_trainInstances.numInstances() == 0)
throw new Exception("Trying to boost with no training data");
// perform the search
searchForBestTestSingle();
if (m_search_bestSplitter == null) return; // handle empty instances
// create the new nodes for the tree, updating the weights
for (int i=0; i<2; i++) {
Instances posInstances =
m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathPosInstances);
Instances negInstances =
m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathNegInstances);
double predictionValue = calcPredictionValue(posInstances, negInstances);
PredictionNode newPredictor = new PredictionNode(predictionValue);
updateWeights(posInstances, negInstances, predictionValue);
m_search_bestSplitter.setChildForBranch(i, newPredictor);
}
// insert the new nodes
m_search_bestInsertionNode.addChild((Splitter) m_search_bestSplitter, this);
// free memory
m_search_bestPathPosInstances = null;
m_search_bestPathNegInstances = null;
m_search_bestSplitter = null;
}
/**
* Generates the m_nominalAttIndices and m_numericAttIndices arrays to index
* the respective attribute types in the training data.
*
*/
private void generateAttributeIndicesSingle() {
// insert indices into vectors
FastVector nominalIndices = new FastVector();
FastVector numericIndices = new FastVector();
for (int i=0; i<m_trainInstances.numAttributes(); i++) {
if (i != m_trainInstances.classIndex()) {
if (m_trainInstances.attribute(i).isNumeric())
numericIndices.addElement(new Integer(i));
else
nominalIndices.addElement(new Integer(i));
}
}
// create nominal array
m_nominalAttIndices = new int[nominalIndices.size()];
for (int i=0; i<nominalIndices.size(); i++)
m_nominalAttIndices[i] = ((Integer)nominalIndices.elementAt(i)).intValue();
// create numeric array
m_numericAttIndices = new int[numericIndices.size()];
for (int i=0; i<numericIndices.size(); i++)
m_numericAttIndices[i] = ((Integer)numericIndices.elementAt(i)).intValue();
}
/**
* Performs a search for the best test (splitter) to add to the tree, by aiming to
* minimize the Z value.
*
* @exception Exception if search fails
*/
private void searchForBestTestSingle() throws Exception {
// keep track of total weight for efficient wRemainder calculations
m_trainTotalWeight = m_trainInstances.sumOfWeights();
m_search_smallestZ = Double.POSITIVE_INFINITY;
searchForBestTestSingle(m_root, m_posTrainInstances, m_negTrainInstances);
}
/**
* Recursive function that carries out search for the best test (splitter) to add to
* this part of the tree, by aiming to minimize the Z value. Performs Z-pure cutoff to
* reduce search space.
*
* @param currentNode the root of the subtree to be searched, and the current node
* being considered as parent of a new split
* @param posInstances the positive-class instances that apply at this node
* @param negInstances the negative-class instances that apply at this node
* @exception Exception if search fails
*/
private void searchForBestTestSingle(PredictionNode currentNode,
Instances posInstances, Instances negInstances)
throws Exception {
// don't investigate pure or empty nodes any further
if (posInstances.numInstances() == 0 || negInstances.numInstances() == 0) return;
// do z-pure cutoff
if (calcZpure(posInstances, negInstances) >= m_search_smallestZ) return;
// keep stats
m_nodesExpanded++;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -