📄 id3treeutils.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/**
* Title: XELOPES Data Mining Library
* Description: The XELOPES library is an open platform-independent and data-source-independent library for Embedded Data Mining.
* Copyright: Copyright (c) 2002 Prudential Systems Software GmbH
* Company: ZSoft (www.zsoft.ru), Prudsys (www.prudsys.com)
* @author Valentine Stepanenko (valentine.stepanenko@zsoft.ru)
* @version 1.0
*/
package com.prudsys.pdm.Models.Classification.DecisionTree.Algorithms.Id3;
import com.prudsys.pdm.Core.CategoricalAttribute;
import com.prudsys.pdm.Core.MiningException;
import com.prudsys.pdm.Input.MiningStoredData;
import com.prudsys.pdm.Input.MiningVector;
/**
* Class implementing some usefull methods for decision tree classifier.
* @author Valentine Stepanenko
*/
public class ID3TreeUtils
{
/** The natural logarithm of 2. */
public static double log2 = Math.log(2);
/** The small deviation allowed in double comparisons */
public static double SMALL = 1e-6;
public ID3TreeUtils()
{
}
/**
* Computes information gain for an attribute.
*/
public static double computeInfoGain( MiningStoredData miningVectors,
CategoricalAttribute calculatedAttribute,
CategoricalAttribute classificationAttribute)
throws MiningException
{
int numberOfCategories = calculatedAttribute.getCategoriesNumber();
int numberOfVectors = miningVectors.size();
double infoGain = computeEntropy( miningVectors, classificationAttribute );
MiningStoredData[] splitedByCategory = splitData( miningVectors, calculatedAttribute );
int size;
for (int j = 0; j < numberOfCategories; j++)
{
size = splitedByCategory[j].size();
if( size > 0)
{
infoGain -= ((double)size/(double)numberOfVectors) * computeEntropy(splitedByCategory[j], classificationAttribute);
}
}
return infoGain;
}
/**
* Computes the entropy of a dataset.
*/
public static double computeEntropy( MiningStoredData miningVectors, CategoricalAttribute classificationAttribute ) throws MiningException
{
int numberValues = classificationAttribute.getCategoriesNumber();
int numberVectors = miningVectors.size();
double [] classCounts = new double[numberValues];
for (int i = 0; i < numberVectors; i++)
{
MiningVector vector = (MiningVector)miningVectors.get(i);
int index = (int) vector.getValue( classificationAttribute );
classCounts[index]++;
}
double entropy = 0;
for (int j = 0; j < numberValues; j++)
{
if(classCounts[j] > 0)
{
entropy -= classCounts[j] * ID3TreeUtils.log2(classCounts[j]);
}
}
entropy /= (double)numberVectors;
return entropy + ID3TreeUtils.log2(numberVectors);
}
/**
* Splits a dataset according to the values of a nominal attribute.
*/
public static MiningStoredData[] splitData( MiningStoredData miningVectors, CategoricalAttribute splitingAttribute )
{
int numberValues = splitingAttribute.getCategoriesNumber();
int numberVectors = miningVectors.size();
MiningStoredData[] splitData = new MiningStoredData[ numberValues ];
for(int j = 0; j < numberValues; j++)
{
splitData[j] = new MiningStoredData();
}
for (int i = 0; i < numberVectors; i++)
{
MiningVector vector = (MiningVector)miningVectors.get( i );
int attributeValue = (int)vector.getValue( splitingAttribute );
splitData[attributeValue].add( vector );
}
return splitData;
}
/**
* Returns index of maximum element in a given
* array of doubles. First maximum is returned.
*
* @param doubles the array of doubles
* @return the index of the maximum element
*/
public static int maxIndex(double [] doubles)
{
double maximum = 0;
int maxIndex = 0;
for(int i = 0; i < doubles.length; i++)
{
if((i == 0) || (doubles[i] > maximum))
{
maxIndex = i;
maximum = doubles[i];
}
}
return maxIndex;
}
/**
* Returns index of maximum element in a given
* array of integers. First maximum is returned.
*
* @param ints the array of integers
* @return the index of the maximum element
*/
public static int maxIndex(int [] ints)
{
int maximum = 0;
int maxIndex = 0;
for(int i = 0; i < ints.length; i++)
{
if((i == 0) || (ints[i] > maximum))
{
maxIndex = i;
maximum = ints[i];
}
}
return maxIndex;
}
/**
* Tests if a is equal to b.
*
* @param a a double
* @param b a double
*/
public static boolean eq(double a, double b)
{
return (a - b < SMALL) && (b - a < SMALL);
}
/**
* Normalizes the doubles in the array by their sum.
*
* @param doubles the array of double
* @exception IllegalArgumentException if sum is Zero or NaN
*/
public static void normalize(double[] doubles)
{
double sum = 0;
for(int i = 0; i < doubles.length; i++)
{
sum += doubles[i];
}
normalize(doubles, sum);
}
/**
* Normalizes the doubles in the array using the given value.
*
* @param doubles the array of double
* @param sum the value by which the doubles are to be normalized
* @exception IllegalArgumentException if sum is zero or NaN
*/
public static void normalize(double[] doubles, double sum)
{
if (Double.isNaN(sum))
{
throw new IllegalArgumentException("Can't normalize array. Sum is NaN.");
}
if(sum == 0)
{
// Maybe this should just be a return.
throw new IllegalArgumentException("Can't normalize array. Sum is zero.");
}
for(int i = 0; i < doubles.length; i++)
{
doubles[i] /= sum;
}
}
/**
* Returns the logarithm of a for base 2.
*
* @param a a double
*/
public static double log2(double a)
{
return Math.log(a) / log2;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -