📄 adtree.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * ADTree.java * Copyright (C) 2001 Richard Kirkby, Bernhard Pfahringer * */package weka.classifiers.adtree;import weka.classifiers.*;import weka.core.*;import java.io.*;import java.util.*;/** * Class for generating an alternating decision tree. The basic algorithm is based on:<p> * * Freund, Y., Mason, L.: The alternating decision tree learning algorithm. * Proceeding of the Sixteenth International Conference on Machine Learning, * Bled, Slovenia, (1999) 124-133.</p> * * This version currently only supports two-class problems. The number of boosting * iterations needs to be manually tuned to suit the dataset and the desired * complexity/accuracy tradeoff. Induction of the trees has been optimized, and heuristic * search methods have been introduced to speed learning.<p> * * Valid options are: <p> * * -B num <br> * Set the number of boosting iterations * (default 10) <p> * * -E num <br> * Set the nodes to expand: -3(all), -2(weight), -1(z_pure), >=0 seed for random walk * (default -3) <p> * * -D <br> * Save the instance data with the model <p> * * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @author Bernhard Pfahringer (bernhard@cs.waikato.ac.nz) * @version $Revision: 1.1.2.2 $ */public class ADTree extends DistributionClassifier implements OptionHandler, Drawable, AdditionalMeasureProducer, WeightedInstancesHandler, IterativeClassifier{ /** The search modes */ public final static int SEARCHPATH_ALL = 0; public final static int SEARCHPATH_HEAVIEST = 1; public final static int SEARCHPATH_ZPURE = 2; public final static int SEARCHPATH_RANDOM = 3; public static final Tag [] TAGS_SEARCHPATH = { new Tag(SEARCHPATH_ALL, "Expand all paths"), new Tag(SEARCHPATH_HEAVIEST, "Expand the heaviest path"), new Tag(SEARCHPATH_ZPURE, "Expand the best z-pure path"), new Tag(SEARCHPATH_RANDOM, "Expand a random path") }; /** The instances used to train the tree */ protected Instances m_trainInstances; /** The root of the tree */ protected PredictionNode m_root = null; /** The random number generator - used for the random search heuristic */ protected Random m_random = null; /** The number of the last splitter added to the tree */ protected int m_lastAddedSplitNum = 0; /** An array containing the inidices to the numeric attributes in the data */ protected int[] m_numericAttIndices; /** An array containing the inidices to the nominal attributes in the data */ protected int[] m_nominalAttIndices; /** The total weight of the instances - used to speed Z calculations */ protected double m_trainTotalWeight; /** The training instances with positive class - referencing the training dataset */ protected ReferenceInstances m_posTrainInstances; /** The training instances with negative class - referencing the training dataset */ protected ReferenceInstances m_negTrainInstances; /** The best node to insert under, as found so far by the latest search */ protected PredictionNode m_search_bestInsertionNode; /** The best splitter to insert, as found so far by the latest search */ protected Splitter m_search_bestSplitter; /** The smallest Z value found so far by the latest search */ protected double m_search_smallestZ; /** The positive instances that apply to the best path found so far */ protected Instances m_search_bestPathPosInstances; /** The negative instances that apply to the best path found so far */ protected Instances m_search_bestPathNegInstances; /** Statistics - the number of prediction nodes investigated during search */ protected int m_nodesExpanded = 0; /** Statistics - the number of instances processed during search */ protected int m_examplesCounted = 0; /** Option - the number of boosting iterations o perform */ protected int m_boostingIterations = 10; /** Option - the search mode */ protected int m_searchPath = 0; /** Option - the seed to use for a random search */ protected int m_randomSeed = 0; /** Option - whether the tree should remember the instance data */ protected boolean m_saveInstanceData = false; /** * Sets up the tree ready to be trained, using two-class optimized method. * * @param instances the instances to train the tree with * @exception Exception if training data is unsuitable */ public void initClassifier(Instances instances) throws Exception { // clear stats m_nodesExpanded = 0; m_examplesCounted = 0; m_lastAddedSplitNum = 0; // make sure training data is suitable if (instances.classIndex() < 0) { throw new UnassignedClassException("ADTree: Needs a class to be assigned"); } if (instances.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("ADTree: Can't handle string attributes"); } if (!instances.classAttribute().isNominal()) { throw new UnsupportedClassTypeException("ADTree: Class must be nominal"); } if (instances.numClasses() != 2) { throw new UnsupportedClassTypeException("ADTree: Must be a two-class problem"); } // prepare the random generator m_random = new Random(m_randomSeed); // create training set m_trainInstances = new Instances(instances); m_trainInstances.deleteWithMissingClass(); // create positive/negative subsets m_posTrainInstances = new ReferenceInstances(m_trainInstances, m_trainInstances.numInstances()); m_negTrainInstances = new ReferenceInstances(m_trainInstances, m_trainInstances.numInstances()); for (Enumeration e = m_trainInstances.enumerateInstances(); e.hasMoreElements(); ) { Instance inst = (Instance) e.nextElement(); if ((int) inst.classValue() == 0) m_negTrainInstances.addReference(inst); // belongs in negative class else m_posTrainInstances.addReference(inst); // belongs in positive class } m_posTrainInstances.compactify(); m_negTrainInstances.compactify(); // create the root prediction node double rootPredictionValue = calcPredictionValue(m_posTrainInstances, m_negTrainInstances); m_root = new PredictionNode(rootPredictionValue); // pre-adjust weights updateWeights(m_posTrainInstances, m_negTrainInstances, rootPredictionValue); // pre-calculate what we can generateAttributeIndicesSingle(); } /** * Performs one iteration. * * @param iteration the index of the current iteration (0-based) * @exception Exception if this iteration fails */ public void next(int iteration) throws Exception { boost(); } /** * Performs a single boosting iteration, using two-class optimized method. * Will add a new splitter node and two prediction nodes to the tree * (unless merging takes place). * * @exception Exception if try to boost without setting up tree first or there are no * instances to train with */ public void boost() throws Exception { if (m_trainInstances == null || m_trainInstances.numInstances() == 0) throw new Exception("Trying to boost with no training data"); // perform the search searchForBestTestSingle(); if (m_search_bestSplitter == null) return; // handle empty instances // create the new nodes for the tree, updating the weights for (int i=0; i<2; i++) { Instances posInstances = m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathPosInstances); Instances negInstances = m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathNegInstances); double predictionValue = calcPredictionValue(posInstances, negInstances); PredictionNode newPredictor = new PredictionNode(predictionValue); updateWeights(posInstances, negInstances, predictionValue); m_search_bestSplitter.setChildForBranch(i, newPredictor); } // insert the new nodes m_search_bestInsertionNode.addChild((Splitter) m_search_bestSplitter, this); // free memory m_search_bestPathPosInstances = null; m_search_bestPathNegInstances = null; m_search_bestSplitter = null; } /** * Generates the m_nominalAttIndices and m_numericAttIndices arrays to index * the respective attribute types in the training data. * */ private void generateAttributeIndicesSingle() { // insert indices into vectors FastVector nominalIndices = new FastVector(); FastVector numericIndices = new FastVector(); for (int i=0; i<m_trainInstances.numAttributes(); i++) { if (i != m_trainInstances.classIndex()) { if (m_trainInstances.attribute(i).isNumeric()) numericIndices.addElement(new Integer(i)); else nominalIndices.addElement(new Integer(i)); } } // create nominal array m_nominalAttIndices = new int[nominalIndices.size()]; for (int i=0; i<nominalIndices.size(); i++) m_nominalAttIndices[i] = ((Integer)nominalIndices.elementAt(i)).intValue(); // create numeric array m_numericAttIndices = new int[numericIndices.size()]; for (int i=0; i<numericIndices.size(); i++) m_numericAttIndices[i] = ((Integer)numericIndices.elementAt(i)).intValue(); } /** * Performs a search for the best test (splitter) to add to the tree, by aiming to * minimize the Z value. * * @exception Exception if search fails */ private void searchForBestTestSingle() throws Exception { // keep track of total weight for efficient wRemainder calculations m_trainTotalWeight = m_trainInstances.sumOfWeights(); m_search_smallestZ = Double.POSITIVE_INFINITY; searchForBestTestSingle(m_root, m_posTrainInstances, m_negTrainInstances); } /** * Recursive function that carries out search for the best test (splitter) to add to * this part of the tree, by aiming to minimize the Z value. Performs Z-pure cutoff to * reduce search space. * * @param currentNode the root of the subtree to be searched, and the current node * being considered as parent of a new split * @param posInstances the positive-class instances that apply at this node * @param negInstances the negative-class instances that apply at this node * @exception Exception if search fails */ private void searchForBestTestSingle(PredictionNode currentNode, Instances posInstances, Instances negInstances) throws Exception { // don't investigate pure or empty nodes any further if (posInstances.numInstances() == 0 || negInstances.numInstances() == 0) return; // do z-pure cutoff if (calcZpure(posInstances, negInstances) >= m_search_smallestZ) return; // keep stats m_nodesExpanded++; m_examplesCounted += posInstances.numInstances() + negInstances.numInstances(); // evaluate static splitters (nominal) for (int i=0; i<m_nominalAttIndices.length; i++) evaluateNominalSplitSingle(m_nominalAttIndices[i], currentNode, posInstances, negInstances); // evaluate dynamic splitters (numeric) if (m_numericAttIndices.length > 0) { // merge the two sets of instances into one Instances allInstances = new Instances(posInstances); for (Enumeration e = negInstances.enumerateInstances(); e.hasMoreElements(); ) allInstances.add((Instance) e.nextElement()); // use method of finding the optimal Z split-point for (int i=0; i<m_numericAttIndices.length; i++) evaluateNumericSplitSingle(m_numericAttIndices[i], currentNode, posInstances, negInstances, allInstances); } if (currentNode.getChildren().size() == 0) return; // keep searching switch (m_searchPath) { case SEARCHPATH_ALL: goDownAllPathsSingle(currentNode, posInstances, negInstances); break; case SEARCHPATH_HEAVIEST: goDownHeaviestPathSingle(currentNode, posInstances, negInstances); break; case SEARCHPATH_ZPURE: goDownZpurePathSingle(currentNode, posInstances, negInstances);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -