📄 decisiontreealgorithm.java
字号:
*/
public Attribute
chooseAttribute( Vector atts, int[] stats, Vector results )
{
// If the list of available attributes is empty,
// return null.
if( atts.size() == 0 ) return null;
int pos = 0;
// Select an attribute, based on the current
// splitting function.
if( m_splitFun.equals( SPLIT_RANDOM ) )
// Choose one of the unmasked attributes
// at random. We leave the results vector
// empty in this case.
pos = m_random.nextInt( atts.size() );
else {
// Calculate a result value for each
// attribute that is available.
double val = 0.0;
double temp;
for( int i = 0; i < atts.size(); i++ ) {
if( m_splitFun.equals( SPLIT_GAIN ) ) {
temp = getGain( stats, (Attribute)atts.elementAt( i ));
if( temp > val ) {
val = temp;
pos = i;
}
results.add( new Double( temp ) );
}
else if( m_splitFun.equals( SPLIT_GAIN_RATIO ) ) {
temp = getGainRatio( stats, (Attribute)atts.elementAt( i ) );
if( temp > val ) {
val = temp;
pos = i;
}
results.add( new Double( temp ) );
}
else if( m_splitFun.equals( SPLIT_GINI ) ) {
temp = getGINI( stats, (Attribute)atts.elementAt( i ));
if( temp > val ) {
val = temp;
pos = i;
}
}
}
}
return (Attribute)atts.elementAt( pos );
}
// Private methods
/**
* Fills the supplied array with the number of examples
* from the current dataset that fall into each of the
* target categories (based on the attribute mask).
*
* @param mask The mask that determines which examples
* reach the current position in the decision
* tree.
*
* @param examples An iteration over a series of
* examples from the current dataset.
*
* @param counts The method expects the parameter to be
* an array with a size equal to the number of
* target attribute values. Each position in
* the array is filled with a corresponding count
* of the number of training examples that fall into
* that particular target class, at the current
* position in the decision tree.
*
* @param reachedHere The method expects the parameter
* to be an array with a size equal to the
* <i>total</i> number of examples being examined.
* Each cell in the array is set to true or
* false, depending on whether or not the
* corresponding example reaches the current
* position in the decision tree.
*/
private void getExampleCounts( AttributeMask mask,
Iterator examples, int counts[], boolean[] reachedHere )
{
// Zero any values currently in stats.
for( int i = 0; i < counts.length; i++ )
counts[i] = 0;
int i = 0;
// Loop through and get totals.
while( examples.hasNext() ) {
int[] example = (int[])examples.next();
if( mask.matchMask( example ) ) {
counts[ example[0] ]++; // Increment appropriate
// target count.
if( reachedHere != null && reachedHere.length > i )
reachedHere[i] = true;
}
i++;
}
}
/**
* Creates a formatted HTML message to display when a
* leaf node is added. This is a convenience method -
* it makes the learnDT() method less cluttered.
*
* @param step The current leaf addition step (which
* identifies the message that should be
* returned).
*
* @param label The label for the new leaf.
*
* @return An HTML text string, suitable for display
* in a GUI panel.
*/
private String createLeafMsg( int step, String label )
{
StringBuffer msg =
new StringBuffer( "<html><font size=\"-1\">" );
if( step == 1 )
msg.append( "No examples reach this position. " +
"Adding new leaf with default target class " +
"<font color=\"yellow\">" + label + "</font>." );
else if( step == 2 )
msg.append( "All examples have the same target " +
"classification. Adding new leaf with common " +
"target class <font color=\"yellow\">" + label +
"</font>." );
else if( step == 3 )
msg.append( "The set of attributes available for " +
"splitting is empty. Adding new leaf " +
"with most common target class" +
"<font color=\"yellow\">" + label + "</font>." );
return msg.append( "</font>" ).toString();
}
/**
* Creates a formatted HTML message to display when an
* internal node is added. This is a convenience method -
* it makes the learnDT() method less cluttered.
*
* @param step The current internal addition step (which
* identifies the message that should be
* returned).
*
* @param label The label for the new node.
*
* @return An HTML text string, suitable for display
* in a GUI panel.
*/
private String createInternalMsg( int step, String label )
{
StringBuffer msg =
new StringBuffer( "<html><font size=\"-1\">" );
if( step == 1 )
msg.append( "Choosing the best attribute to " +
"split on, based on " + m_splitFun + " criteria." );
else if( step == 2 )
msg.append( "Creating new internal node for the best " +
"attribute, <font color=\"yellow\">" + label +
"</font>." );
else if( step == 3 )
msg.append( "Attached new subtree along branch " +
"<font color=\"yellow\">" + label + "</font>." );
return msg.append( "</font>" ).toString();
}
/**
* Computes entropy based on classification counts stored
* in the stats array.
*
* @param stats An array of integers, where each value
* indicates the number of examples that fall into
* a particular target category. The size of the
* array should be the same as the number of possible
* target attribute values.
*
* @param numExamples The total number of examples (sum of
* all the counts in the stats array).
*
* @return The entropy as a double value.
*/
private double entropy( int[] stats, int numExamples )
{
double entropy = 0;
if( numExamples == 0 ) return 0;
for( int i = 0; i < stats.length; i++ ) {
if( stats[i] == 0 ) continue;
// Unfortunately, the Java math class only
// supports log{base e}...
entropy -= ((double)stats[i])/numExamples*
(Math.log( ((double)stats[i])/numExamples) ) / Math.log( 2.0 );
}
return entropy;
}
/**
* Computes the information gain, which is the
* "expected reduction in entropy caused by
* paritioning the examples according to [a certain]
* attribute".
*
* @param stats An array of integers, where each value
* indicates the number of examples that fall into
* a particular target category. The size of the
* array should be the same as the number of
* possible target attribute values. This array
* holds the statistics associated with the number
* of examples that reach a certain node/arc in the
* tree.
*
* @param att An Attribute that can be split on at the
* current position in the tree. The Attribute's
* internal storage space should already be
* populated with correct statistical information.
*
* @return The information gain value for the supplied
* attribute.
*/
private double getGain( int stats[], Attribute att )
{
// First, we calculate the entropy of the set
// of examples before splitting (using the counts
// in the stats array).
int numExamples = 0;
for( int i = 0; i < stats.length; i++ )
numExamples += stats[i];
if( numExamples == 0 ) return 0;
double originalEntropy = entropy( stats, numExamples );
// Now, we determine the entropy after splitting.
double splitEntropy = 0;
int[][] attStats = att.getStatsArray();
// Loop over all possible values.
for( int j = 0; j < attStats.length; j++ ) {
int numSubsetExamples = 0;
// Determine number of examples along this path.
for( int k = 0; k < attStats[j].length; k++ )
numSubsetExamples += attStats[j][k];
splitEntropy +=
((double)numSubsetExamples)/numExamples*
entropy( attStats[j], numSubsetExamples );
}
return originalEntropy - splitEntropy;
}
/**
* Computes the gain ratio for a particular attribute.
* The gain ratio calculation includes a term called
* 'split information' that penalizes attributes that
* split data broadly and uniformly.
*
* @param stats An array of integers, where each value
* indicates the number of examples that fall into
* a particular target category. The size of the
* array should be the same as the number of
* possible target attribute values. This array
* holds the statistics associated with the number
* of examples that reach a certain node/arc in the
* tree.
*
* @param att An Attribute that can be split on at the
* current position in the tree. The Attribute's
* internal storage space should already be
* populated with correct statistical information.
*
* @return The gain ratio for the supplied attribute.
*/
private double getGainRatio( int stats[], Attribute att )
{
// We recompute some of the same quantities
// calculated in the gain method here - it would
// be more efficient to merge the gain calculation
// here.
int numExamples = 0;
for( int i = 0; i < stats.length; i++ )
numExamples += stats[i];
if( numExamples == 0 ) return 0;
// Compute the gain.
double gain = getGain( stats, att );
// Compute the SplitInformation term.
// (which is the entropy of the examples with
// respect to the attribute values of att).
int[] splitInfoStats = new int[ att.getNumValues() ];
int[][] attStats = att.getStatsArray();
// Loop over all possible values.
for( int j = 0; j < attStats.length; j++ ) {
// Determine number of examples along this path.
for( int k = 0; k < attStats[j].length; k++ )
splitInfoStats[j] += attStats[j][k];
}
double splitInfo = entropy( splitInfoStats, numExamples );
return gain/splitInfo;
}
/**
* Computes the GINI score for a particular attribute.
*
* @param stats An array of integers, where each value
* indicates the number of examples that fall into
* a particular target category. The size of the
* array should be the same as the number of
* possible target attribute values. This array
* holds the statistics associated with the number
* of examples that reach a certain node/arc in the
* tree.
*
* @param att An Attribute that can be split on at the
* current position in the tree. The Attribute's
* internal storage space should already be
* populated with correct statistical information.
*
* @return the GINI score for the supplied attribute.
*/
private double getGINI( int stats[], Attribute att )
{
// Determine the total number of examples.
int numExamples = 0;
for( int i = 0; i < stats.length; i++ )
numExamples += stats[i];
if( numExamples == 0 ) return 0;
double giniScore = 0;
int[][] attStats = att.getStatsArray();
// Loop over all possible values.
for( int j = 0; j < attStats.length; j++ ) {
int numSubsetExamples = 0;
int sumOfSquares = 0;
// Determine number of examples along this path.
for( int k = 0; k < attStats[j].length; k++ ) {
sumOfSquares += attStats[j][k] * attStats[j][k];
numSubsetExamples += attStats[j][k];
}
if( numSubsetExamples != 0 )
giniScore += ((double)sumOfSquares)/numSubsetExamples;
}
// Now, compute the second term in the GINI score.
for( int l = 0; l < stats.length; l++ )
giniScore -= ((double)stats[l])*((double)stats[l]) / numExamples;
// Finally, divide by the total number of examples.
giniScore /= numExamples;
return giniScore;
}
/**
* Reduced error pruing error function.
*
* @param numExamplesReachParent The number of examples
* that reach the parent of a given position in the tree.
*
* @param numExamplesIncorrectClass The number of examples
* <i>in</i>correctly classified at the given position in
* the tree (this has meaning only for leaf nodes).
*/
private double errorRE( int numExamplesReachParent,
int numExamplesIncorrectClass )
{
// Include Laplacian correction.
int total = 1 + numExamplesIncorrectClass;
return ((double)total)/(numExamplesReachParent + 2);
}
/**
* Returns an error bar value based on the current
* confidence interval.
*
* @param mean The mean value used to calculate the error
* bar value.
*
* @param size The sample size.
*/
private double errorBar( double mean, int size )
{
return Math.sqrt( mean * (1 - mean) / size ) * m_pessPruneZScore;
}
/**
* Creates a formatted HTML message to display during
* reduced error pruning. This is a convenience method -
* it makes the pruneReducedErrorDT() method less cluttered.
*
* @param errCurrent The error produced by the current tree.
*
* @param errPrune The error that would result if the tree
* was pruned.
*
* @return An HTML text string, suitable for display
* in a GUI panel.
*/
private String createPruningMsg( double errCurrent, double errPrune )
{
String errCurrentString = Double.toString( errCurrent );
if( errCurrentString.length() > 5 )
errCurrentString = errCurrentString.substring( 0, 4 );
String errPruneString = Double.toString( errPrune );
if( errPruneString.length() > 5 )
errPruneString = errPruneString.substring( 0, 4 );
StringBuffer msg =
new StringBuffer( "<html><font size=\"-1\">" );
msg.append( "Current error = " + errCurrentString +
", pruning error = " + errPruneString + "." );
return msg.append( "</font>" ).toString();
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -