📄 decisiontreealgorithm.java
字号:
package ai.decision.algorithm;
import java.util.*;
import ai.decision.gui.*;
import ai.common.*;
/**
* An implementation of a decision tree learning algorithm.
* (See Mitchell, <i>Machine Learning</i>, pg. 56;
* Russell and Norvig, <i>Artificial Intelligence:
* A Modern Approach</i>, pg. 537 )
*
* <p>
* When building a decision tree, this class relies on a
* Dataset instance to provide relevant statistics for
* attribute selection.
*
* <p>
* <b>Change History:</b>
*
* <p><pre>
* Name: Date: Change:
* =============================================================
* J. Kelly Sep-26-2000 Ground-up rewrite using
* AlgorithmFramework class.
* J. Kelly Oct-02-2000 Added reduced-error and
* pessimistic pruning.
* </pre>
*
* Copyright 2000 University of Alberta.
*
* <!--
* This file is part of the Decision Tree Applet.
*
* The Decision Tree Applet is free software; you can redistribute it
* and/or modify it under the terms of the GNU General Public License as
* published by the Free Software Foundation; either version 2 of the
* License, or (at your option) any later version.
*
* Foobar is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with the Decision Tree Applet; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
* -->
*/
public class DecisionTreeAlgorithm
extends AlgorithmFramework
{
// Class data members
// Possible splitting criteria
/**
* Indicates that attributes for splitting are selected
* at random.
*/
public static final String SPLIT_RANDOM = "Random";
/**
* Indicates that attributes for splitting are selected
* based on maximum information gain.
*/
public static final String SPLIT_GAIN = "Gain";
/**
* Indicates that attributes for splitting are selected
* based on maximum gain ratio.
*/
public static final String SPLIT_GAIN_RATIO = "Gain Ratio";
/**
* Indicates that attributes for splitting are selected
* based on maxiumum GINI score.
*/
public static final String SPLIT_GINI = "GINI";
/**
* An array of the available splitting functions.
*/
public static final String[] SPLIT_FUNCTIONS =
{ SPLIT_RANDOM, SPLIT_GAIN, SPLIT_GAIN_RATIO, SPLIT_GINI };
//-------------------------------------------------------
/**
* Indicates that the examples have mixed target
* attribute values.
*/
public static final int DATASET_MIXED_CONCL = 0;
/**
* Indicates that all examples share one common target
* attribute value.
*/
public static final int DATASET_IDENT_CONCL = 1;
/**
* Indicates that the set of training examples is empty.
*/
public static final int DATASET_EMPTY = 2;
//-------------------------------------------------------
// Possible pruning algorithms
/**
* Indicates that the decision tree should not be
* pruned.
*/
public static final String PRUNING_NONE = "None";
/**
* Indicates that the decision tree should be
* pruned using the reduced-error pruning
* algorithm.
*/
public static final String PRUNING_REDUCED_ERROR = "Reduced-error";
/**
* Indicates that the decision tree should be
* pruned using the pessimistic pruning
* algorithm.
*/
public static final String PRUNING_PESSIMISTIC = "Pessimistic";
/**
* An array of the available pruning algorithms.
*/
public static final String[] PRUNING_ALGORITHMS =
{ PRUNING_NONE, PRUNING_REDUCED_ERROR, PRUNING_PESSIMISTIC };
/**
* Default pessimistic pruning z-score - 95% confidence
* interval.
*/
public static final double DEFAULT_Z_SCORE = 1.96;
//-------------------------------------------------------
// Instance data members
Dataset m_dataset; // Data set used to build tree.
DecisionTree m_tree; // Current decision tree.
String m_splitFun; // Current splitting function.
String m_pruneAlg; // Current pruning algorithm.
Random m_random; // Random number generator.
double m_pessPruneZScore; // Pessimistic pruning Z-score.
ComponentManager m_manager;
// Constructors
/**
* Prepares to run the decision tree algorithm by
* creating an empty decision tree.
*
* <p>
* The default splitting function is set to "Random".
*
* @param dataset A Dataset object that is already initialized.
*
* @throws NullPointerException If the supplied Dataset or
* ComponentManager is null.
*/
public DecisionTreeAlgorithm( Dataset dataset, ComponentManager manager )
{
super();
if( dataset == null || manager == null )
throw new
NullPointerException( "Dataset or component manager is null." );
m_manager = manager;
m_dataset = dataset;
m_splitFun = SPLIT_RANDOM;
m_pruneAlg = PRUNING_NONE;
m_random = new Random( 2389 );
m_tree = new DecisionTree();
m_pessPruneZScore = DEFAULT_Z_SCORE;
}
// Public methods
/**
* Returns a reference to the dataset that the algorithm
* is currently using.
*
* @return A reference to the current dataset.
*/
public Dataset getDataset()
{
return m_dataset;
}
/**
* Sets the current dataset. Changing the dataset
* automatically destroys the current tree.
*
* @param dataset The new dataset.
*
* @throws NullPointerException if the supplied dataset
* is null.
*/
public void setDataset( Dataset dataset )
{
if( dataset == null )
throw new
NullPointerException( "Dataset is null." );
m_dataset = dataset;
m_tree = new DecisionTree();
// Re-register listeners with the new tree.
if( m_manager.getVisualTreePanel() != null )
m_tree.addTreeChangeListener( m_manager.getVisualTreePanel() );
}
/**
* Resets the algorithm, destroying the current
* tree. The dataset used to build the tree
* remains unchanged.
*/
public void reset()
{
m_tree = new DecisionTree();
// Re-register listeners with the new tree.
if( m_manager.getVisualTreePanel() != null )
m_tree.addTreeChangeListener( m_manager.getVisualTreePanel() );
}
/**
* Returns a reference to the decision tree data structure.
*/
public DecisionTree getTree()
{
return m_tree;
}
/**
* Sets the splitting function used to build the
* decision tree. If the supplied function name
* does not correspond to one of the known functions
* the random 'function' is used by default.
*
* @param splitFun The new splitting function - this must
* be one of SPLIT_RANDOM, SPLIT_GAIN,
* SPLIT_GAIN_RATIO or SPLIT_GINI.
*/
public synchronized void setSplittingFunction( String splitFun )
{
if( splitFun.equals( SPLIT_RANDOM ) ||
splitFun.equals( SPLIT_GAIN ) ||
splitFun.equals( SPLIT_GAIN_RATIO ) ||
splitFun.equals( SPLIT_GINI ) )
m_splitFun = splitFun;
else
m_splitFun = SPLIT_RANDOM;
// Inform HighlightListeners that the splitting
// function text may have changed.
Iterator i = m_highlightListeners.iterator();
while( i.hasNext() )
((HighlightListener)i.next())
.setDynamicText( "LearnDT", "splitfun", m_splitFun );
}
/**
* Sets the pruning algorithm. If the supplied algorithm
* name does not correspond to one of the known algorithms,
* pruning is disabled.
*
* @param pruneAlg The new pruning algorithm - this must
* be one of PRUNING_NONE, PRUNING_REDUCED_ERROR or
* PRUNING_PESSIMISTIC.
*/
public synchronized void setPruningAlgorithm( String pruneAlg )
{
if( pruneAlg.equals( PRUNING_NONE ) ||
pruneAlg.equals( PRUNING_REDUCED_ERROR ) ||
pruneAlg.equals( PRUNING_PESSIMISTIC ) )
m_pruneAlg = pruneAlg;
else
m_pruneAlg = PRUNING_NONE;
// Inform HighlightListeners that the splitting
// function text may have changed.
Iterator i = m_highlightListeners.iterator();
while( i.hasNext() )
((HighlightListener)i.next())
.setDynamicText( "BuildDT", "prunealg", m_pruneAlg );
}
/**
* Returns the current splitting function.
*
* @return The current splitting function as a String.
*/
public String getSplittingFunction()
{
return m_splitFun;
}
/**
* Returns the current pruning algorithm.
*
* @return The current pruning algorithm as a String.
*/
public String getPruningAlgorithm()
{
return m_pruneAlg;
}
/**
* Returns the current pessimistic pruning Z-score.
*
* @return The current pessimistic pruning Z-score.
*/
public double getPessimisticPruningZScore()
{
return m_pessPruneZScore;
}
/**
* Sets the current pessimistic pruning Z-score.
*
* @param zScore The new Z-score to use when calculating
* error bars for the pessimistic pruning algorithm.
*
* @throws IllegalArgumentException If the supplied value
* is negative.
*/
public void setPessimisticPruningZScore( double zScore )
{
if( zScore < 0 )
throw new IllegalArgumentException( "Supplied Z-score < 0." );
m_pessPruneZScore = zScore;
}
/**
* Runs the complete decision tree algorithm.
* The algorithm starts with the current state of
* the tree, and builds from there.
*
* <p>
* This method allows the algorithm to be run in a
* separate thread.
*/
public synchronized void run()
{
Iterator i;
// Inform HighlightListeners that we're about to start
// the main "BuildDT" routine.
i = m_highlightListeners.iterator();
while( i.hasNext() )
((HighlightListener)i.next()).displayFunction( "BuildDT" );
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 1, null ) ) return;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 2, null ) ) return;
DecisionTreeNode parent;
int[] arcNum = new int[1];
//------------------ LearnDT ------------------
// Inform HighlightListeners that we're about to start
// the "LearnDT" routine.
i = m_highlightListeners.iterator();
while( i.hasNext() )
((HighlightListener)i.next()).displayFunction( "LearnDT" );
// Build the tree, looping until it's complete.
while( true ) {
// Determine where to start building
// (leftmost available position).
if( m_tree.isEmpty() )
parent = null;
else if( (parent =
m_tree.findIncompleteNode( m_tree.getRoot(), arcNum )) == null )
break;
// Build the tree - the call to learnDT()
// will fill in everything below the current branch.
if( !learnDT( parent, arcNum[0] ) ) break;
}
// Now check - if the tree is complete, we can start the
// pruning process. Otherwise, we've stopped with some
// nodes still missing.
if( !m_tree.isComplete() ) return;
// Inform HighlightListeners that we've finished "LearnDT",
// and are now moving on to "PruneDT".
i = m_highlightListeners.iterator();
while( i.hasNext() )
((HighlightListener)i.next()).displayFunction( "BuildDT" );
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 3, null ) ) return;
//------------------ PruneDT ------------------
if( m_pruneAlg.equals( PRUNING_REDUCED_ERROR ) ) {
// Inform HighlightListeners that we're about to start
// the "PruneReducedErrorDT" routine.
i = m_highlightListeners.iterator();
while( i.hasNext() )
((HighlightListener)i.next()).displayFunction( "PruneReducedErrorDT" );
// pruneReducedErrorDT will return false if it's
// interrupted and is unable to finish pruning the tree.
if( !pruneReducedErrorDT( m_tree.getRoot(), new double[1] ) )
return;
}
else if( m_pruneAlg.equals( PRUNING_PESSIMISTIC ) ) {
// Inform HighlightListeners that we're about to start
// the "PrunePessimisticDT" routine.
i = m_highlightListeners.iterator();
while( i.hasNext() )
((HighlightListener)i.next()).displayFunction( "PrunePessimisticDT" );
// prunePessimisticDT will return false if it's
// interrupted and is unable to finish pruning the tree.
if( !prunePessimisticDT( m_tree.getRoot(), new double[1] ) )
return;
}
// Reset the function display to "BuildDT"
i = m_highlightListeners.iterator();
while( i.hasNext() )
((HighlightListener)i.next()).displayFunction( "BuildDT" );
// Tell anyone that might be tracking
// the state of the algorithm that we've finished.
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -