📄 decisiontreealgorithm.java
字号:
i = m_algorithmListeners.iterator();
while( i.hasNext() )
((AlgorithmListener)i.next()).notifyAlgorithmFinished();
if( m_manager.getStatusBar() != null )
m_manager.getStatusBar().postMessage( "Algorithm finished." );
}
/**
* An implementation of the recursive decision tree
* learning algorithm. Given a parent node and an arc
* number, the method will attach a new decision 'sub'-tree
* below the parent node.
*
* @param parent The parent node for the new decision tree.
*
* @param arcNum The arc number (or path) along which the
* new subtree will be attached.
*
* @return true if an entire subtree was successfully added,
* false otherwise.
*/
public boolean learnDT( DecisionTreeNode parent, int arcNum )
{
AttributeMask mask;
Iterator i;
if( parent == null )
// We have to add at the root.
mask = new AttributeMask( m_dataset.getNumAttributes() );
else {
mask = new AttributeMask( parent.getMask() );
// Mask off the specified arc number.
try {
mask.mask(
m_dataset.getAttributePosition( parent.getLabel() ), arcNum );
}
catch( NonexistentAttributeException e ) {
e.printStackTrace();
return false;
}
}
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 1, null ) ) return false;
// Now, classify the examples at the current position.
int[] conclusion = new int[8];
int result = classifyExamples( mask, conclusion, null, null, null );
Attribute target = m_dataset.getTargetAttribute();
int numTargetVals = target.getNumValues();
String label;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 2, null ) ) return false;
if( result == DATASET_EMPTY ) {
// If no examples reach our current position
// we add a leaf with the most common target
// classfication for the parent node.
// Save testing results.
int numTestingExamplesReachHere = conclusion[5];
int bestTestingTargetIndex = conclusion[4];
int numTestingExamplesCorrectClass = conclusion[6];
int numTrainingExamplesCorrectClass = conclusion[7];
classifyExamples( parent.getMask(), conclusion, null, null, null );
try {
label = target.getAttributeValueByNum( conclusion[0] );
}
catch( NonexistentAttributeValueException e ) {
e.printStackTrace();
return false;
}
//--------------------- Debug ---------------------
if( DEBUG_ON ) {
System.out.println();
System.out.println( "DecisionTreeAlgorithm::learnDT: " +
"No examples reach the current position." );
}
// We have to grab the counts again for the testing data...
int[] currTestingCounts = new int[ target.getNumValues() ];
getExampleCounts( mask,
m_dataset.getTestingExamples(), currTestingCounts, null );
// Mask target value and add a leaf to the tree.
mask.mask( 0, conclusion[0] );
DecisionTreeNode node =
m_tree.addLeafNode( parent,
arcNum,
label,
mask,
0,
conclusion[0],
0,
currTestingCounts[conclusion[0]],
numTestingExamplesReachHere,
bestTestingTargetIndex,
numTestingExamplesCorrectClass,
numTrainingExamplesCorrectClass );
i = m_algorithmListeners.iterator();
if( m_verboseFlag )
while( i.hasNext() )
((AlgorithmListener)
i.next()).notifyAlgorithmStepStart( createLeafMsg( 1, label ) );
else
while( i.hasNext() )
((AlgorithmListener)
i.next()).notifyAlgorithmStepStart();
return true;
}
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 3, null ) ) return false;
if( result == DATASET_IDENT_CONCL ) {
// Pure result - we can add a leaf node with the
// correct target attribute value.
try {
label = target.getAttributeValueByNum( conclusion[0] );
}
catch( NonexistentAttributeValueException e ) {
e.printStackTrace();
return false;
}
//--------------------- Debug ---------------------
if( DEBUG_ON ) {
System.out.println();
System.out.println( "DecisionTreeAlgorithm::learnDT: " +
"All examples at the current position " +
"have the same target class '" + label + "'." );
}
// Mask target value and add a leaf to the tree.
mask.mask( 0, conclusion[0] );
DecisionTreeNode node =
m_tree.addLeafNode( parent,
arcNum,
label,
mask,
conclusion[1],
conclusion[0],
conclusion[2],
conclusion[3],
conclusion[5],
conclusion[4],
conclusion[6],
conclusion[7] );
i = m_algorithmListeners.iterator();
if( m_verboseFlag )
while( i.hasNext() )
((AlgorithmListener)
i.next()).notifyAlgorithmStepStart( createLeafMsg( 2, label ) );
else
while( i.hasNext() )
((AlgorithmListener)
i.next()).notifyAlgorithmStepStart();
return true;
}
// Mixed conclusion - so we have to select
// an attribute to split on, and then build a
// new internal node with that attribute.
// First, generate statistics - this may take awhile.
int[] nodeStats = new int[ numTargetVals ];
Vector availableAtts = generateStats( mask, nodeStats );
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 4, null ) ) return false;
if( availableAtts.size() == 0 ) {
// No attributes left to split on - so use
// the most common target value at the current position.
try {
label = target.getAttributeValueByNum( conclusion[0] );
}
catch( NonexistentAttributeValueException e ) {
e.printStackTrace();
return false;
}
//--------------------- Debug ---------------------
if( DEBUG_ON ) {
System.out.println();
System.out.println( "DecisionTreeAlgorithm::learnDT: " +
"No attributes left to split on at current position." );
}
mask.mask( 0, conclusion[0] );
DecisionTreeNode node =
m_tree.addLeafNode( parent,
arcNum,
label,
mask,
conclusion[1],
conclusion[0],
conclusion[2],
conclusion[3],
conclusion[5],
conclusion[4],
conclusion[6],
conclusion[7] );
i = m_algorithmListeners.iterator();
if( m_verboseFlag )
while( i.hasNext() )
((AlgorithmListener)
i.next()).notifyAlgorithmStepStart( createLeafMsg( 3, label ) );
else
while( i.hasNext() )
((AlgorithmListener)
i.next()).notifyAlgorithmStepStart();
return true;
}
// Choose an attribute, based on the set of
// available attributes.
Vector results = new Vector();
Attribute att = chooseAttribute( availableAtts, nodeStats, results );
//--------------------- Debug ---------------------
if( DEBUG_ON ) {
System.out.println();
System.out.println( "DecisionTreeAlgorithm::learnDT: " +
"Preparing to split..." );
System.out.println();
System.out.println( "Available attributes and " +
"associated " + m_splitFun + " values:" );
for( int j = 0; j < availableAtts.size(); j++ ) {
System.out.print(
((Attribute)availableAtts.elementAt( j )).getName() );
if( j < results.size() ) {
System.out.print( " - " );
System.out.println( (Double)results.elementAt(j) );
}
else {
System.out.println();
}
}
System.out.println();
System.out.println( "Selected " + att.getName() + "." );
System.out.println();
}
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 5, createInternalMsg(1, att.getName() )) )
return false;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 6, createInternalMsg(2, att.getName() )) )
return false;
int attPos;
try {
attPos = m_dataset.getAttributePosition( att.getName() );
}
catch( NonexistentAttributeException e ) {
e.printStackTrace();
return false;
}
DecisionTreeNode newParent =
m_tree.addInternalNode( parent,
arcNum,
attPos,
att,
mask,
conclusion[1],
conclusion[0],
conclusion[2],
conclusion[3],
conclusion[5],
conclusion[4],
conclusion[6],
conclusion[7] );
// Now, recursively decend along each branch of the new node.
for( int j = 0; j < newParent.getArcLabelCount(); j++ ) {
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 7, null ) ) return false;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 8, null ) ) return false;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 9, null ) ) return false;
//--------------------- Debug ---------------------
if( DEBUG_ON ) {
System.out.println();
System.out.println( "DecisionTreeAlgorithm::learnDT: " +
"Descending along branch " + (j+1) + " of " +
newParent.getArcLabelCount() + "." );
}
// Recursive call.
if( !learnDT( newParent, j ) ) return false;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 10, null ) ) return false;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 11,
createInternalMsg(3, newParent.getArcLabel(j)) ) )
return false;
}
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 12, null ) ) return false;
return true;
}
/**
* An implementation of the recursive decision tree
* reduced error pruning algorithm. Given a parent
* node, the method will prune all the branches
* below the node.
*
* @param node The root node of the tree to prune.
*
* @param error A <code>double</code> array of size 1. The
* array is used to store the current error value.
*
* @return <code>true</code> if an entire subtree was successfully
* pruned, or <code>false</code> otherwise.
*/
public boolean pruneReducedErrorDT( DecisionTreeNode node, double[] error )
{
// Post-order walk through the tree, marking
// our path as we go along.
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 1, null ) ) return false;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 2, null ) ) return false;
if( node.isLeaf() ) {
error[0] = node.getTestingEgsAtNode() -
node.getTestingEgsCorrectClassUsingBestTrainingIndex();
return true;
}
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 3, null ) ) return false;
// We're at an internal node, so compute the error
// of the children and use the result to determine
// if we prune or not.
double errorSum = 0;
for( int i = 0; i < node.getArcLabelCount(); i++ ) {
// Mark our current path.
m_tree.flagNode( node, i );
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 4, null ) ) return false;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 5, null ) ) return false;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 6, null ) ) return false;
if( !pruneReducedErrorDT( node.getChild( i ), error ) ) {
m_tree.flagNode( node, -2 );
return false;
}
errorSum += error[0];
}
// Mark the node as our current target.
m_tree.flagNode( node, -1 );
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 8, null ) ) return false;
// Get the best-case performance of this node.
double errorBest = node.getTestingEgsAtNode() -
node.getTestingEgsCorrectClassUsingBestTestingIndex();
DecisionTreeNode newNode = node;
//----------------- Breakpoint -----------------
if( !handleBreakpoint( 9,
createPruningMsg( errorSum, errorBest ) ) ) return false;
if( errorBest < errorSum ) {
// We need to "prune" this node to a leaf.
DecisionTreeNode parent = node.getParent();
int arcNum = -1;
if( parent != null )
arcNum = parent.getChildPosition( node );
//--------------------- Debug ---------------------
if( DEBUG_ON ) {
System.out.println();
System.out.println( "DecisionTreeAlgorithm::pruneReduceErrorDT: " +
" Pruning node " + node.getLabel() + "." );
}
m_tree.pruneSubtree( node );
// Figure out the label for the new leaf.
String label = null;
try {
label =
m_dataset.getTargetAttribute().getAttributeValueByNum(
node.getTestingBestTarget() );
}
catch( NonexistentAttributeValueException e ) {
// Should never happen.
e.printStackTrace();
}
node.getMask().mask( 0, node.getTestingBestTarget() );
newNode =
m_tree.addLeafNode( parent, arcNum, label,
node.getMask(),
node.getTrainingEgsAtNode(),
node.getTestingBestTarget(),
node.getTrainingEgsCorrectClassUsingBestTestingIndex(),
node.getTestingEgsCorrectClassUsingBestTestingIndex(),
node.getTestingEgsAtNode(),
node.getTestingBestTarget(),
node.getTestingEgsCorrectClassUsingBestTestingIndex(),
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -