📄 id3treenode.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/**
* Title: XELOPES Data Mining Library
* Description: The XELOPES library is an open platform-independent and data-source-independent library for Embedded Data Mining.
* Copyright: Copyright (c) 2002 Prudential Systems Software GmbH
* Company: ZSoft (www.zsoft.ru), Prudsys (www.prudsys.com)
* @author Valentine Stepanenko (valentine.stepanenko@zsoft.ru)
* @version 1.0
*/
package com.prudsys.pdm.Models.Classification.DecisionTree.Algorithms.Id3;
import com.prudsys.pdm.Core.CategoricalAttribute;
import com.prudsys.pdm.Core.Category;
import com.prudsys.pdm.Core.CategoryProperty;
import com.prudsys.pdm.Core.MiningAttribute;
import com.prudsys.pdm.Core.MiningDataSpecification;
import com.prudsys.pdm.Core.MiningException;
import com.prudsys.pdm.Input.MiningInputStream;
import com.prudsys.pdm.Input.MiningStoredData;
import com.prudsys.pdm.Input.MiningVector;
import com.prudsys.pdm.Input.Predicates.SimplePredicate;
import com.prudsys.pdm.Models.Classification.DecisionTree.DecisionTreeNode;
/**
* Recursive implementation of ID3 algorithm.
*/
public class ID3TreeNode extends DecisionTreeNode
{
private int numberOfMiningAttributes;
private CategoricalAttribute classificationAttribute;
private int numberOfClassValues;
private int indesOfClassAttribute;
//private ArrayList miningVectors;
private MiningStoredData miningVectors;
/** Attribute used for splitting. */
private CategoricalAttribute splittingAttribute = null;
/** Splitting attribute index. */
private int splittingAttributeIndex = -1;
/** Class value if node is leaf. */
private double predictedValue;
/** Class distribution if node is leaf. */
private double[] distribution;
private int predicateAttributeIndex = -1;
/**
* Empty constructor.
*/
public ID3TreeNode()
{
}
/**
* Constructor with given data.
*
* @param miningStoredData mining input stream data
* @param metaData meta data
* @param target target attribute
* @throws MiningException could not run the constructor
*/
public ID3TreeNode( MiningInputStream miningStoredData, MiningDataSpecification metaData, MiningAttribute target ) throws MiningException
{
//this.miningVectors = miningminingInputStream.getCollection();
this.miningVectors = new MiningStoredData( miningStoredData );
this.metaData = metaData;
this.classificationAttribute = (CategoricalAttribute)target;
this.target = target;
numberOfClassValues = classificationAttribute.getCategoriesNumber();
numberOfMiningAttributes = metaData.getAttributesNumber();
indesOfClassAttribute = metaData.getAttributeIndex( classificationAttribute );
}
/**
* Method building DecisionTree tree.
*
* @param data the training data
* @exception MiningException if decision tree can't be built successfully
*/
public void buildTree() throws MiningException
{
int numberOfMiningVectors = miningVectors.size();
// Check if no instances have reached this node.
if( numberOfMiningVectors == 0 )
{
leaf = true;
predictedValue = Category.MISSING_VALUE;
distribution = new double[numberOfClassValues];
splittingAttribute = null;
}
else
{
// Compute attribute with maximum information gain.
double[] infoGains = new double[numberOfMiningAttributes];
for(int i = 0; i < numberOfMiningAttributes; i++)
{
if( i == indesOfClassAttribute || i == predicateAttributeIndex )
{
infoGains[i] = -1;
}
else
{
CategoricalAttribute calculatedAttribute = (CategoricalAttribute)metaData.getMiningAttribute( i );
//----------------------------It's possible to call method computeInfoGain
//----------------------------from the I3Utils in this point. But to increase
//----------------------------performance we use in-line code.
//infoGains[i] = DTTreeUtils.computeInfoGain( miningVectors, calculatedAttribute, classificationAttribute );
int numberOfCategories = calculatedAttribute.getCategoriesNumber();
int numberOfVectors = miningVectors.size();
double infoGain = ID3TreeUtils.computeEntropy( miningVectors, classificationAttribute );
//----------------------------It's possible to call method splitData
//----------------------------from the I3Utils in this point. But to increase
//----------------------------performance we use in-line code.
//ArrayList[] splitedByCategory = DTTreeUtils.splitData( miningVectors, calculatedAttribute );
MiningStoredData[] splitedByCategory = new MiningStoredData[ numberOfCategories ];
for(int j = 0; j < numberOfCategories; j++)
{
splitedByCategory[j] = new MiningStoredData();
}
for (int j = 0; j < numberOfVectors; j++)
{
MiningVector vector = (MiningVector)miningVectors.get( j );
int attributeValue = (int)vector.getValue( calculatedAttribute );
splitedByCategory[attributeValue].add( vector );
}
//----------------------------end of splitData
int size;
for (int j = 0; j < numberOfCategories; j++)
{
size = splitedByCategory[j].size();
if( size > 0)
{
infoGain -= ((double)size/(double)numberOfVectors) * ID3TreeUtils.computeEntropy(splitedByCategory[j], classificationAttribute);
}
}
infoGains[i] = infoGain;
//----------------------------end of computeInfoGain
}
}
int index = ID3TreeUtils.maxIndex(infoGains);
double maxInfoGains = infoGains[index];
splittingAttribute = (CategoricalAttribute)metaData.getMiningAttribute( index );
splittingAttributeIndex = index;
// Make leaf if information gain is zero.
// Otherwise create successors.
if( ID3TreeUtils.eq(maxInfoGains, 0) )
{
leaf = true;
distribution = new double[numberOfClassValues];
for (int i = 0; i < numberOfMiningVectors; i++)
{
MiningVector vector = (MiningVector)miningVectors.get( i );
distribution[(int) vector.getValue( indesOfClassAttribute )]++;
}
ID3TreeUtils.normalize( distribution );
predictedValue = ID3TreeUtils.maxIndex(distribution);
String predictedScore = classificationAttribute.getCategory( predictedValue ).toString();
MiningAttribute target = metaData.getMiningAttribute(indesOfClassAttribute);
setScore(((CategoricalAttribute)target).getKey(new Category(predictedScore,predictedScore,new CategoryProperty(CategoryProperty.VALID))));
}
else
{
leaf = false;
// changes to set a score attribute for nodes in PMML for XML parsers validation
setScore(-1.);
//predictedScore = "meaningless";
//----------------------------It's possible to call method splitData
//----------------------------from the I3Utils in this point. But to increase
//----------------------------performance we use in-line code.
//ArrayList[] splitData = DTTreeUtils.splitData(miningVectors, splittingAttribute);
int numberOfCategories = splittingAttribute.getCategoriesNumber();
int numberOfVectors = miningVectors.size();
MiningStoredData[] splitData = new MiningStoredData[ numberOfCategories ];
for(int j = 0; j < numberOfCategories; j++)
{
splitData[j] = new MiningStoredData();
}
for (int j = 0; j < numberOfVectors; j++)
{
MiningVector vector = (MiningVector)miningVectors.get( j );
int attributeValue = (int)vector.getValue( splittingAttribute );
splitData[attributeValue].add( vector );
}
//----------------------------end of splitData
int numberOfSplitCategories = splittingAttribute.getCategoriesNumber();
children = new ID3TreeNode[numberOfSplitCategories];
for(int j = 0; j < numberOfSplitCategories; j++)
{
ID3TreeNode id3node = new ID3TreeNode();
//children[j] = new DTTreeNode();
id3node.setParent( this );
id3node.setMiningVectors( splitData[j] );
id3node.setPredicate(new SimplePredicate(splittingAttribute,splittingAttribute.getCategory(j).toString(),SimplePredicate.EQUAL));
// id3node.predicateAttributeName = splittingAttribute.getName();
// id3node.predicateAttributeValue = splittingAttribute.getValue(j).toString();
id3node.predicateAttributeIndex = splittingAttributeIndex;
id3node.setMetaData( metaData );
id3node.setClassificationAttribute( classificationAttribute );
id3node.setNumberOfClassValues( numberOfClassValues );
id3node.setIndesOfClassAttribute( indesOfClassAttribute );
id3node.setNumberOfMiningAttributes( numberOfMiningAttributes );
id3node.setTarget(target);
id3node.buildTree();
children[j] = id3node;
}
}
}
miningVectors = null;
}
/**
* Classifies a given test vector using the decision tree.
*
* @param vector the vector to be classified
* @return the classification result
*/
public double classify( MiningVector vector ) throws MiningException
{
int index;
if( leaf )
{
return predictedValue;
}
else
{
index = (int)vector.getValue( splittingAttributeIndex );
return ((ID3TreeNode)children[index]).classify( vector );
}
}
/**
* Computes class distribution for instance using decision tree.
*
* @param instance the instance for which distribution is to be computed
* @return the class distribution for the given instance
*/
public double[] distributionForVector( MiningVector vector )
{
int index;
if( leaf )
{
return distribution;
}
else
{
index = (int)vector.getValue( splittingAttributeIndex );
return ((ID3TreeNode)children[index]).distributionForVector( vector );
}
}
public double[] getDistribution()
{
return distribution;
}
public double getPredictedValue()
{
return predictedValue;
}
public CategoricalAttribute getSplittingAttribute()
{
return splittingAttribute;
}
public CategoricalAttribute getClassificationAttribute()
{
return classificationAttribute;
}
public void setClassificationAttribute(CategoricalAttribute classificationAttribute)
{
this.classificationAttribute = classificationAttribute;
}
public MiningStoredData reads()
{
return miningVectors;
}
public void setMiningVectors(MiningStoredData miningVectors)
{
this.miningVectors = miningVectors;
}
public void setNumberOfMiningAttributes(int numberOfMiningAttributes)
{
this.numberOfMiningAttributes = numberOfMiningAttributes;
}
public void setNumberOfClassValues(int numberOfClassValues)
{
this.numberOfClassValues = numberOfClassValues;
}
public void setIndesOfClassAttribute(int indesOfClassAttribute)
{
this.indesOfClassAttribute = indesOfClassAttribute;
}
/**
* Prints the decision tree using the private toString method from below.
*
* @return a textual description of the classifier
*/
public String toString()
{
StringBuffer text = new StringBuffer();
// if( leaf )
// {
// text.append( "if" + " " + predicateAttributeName +
// " = " +
// "'" + predicateAttributeValue + "'" + " " + "then" + " " +
// "class" +
// " = " +
// "'" + predictedScore + "'");
// }
// else
// {
// if( parent == null )
// {
// text.append( "root" );
// }
// else
// {
// text.append( "if" + " " + predicateAttributeName +
// " = " +
// "'" + predicateAttributeValue + "'" + " " + "then");
// }
// }
if( leaf )
{
text.append( "if "+predicate.toString()+" then class = '"+getScoreString()+"'");
}
else
{
if(parent == null) text.append("root");
else
text.append("if "+predicate.toString()+" then");
}
return text.toString();
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -