📄 adtree.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * ADTree.java * Copyright (C) 2001 Richard Kirkby, Bernhard Pfahringer * */package weka.classifiers.trees;import weka.classifiers.Classifier;import weka.classifiers.IterativeClassifier;import weka.classifiers.trees.adtree.PredictionNode;import weka.classifiers.trees.adtree.ReferenceInstances;import weka.classifiers.trees.adtree.Splitter;import weka.classifiers.trees.adtree.TwoWayNominalSplit;import weka.classifiers.trees.adtree.TwoWayNumericSplit;import weka.core.AdditionalMeasureProducer;import weka.core.Attribute;import weka.core.Capabilities;import weka.core.Drawable;import weka.core.FastVector;import weka.core.Instance;import weka.core.Instances;import weka.core.Option;import weka.core.OptionHandler;import weka.core.SelectedTag;import weka.core.SerializedObject;import weka.core.Tag;import weka.core.TechnicalInformation;import weka.core.TechnicalInformationHandler;import weka.core.Utils;import weka.core.WeightedInstancesHandler;import weka.core.Capabilities.Capability;import weka.core.TechnicalInformation.Field;import weka.core.TechnicalInformation.Type;import java.util.Enumeration;import java.util.Random;import java.util.Vector;/** <!-- globalinfo-start --> * Class for generating an alternating decision tree. The basic algorithm is based on:<br/> * <br/> * Freund, Y., Mason, L.: The alternating decision tree learning algorithm. In: Proceeding of the Sixteenth International Conference on Machine Learning, Bled, Slovenia, 124-133, 1999.<br/> * <br/> * This version currently only supports two-class problems. The number of boosting iterations needs to be manually tuned to suit the dataset and the desired complexity/accuracy tradeoff. Induction of the trees has been optimized, and heuristic search methods have been introduced to speed learning. * <p/> <!-- globalinfo-end --> * <!-- technical-bibtex-start --> * BibTeX: * <pre> * @inproceedings{Freund1999, * address = {Bled, Slovenia}, * author = {Freund, Y. and Mason, L.}, * booktitle = {Proceeding of the Sixteenth International Conference on Machine Learning}, * pages = {124-133}, * title = {The alternating decision tree learning algorithm}, * year = {1999} * } * </pre> * <p/> <!-- technical-bibtex-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -B <number of boosting iterations> * Number of boosting iterations. * (Default = 10)</pre> * * <pre> -E <-3|-2|-1|>=0> * Expand nodes: -3(all), -2(weight), -1(z_pure), >=0 seed for random walk * (Default = -3)</pre> * * <pre> -D * Save the instance data with the model</pre> * <!-- options-end --> * * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @author Bernhard Pfahringer (bernhard@cs.waikato.ac.nz) * @version $Revision: 1.6 $ */public class ADTree extends Classifier implements OptionHandler, Drawable, AdditionalMeasureProducer, WeightedInstancesHandler, IterativeClassifier, TechnicalInformationHandler { /** for serialization */ static final long serialVersionUID = -1532264837167690683L; /** * Returns a string describing classifier * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Class for generating an alternating decision tree. The basic " + "algorithm is based on:\n\n" + getTechnicalInformation().toString() + "\n\n" + "This version currently only supports two-class problems. The number of boosting " + "iterations needs to be manually tuned to suit the dataset and the desired " + "complexity/accuracy tradeoff. Induction of the trees has been optimized, and heuristic " + "search methods have been introduced to speed learning."; } /** search mode: Expand all paths */ public static final int SEARCHPATH_ALL = 0; /** search mode: Expand the heaviest path */ public static final int SEARCHPATH_HEAVIEST = 1; /** search mode: Expand the best z-pure path */ public static final int SEARCHPATH_ZPURE = 2; /** search mode: Expand a random path */ public static final int SEARCHPATH_RANDOM = 3; /** The search modes */ public static final Tag [] TAGS_SEARCHPATH = { new Tag(SEARCHPATH_ALL, "Expand all paths"), new Tag(SEARCHPATH_HEAVIEST, "Expand the heaviest path"), new Tag(SEARCHPATH_ZPURE, "Expand the best z-pure path"), new Tag(SEARCHPATH_RANDOM, "Expand a random path") }; /** The instances used to train the tree */ protected Instances m_trainInstances; /** The root of the tree */ protected PredictionNode m_root = null; /** The random number generator - used for the random search heuristic */ protected Random m_random = null; /** The number of the last splitter added to the tree */ protected int m_lastAddedSplitNum = 0; /** An array containing the inidices to the numeric attributes in the data */ protected int[] m_numericAttIndices; /** An array containing the inidices to the nominal attributes in the data */ protected int[] m_nominalAttIndices; /** The total weight of the instances - used to speed Z calculations */ protected double m_trainTotalWeight; /** The training instances with positive class - referencing the training dataset */ protected ReferenceInstances m_posTrainInstances; /** The training instances with negative class - referencing the training dataset */ protected ReferenceInstances m_negTrainInstances; /** The best node to insert under, as found so far by the latest search */ protected PredictionNode m_search_bestInsertionNode; /** The best splitter to insert, as found so far by the latest search */ protected Splitter m_search_bestSplitter; /** The smallest Z value found so far by the latest search */ protected double m_search_smallestZ; /** The positive instances that apply to the best path found so far */ protected Instances m_search_bestPathPosInstances; /** The negative instances that apply to the best path found so far */ protected Instances m_search_bestPathNegInstances; /** Statistics - the number of prediction nodes investigated during search */ protected int m_nodesExpanded = 0; /** Statistics - the number of instances processed during search */ protected int m_examplesCounted = 0; /** Option - the number of boosting iterations o perform */ protected int m_boostingIterations = 10; /** Option - the search mode */ protected int m_searchPath = 0; /** Option - the seed to use for a random search */ protected int m_randomSeed = 0; /** Option - whether the tree should remember the instance data */ protected boolean m_saveInstanceData = false; /** * Returns an instance of a TechnicalInformation object, containing * detailed information about the technical background of this class, * e.g., paper reference or book this class is based on. * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.INPROCEEDINGS); result.setValue(Field.AUTHOR, "Freund, Y. and Mason, L."); result.setValue(Field.YEAR, "1999"); result.setValue(Field.TITLE, "The alternating decision tree learning algorithm"); result.setValue(Field.BOOKTITLE, "Proceeding of the Sixteenth International Conference on Machine Learning"); result.setValue(Field.ADDRESS, "Bled, Slovenia"); result.setValue(Field.PAGES, "124-133"); return result; } /** * Sets up the tree ready to be trained, using two-class optimized method. * * @param instances the instances to train the tree with * @exception Exception if training data is unsuitable */ public void initClassifier(Instances instances) throws Exception { // clear stats m_nodesExpanded = 0; m_examplesCounted = 0; m_lastAddedSplitNum = 0; // prepare the random generator m_random = new Random(m_randomSeed); // create training set m_trainInstances = new Instances(instances); // create positive/negative subsets m_posTrainInstances = new ReferenceInstances(m_trainInstances, m_trainInstances.numInstances()); m_negTrainInstances = new ReferenceInstances(m_trainInstances, m_trainInstances.numInstances()); for (Enumeration e = m_trainInstances.enumerateInstances(); e.hasMoreElements(); ) { Instance inst = (Instance) e.nextElement(); if ((int) inst.classValue() == 0) m_negTrainInstances.addReference(inst); // belongs in negative class else m_posTrainInstances.addReference(inst); // belongs in positive class } m_posTrainInstances.compactify(); m_negTrainInstances.compactify(); // create the root prediction node double rootPredictionValue = calcPredictionValue(m_posTrainInstances, m_negTrainInstances); m_root = new PredictionNode(rootPredictionValue); // pre-adjust weights updateWeights(m_posTrainInstances, m_negTrainInstances, rootPredictionValue); // pre-calculate what we can generateAttributeIndicesSingle(); } /** * Performs one iteration. * * @param iteration the index of the current iteration (0-based) * @exception Exception if this iteration fails */ public void next(int iteration) throws Exception { boost(); } /** * Performs a single boosting iteration, using two-class optimized method. * Will add a new splitter node and two prediction nodes to the tree * (unless merging takes place). * * @exception Exception if try to boost without setting up tree first or there are no * instances to train with */ public void boost() throws Exception { if (m_trainInstances == null || m_trainInstances.numInstances() == 0) throw new Exception("Trying to boost with no training data"); // perform the search searchForBestTestSingle(); if (m_search_bestSplitter == null) return; // handle empty instances // create the new nodes for the tree, updating the weights for (int i=0; i<2; i++) { Instances posInstances = m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathPosInstances); Instances negInstances = m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathNegInstances); double predictionValue = calcPredictionValue(posInstances, negInstances); PredictionNode newPredictor = new PredictionNode(predictionValue); updateWeights(posInstances, negInstances, predictionValue); m_search_bestSplitter.setChildForBranch(i, newPredictor); } // insert the new nodes m_search_bestInsertionNode.addChild((Splitter) m_search_bestSplitter, this); // free memory m_search_bestPathPosInstances = null; m_search_bestPathNegInstances = null; m_search_bestSplitter = null; } /** * Generates the m_nominalAttIndices and m_numericAttIndices arrays to index * the respective attribute types in the training data. * */ private void generateAttributeIndicesSingle() { // insert indices into vectors FastVector nominalIndices = new FastVector(); FastVector numericIndices = new FastVector(); for (int i=0; i<m_trainInstances.numAttributes(); i++) { if (i != m_trainInstances.classIndex()) { if (m_trainInstances.attribute(i).isNumeric()) numericIndices.addElement(new Integer(i)); else nominalIndices.addElement(new Integer(i)); } } // create nominal array m_nominalAttIndices = new int[nominalIndices.size()]; for (int i=0; i<nominalIndices.size(); i++) m_nominalAttIndices[i] = ((Integer)nominalIndices.elementAt(i)).intValue(); // create numeric array m_numericAttIndices = new int[numericIndices.size()]; for (int i=0; i<numericIndices.size(); i++) m_numericAttIndices[i] = ((Integer)numericIndices.elementAt(i)).intValue(); } /** * Performs a search for the best test (splitter) to add to the tree, by aiming to * minimize the Z value. * * @exception Exception if search fails */ private void searchForBestTestSingle() throws Exception { // keep track of total weight for efficient wRemainder calculations m_trainTotalWeight = m_trainInstances.sumOfWeights(); m_search_smallestZ = Double.POSITIVE_INFINITY; searchForBestTestSingle(m_root, m_posTrainInstances, m_negTrainInstances); } /** * Recursive function that carries out search for the best test (splitter) to add to * this part of the tree, by aiming to minimize the Z value. Performs Z-pure cutoff to * reduce search space. * * @param currentNode the root of the subtree to be searched, and the current node * being considered as parent of a new split * @param posInstances the positive-class instances that apply at this node * @param negInstances the negative-class instances that apply at this node
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -