📄 gentreenode.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/**
* Title: XELOPES Data Mining Library
* Description: The XELOPES library is an open platform-independent and data-source-independent library for Embedded Data Mining.
* Copyright: Copyright (c) 2002 Prudential Systems Software GmbH
* Company: ZSoft (www.zsoft.ru), Prudsys (www.prudsys.com)
* @author Valentine Stepanenko (ValentineStepanenko@zsoft.ru)
* @version 1.0
*/
package com.prudsys.pdm.Models.Classification.DecisionTree.Algorithms.GenTree;
import com.prudsys.pdm.Core.CategoricalAttribute;
import com.prudsys.pdm.Core.Category;
import com.prudsys.pdm.Core.MiningAttribute;
import com.prudsys.pdm.Core.MiningDataSpecification;
import com.prudsys.pdm.Core.MiningException;
import com.prudsys.pdm.Core.NumericAttribute;
import com.prudsys.pdm.Input.MiningInputStream;
import com.prudsys.pdm.Input.MiningStoredData;
import com.prudsys.pdm.Input.MiningVector;
import com.prudsys.pdm.Input.Predicates.Predicate;
import com.prudsys.pdm.Input.Predicates.SimplePredicate;
import com.prudsys.pdm.Models.Classification.DecisionTree.DecisionTreeNode;
import com.prudsys.pdm.Models.Classification.DecisionTree.DecisionTreeSettings;
/**
* Recursive implementation of general tree algorithm. Implementation
* not optimized for high speed yet.
*/
public class GenTreeNode extends DecisionTreeNode
{
/** Reference to algorithm which created this object. */
private GenTreeAlgorithm genTreeAlgorithm;
/** Part of training data set for this node. */
private MiningStoredData miningVectors;
/** Attribute used for splitting. */
private MiningAttribute splittingAttribute = null;
/** Splitting attribute index. */
private int splittingAttributeIndex = -1;
/** Index of last splitting attribute. */
private int prevSplittingAttributeIndex = -1;
/** Class distribution of node. */
private double[] distribution;
/** Current child node processed. */
public int currChild = 0;
/** Debug. 0 - no debugging. */
private int debug = 0;
/**
* Empty constructor.
*/
public GenTreeNode()
{
}
/**
* Constructor with initial data.
*
* @param genTreeAlgorithm reference to algorithm which created this node
* @param miningStoredData mining input stream data
* @param metaData meta data
* @param target target attribute, must be categorical
* @throws MiningException could not run the constructor
*/
public GenTreeNode(GenTreeAlgorithm genTreeAlgorithm,
MiningInputStream miningStoredData,
MiningDataSpecification metaData,
MiningAttribute target)
throws MiningException
{
this.metaData = metaData;
this.target = target;
this.genTreeAlgorithm = genTreeAlgorithm;
this.miningVectors = new MiningStoredData( miningStoredData );
}
/**
* Calculates score of node.
*
* @exception MiningException cannot calculate score in node
*/
public void calculateScore() throws MiningException {
// Init:
DecisionTreeSettings dtSettings = (DecisionTreeSettings) genTreeAlgorithm.getMiningSettings();
CategoricalAttribute classificationAttribute = (CategoricalAttribute)target;
int numberOfMiningVectors = miningVectors.size();
int indexOfClassAttribute = metaData.getAttributeIndex(classificationAttribute);
int numberOfClassValues = classificationAttribute.getCategoriesNumber();
boolean storeScoreDist = genTreeAlgorithm.isStoreScoreDistribution();
distribution = new double[numberOfClassValues];
// No mining vectors in node => score of parent, exit:
if ( numberOfMiningVectors == 0 )
{
leaf = true;
double predictedValue = Category.MISSING_VALUE;
GenTreeNode parent = (GenTreeNode) getParent();
if (parent != null)
predictedValue = parent.getScore();
setScore(predictedValue);
if (storeScoreDist) setDistribution( (double[]) distribution.clone() );
return;
};
// Calculate class distribution in node:
for (int i = 0; i < numberOfMiningVectors; i++)
{
MiningVector vector = (MiningVector)miningVectors.get( i );
distribution[(int) vector.getValue( indexOfClassAttribute )]++;
};
if (storeScoreDist) setDistribution( (double[]) distribution.clone() );
GenTreeUtils.normalize(distribution);
// Get predicted value:
double predictedValue = GenTreeUtils.maxIndex(distribution);
setScore(predictedValue);
// Termination criteria block:
if (numberOfMiningVectors == 1 ||
numberOfMiningVectors < dtSettings.getMinNodeSize() ||
getTotalNumberOfParents() == dtSettings.getMaxDepth() )
{
leaf = true;
return;
};
}
/**
* Method building Decision tree.
*
* @exception MiningException if decision tree can't be built successfully
*/
public void buildTree() throws MiningException
{
// Initializations:
DecisionTreeSettings dtSettings = (DecisionTreeSettings) genTreeAlgorithm.getMiningSettings();
CategoricalAttribute classificationAttribute = (CategoricalAttribute)target;
int numberOfMiningAttributes = metaData.getAttributesNumber();
int numberOfMiningVectors = miningVectors.size();
int indexOfClassAttribute = metaData.getAttributeIndex(classificationAttribute);
int numberOfClassValues = classificationAttribute.getCategoriesNumber();
splittingAttribute = null;
// Compute attribute with maximum information gain:
double[] impMeas = new double[numberOfMiningAttributes];
double[] cuts = new double[numberOfMiningAttributes];
for (int i = 0; i < numberOfMiningAttributes; i++)
{
// Class attribute => ignore:
if (i == indexOfClassAttribute)
{
impMeas[i] = -1;
}
// Compute information gain for current attribute:
else
{
MiningAttribute calculatedAttribute = metaData.getMiningAttribute(i);
double[] cut = new double[1];
impMeas[i] = GenTreeUtils.computeImpMeas(miningVectors,
calculatedAttribute, classificationAttribute, cut,
genTreeAlgorithm.impurityMeasureType,
genTreeAlgorithm.discreteType, 0);
cuts[i] = cut[0];
if (debug > 0)
System.out.println("attribute: " + calculatedAttribute.getName() +
" impMeas: " + impMeas[i] + " cutpoint: " + cuts[i]);
};
};
if (debug > 0) System.out.println();
// Find best splitting attribute:
int index = GenTreeUtils.maxIndex(impMeas);
double maxImpMeas = impMeas[index];
double bestCut = cuts[index];
splittingAttribute = metaData.getMiningAttribute( index );
splittingAttributeIndex = index;
// Make leaf if impurity measure is too small:
if ( maxImpMeas <= dtSettings.getMinDecreaseInImpurity() + GenTreeUtils.SMALL )
{
leaf = true;
miningVectors = null;
return;
};
// Create successors:
leaf = false;
// Split data:
MiningStoredData[] splitData = null;
if (splittingAttribute instanceof CategoricalAttribute)
splitData = GenTreeUtils.splitData(miningVectors, (CategoricalAttribute) splittingAttribute);
else {
double[] cut = {bestCut};
splitData = GenTreeUtils.splitData(miningVectors, (NumericAttribute) splittingAttribute, cut);
};
int numberOfSplitCategories = 0;
if (splittingAttribute instanceof CategoricalAttribute)
numberOfSplitCategories = ((CategoricalAttribute)splittingAttribute).getCategoriesNumber();
else
numberOfSplitCategories = 2;
// Create child nodes:
children = new GenTreeNode[numberOfSplitCategories];
for(int j = 0; j < numberOfSplitCategories; j++)
{
// New child:
GenTreeNode GTNode = new GenTreeNode(genTreeAlgorithm, splitData[j], metaData, target);
// Set topology of child:
GTNode.setParent( this );
GTNode.setLevel( level+1 );
// Set predicate of child:
Predicate pred = null;
if (splittingAttribute instanceof CategoricalAttribute)
pred = new SimplePredicate(splittingAttribute,
((CategoricalAttribute)splittingAttribute).getCategory(j).toString(),SimplePredicate.EQUAL);
else {
int op = SimplePredicate.LESS;
if (j == 1)
op = SimplePredicate.GREATER_OR_EQUAL;
pred = new SimplePredicate(splittingAttribute, String.valueOf(bestCut), op);
};
GTNode.setPredicate(pred);
GTNode.prevSplittingAttributeIndex = splittingAttributeIndex;
// Calculate score for child:
GTNode.calculateScore();
// Build tree again for child (if not leaf):
if (! genTreeAlgorithm.isIterativeMode() && ! GTNode.isLeaf() )
GTNode.buildTree();
// Assign child to parent:
children[j] = GTNode;
};
miningVectors = null;
}
/**
* Computes class distribution for vector using decision tree.
*
* @param vector the mining vector for which distribution is to be computed
* @return the class distribution for the given vector
*/
public double[] distributionForVector( MiningVector vector )
{
int index;
if( leaf )
{
return distribution;
}
else
{
index = (int)vector.getValue( splittingAttributeIndex );
return ((GenTreeNode)children[index]).distributionForVector( vector );
}
}
/**
* Returns DT algorithm cretead which has created this node.
*
* @return DT algorithm which has created this node
*/
public GenTreeAlgorithm getGenTreeAlgorithm()
{
return genTreeAlgorithm;
}
/**
* Sets DT algorithm which has created this node.
*
* @param genTreeAlgorithm DT algorithm which has created this node
*/
public void setGenTreeAlgorithm(GenTreeAlgorithm genTreeAlgorithm)
{
this.genTreeAlgorithm = genTreeAlgorithm;
}
/**
* Returns mining vectors dataset of this node.
*
* @return mining vectors dataset of this node
*/
public MiningStoredData getMiningVectors()
{
return miningVectors;
}
/**
* Set mining vectors dataset belonging to this node.
*
* @param miningVectors this dataset
*/
public void setMiningVectors(MiningStoredData miningVectors)
{
this.miningVectors = miningVectors;
}
/**
* Returns class distribution, i.e. number of all vectors assigned
* to the classes.
*
* @return class distribution
*/
public double[] getDistribution()
{
return distribution;
}
/**
* Returns splitting attribute of this node, if exist.
* If node is a leave, return null.
*
* @return splitting attribute of this node
*/
public MiningAttribute getSplittingAttribute()
{
return splittingAttribute;
}
/**
* Prints the decision tree using the private toString method from below.
*
* @return a textual description of the classifier
*/
public String toString()
{
StringBuffer text = new StringBuffer();
// if( leaf )
// {
// text.append( "if" + " " + predicateAttributeName +
// " = " +
// "'" + predicateAttributeValue + "'" + " " + "then" + " " +
// "class" +
// " = " +
// "'" + predictedScore + "'");
// }
// else
// {
// if( parent == null )
// {
// text.append( "root" );
// }
// else
// {
// text.append( "if" + " " + predicateAttributeName +
// " = " +
// "'" + predicateAttributeValue + "'" + " " + "then");
// }
// }
if( leaf )
{
text.append( "if "+predicate.toString()+" then class = '"+getScoreString()+"'");
}
else
{
if(parent == null) text.append("root");
else
text.append("if "+predicate.toString()+" then");
}
return text.toString();
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -