📄 decisiontreealgorithm.java
字号:
node.getTrainingEgsCorrectClassUsingBestTestingIndex() );
}
// Update the count.
if( newNode.isLeaf() )
error[0] = errorBest;
else
error[0] = errorSum;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 10, null ) ) return false;
// All finished, unmark the node if it still exists.
m_tree.flagNode( node, -2 );
return true;
}
/**
* An implementation of the recursive decision tree
* pessimistic pruning algorithm. Given a parent
* node, the method will prune all the branches
* below the node.
*
* @param node The root node of the tree to prune.
*
* @param error A <code>double</code> array of size 1. The
* array is used to store the current error value.
*
* @return <code>true</code> if an entire subtree was successfully
* pruned, or <code>false</code> otherwise.
*/
public boolean prunePessimisticDT( DecisionTreeNode node, double[] error )
{
// Post-order walk through the tree, marking
// our path as we go along.
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 1, null ) ) return false;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 2, null ) ) return false;
if( node.isLeaf() ) {
if( node.getTrainingEgsAtNode() == 0 ) {
error[0] = 0;
return true;
}
else {
// We do the error calculation in two steps -
// Here we multiply the error value by the number
// of examples that reach the node. When the method
// is called recursively, this value will be divided
// by the number of examples that reach the parent
// node (thus weighting the error from each child).
int errors = node.getTrainingEgsAtNode() -
node.getTrainingEgsCorrectClassUsingBestTrainingIndex();
double p = (errors + 1.0) / (node.getTrainingEgsAtNode() + 2.0);
error[0] = node.getTrainingEgsAtNode() *
errorBar( p, node.getTrainingEgsAtNode() ) + errors;
return true;
}
}
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 3, null ) ) return false;
// We're at an internal node, so compute the error
// of the children and use the result to determine
// if we prune or not.
double errorSum = 0;
for( int i = 0; i < node.getArcLabelCount(); i++ ) {
// Mark our current path.
m_tree.flagNode( node, i );
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 4, null ) ) return false;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 5, null ) ) return false;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 6, null ) ) return false;
if( !prunePessimisticDT( node.getChild( i ), error ) ) {
m_tree.flagNode( node, -2 );
return false;
}
errorSum += error[0];
}
// Mark the node as our current target.
m_tree.flagNode( node, -1 );
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 8, null ) ) return false;
// Get the worst-case performance of this node.
double errorWorst;
if( node.getTrainingEgsAtNode() == 0 ) {
error[0] = 0;
return true;
}
int errors = node.getTrainingEgsAtNode() -
node.getTrainingEgsCorrectClassUsingBestTrainingIndex();
double p = (errors + 1.0) / (node.getTrainingEgsAtNode() + 2.0);
errorWorst = node.getTrainingEgsAtNode() *
errorBar( p, node.getTrainingEgsAtNode() ) + errors;
DecisionTreeNode newNode = node;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 9,
createPruningMsg( errorSum, errorWorst ) ) ) return false;
if( errorWorst < errorSum ) {
// We need to "prune" this node to a leaf.
DecisionTreeNode parent = node.getParent();
int arcNum = -1;
if( parent != null )
arcNum = parent.getChildPosition( node );
//--------------------- Debug ---------------------
if( DEBUG_ON ) {
System.out.println();
System.out.println( "DecisionTreeAlgorithm::prunePessimisticDT: " +
" Pruning node " + node.getLabel() + "." );
}
m_tree.pruneSubtree( node );
// Figure out the label for the new leaf.
String label = null;
try {
label =
m_dataset.getTargetAttribute().getAttributeValueByNum(
node.getTrainingBestTarget() );
}
catch( NonexistentAttributeValueException e ) {
// Should never happen.
e.printStackTrace();
}
node.getMask().mask( 0, node.getTrainingBestTarget() );
newNode =
m_tree.addLeafNode( parent, arcNum, label,
node.getMask(),
node.getTrainingEgsAtNode(),
node.getTrainingBestTarget(),
node.getTrainingEgsCorrectClassUsingBestTrainingIndex(),
node.getTestingEgsCorrectClassUsingBestTrainingIndex(),
node.getTestingEgsAtNode(),
node.getTestingBestTarget(),
node.getTestingEgsCorrectClassUsingBestTestingIndex(),
node.getTrainingEgsCorrectClassUsingBestTestingIndex() );
}
// Update the count.
if( newNode.isLeaf() )
error[0] = errorWorst;
else
error[0] = errorSum;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 10, null ) ) return false;
// All finished, unmark the node if it still exists.
m_tree.flagNode( node, -2 );
return true;
}
/**
* Classifies all examples in the current set of
* examples, by target attribute value. The
* attribute mask determines which examples from the
* dataset form the current example set.
*
* @param mask The current attribute mask that
* indicates which examples from the dataset
* should be considered.
*
* @param conclusion The method expects the parameter
* to be an array of size 8. Positions in
* the array are filled with the following
* values.
*
* <ul>
* <li><i>Position 0</i> - Records the
* index of the most common target attribute
* value from the training dataset.
*
* <li><i>Position 1</i> - Records the number
* of training examples from the dataset that
* reach the current position.
*
* <li><i>Position 2</i> - Records the number
* of training examples from the dataset
* that would be correcly classified
* <i>if a leaf with the most common training
* target classification</i> were added at the
* current position.
*
* <li><i>Position 3</i> - Records the number
* if testing examples from the dataset
* that would be correctly classified
* <i>if a leaf with the most common training
* target classification</i> were added at the
* current position.
*
* <li><i>Position 4</i> - Records the index
* of the most common target attribute
* value from the testing dataset.
*
* <li><i>Position 5</i> - Records the number
* of testing examples from the dataset that
* reach the current position.
*
* <li><i>Position 6</i> - Records the number
* of testing examples from the dataset
* that would be correcly classified
* <i>if a leaf with the most common testing
* target classification</i> were added at the
* current position.
*
* <li><i>Position 7</i> - Records the number
* if training examples from the dataset
* that would be correctly classified
* <i>if a leaf with the most common testing
* target classification</i> were added at the
* current position.
* </ul>
*
* @param trainingCounts The method expects the parameter to be
* an array with a size equal to the number of
* target attribute values. Each position in
* the array is filled with a corresponding count
* of the number of training examples that fall into
* that particular target class, at the current
* position in the tree. This parameter can be null
* if training count data is not required.
*
* @param testingCounts The method expects the parameter to be
* an array with a size equal to the number of
* target attribute values. Each position in the
* array is filled with a corresponding count of
* the number of testing examples that fall into
* that particular target class, at the current
* position in the tree. This parameter can be null
* if testing count data is not required.
*
* @param examples The method expects the parameter to be
* an array with a size equal to the number of
* training examples in the dataset. Each entry in
* the array is marked with true or false, depending
* on whether or not a particular example reaches
* the current position.
*
* @return DATASET_MIXED_CONCL if the examples have
* multiple, different target attribute values.
* DATASET_IDENT_CONCL if all the exmamples have
* the same target attribute value.
* DATASET_EMPTY if there are no examples in the
* current example set.
*
* <p>
* If the result is DATASET_IDENT_CONCL, the
* index of the single target attribute value
* is returned in <code>conclusion[0]</code>. If
* the result is DATASET_EMPTY, the index of the
* most common target attribute value is returned
* in <code>conclusion[0]</code>.
*/
public int
classifyExamples( AttributeMask mask,
int[] conclusion,
int[] trainingCounts,
int[] testingCounts,
boolean[] examples )
{
if( mask == null || conclusion == null )
throw new
NullPointerException( "Mask or conclusion array is null." );
// Determine the number of target attribute values
// and create some storage space for our counts.
int[] currTrainingCounts = null;
int[] currTestingCounts = null;
if( trainingCounts != null )
currTrainingCounts = trainingCounts;
else
currTrainingCounts = new
int[ m_dataset.getTargetAttribute().getNumValues() ];
if( testingCounts != null )
currTestingCounts = testingCounts;
else
currTestingCounts = new int[ currTrainingCounts.length ];
getExampleCounts( mask,
m_dataset.getTrainingExamples(), currTrainingCounts, examples );
getExampleCounts( mask,
m_dataset.getTestingExamples(), currTestingCounts, null );
// Init results.
conclusion[0] = 0; // Training target attribute value index.
conclusion[1] = 0; // Total number of training examples
// reaching node.
conclusion[2] = 0; // Number of training examples correctly
// classified if this node were a leaf
// with most common training target value.
conclusion[3] = 0; // Number of testing examples correctly
// classified if this node were a leaf
// with most common training target value.
conclusion[4] = 0; // Testing target attribute value index.
conclusion[5] = 0; // Total number of testing examples
// reaching node.
conclusion[6] = 0; // Number of testing examples correctly
// classified if this node were a leaf
// with most common testing target value.
conclusion[7] = 0; // Number of training examples correctly
// classified if this node were a leaf
// with most common testing target value.
// Examine the results and determine the conclusion.
int result = DATASET_EMPTY;
for( int i = 0; i < currTrainingCounts.length; i++ ) {
// Increment # of examples that reach this position.
conclusion[1] += currTrainingCounts[i];
conclusion[5] += currTestingCounts[i];
if( result == DATASET_EMPTY && currTrainingCounts[i] != 0 )
result = DATASET_IDENT_CONCL;
else if( result == DATASET_IDENT_CONCL && currTrainingCounts[i] != 0 )
result = DATASET_MIXED_CONCL;
if( currTrainingCounts[i] >= currTrainingCounts[ conclusion[0] ] ) {
// This target value is more common in the training set.
conclusion[0] = i;
conclusion[2] = currTrainingCounts[i];
conclusion[3] = currTestingCounts[i];
}
if( currTestingCounts[i] >= currTestingCounts[ conclusion[4] ] ) {
// This target value is more common in the testing set.
conclusion[4] = i;
conclusion[6] = currTestingCounts[i];
conclusion[7] = currTrainingCounts[i];
}
}
return result;
}
/**
* Generates statistics (used for splitting) based on the
* current position in the tree (as defined by an
* attribute mask).
*
* @return A Vector that contains the available attributes.
* Each attribute's internal statistics array is
* populated with appropriate data. The supplied
* stats array is filled with counts of the number of
* examples that fall into each of the target classes
* at the current position in the tree.
*/
public Vector generateStats( AttributeMask mask, int stats[] )
{
// First, we fill the stats array - this is not the
// most efficient approach, since we're looping through
// the data several times.
getExampleCounts( mask, m_dataset.getTrainingExamples(), stats, null );
// Now, we have to go through the attribute mask
// and locate the attributes that are still available.
Vector results = new Vector();
// Create a new mask that we can modify.
AttributeMask newMask = new AttributeMask( mask );
// We don't use position 0, that's where the target attribute is.
for( int i = 1; i < mask.getNumAttributes(); i++ ) {
if( newMask.isMasked( i ) == AttributeMask.UNUSED ) {
// This attribute is available, so we calculate stats for it.
Attribute att = null;
try {
att = m_dataset.getAttributeByNum( i );
}
catch( NonexistentAttributeException e ) {
// This can't happen!!
e.printStackTrace();
return null;
}
int[][] attStats = att.getStatsArray();
// Modify the mask and fill in the arrays.
for( int j = 0; j < att.getNumValues(); j++ ) {
newMask.mask( i, j );
getExampleCounts( newMask, m_dataset.getTrainingExamples(),
attStats[j], null );
}
// Reset the mask.
newMask.unmask( i );
results.add( att );
}
}
return results;
}
/**
* Choose an available attribute to split on,
* based on the supplied attribute mask. The
* current splitting function is used to select
* a particular attribute.
*
* @param atts A vector that contains all available (UNUSED)
* attributes. The attributes in the vector have
* statistics attached that are used in the selection
* process.
*
* @param stats An array that contains classification
* results for a particular path through the
* tree, <i>before</i> splitting.
*
* @param results An initially empty vector that,
* depending on the splitting function, can be
* filled with splitting results for each of the
* available attributes. Entries in the Vector
* are Double objects. If a particular splitting
* function does not return numerical results
* (e.g. the 'Random' function), the Vector
* remains empty.
*
* @return The attribute that <i>best</i> classifies
* current examples (examples that are valid
* at the current position in the tree).
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -