📄 decisiontree.cs
字号:
using System;
using System.Data;
using System.Collections;
namespace DecisionTreeID3
{
/// <summary>
/// Class1 的摘要说明。
/// </summary>
public class Attribute
{
ArrayList mValues;
string mName;
object mLabel;
/// <summary>
/// Inicializa uma nova inst鈔cia de uma classe Atribute
/// </summary>
/// <param name="name">Indica o nome do atributo</param>
/// <param name="values">Indica os valores poss韛eis para o atributo</param>
public Attribute(string name, string[] values)
{
mName = name;
mValues = new ArrayList(values);
mValues.Sort();
}
public Attribute(object Label)
{
mLabel = Label;
mName = string.Empty;
mValues = null;
}
/// <summary>
/// Indica o nome do atributo
/// </summary>
public string AttributeName
{
get
{
return mName;
}
}
/// <summary>
/// Retorna um array com os valores do atributo
/// </summary>
public string[] values
{
get
{
if (mValues != null)
return (string[])mValues.ToArray(typeof(string));
else
return null;
}
}
/// <summary>
/// Indica se um valor ?permitido para este atributo
/// </summary>
/// <param name="value"></param>
/// <returns></returns>
public bool isValidValue(string value)
{
return indexValue(value) >= 0;
}
/// <summary>
/// Retorna o 韓dice de um valor
/// </summary>
/// <param name="value">Valor a ser retornado</param>
/// <returns>O valor do 韓dice na qual a posi玢o do valor se encontra</returns>
public int indexValue(string value)
{
if (mValues != null)
return mValues.BinarySearch(value);
else
return -1;
}
/// <summary>
///
/// </summary>
/// <returns></returns>
public override string ToString()
{
if (mName != string.Empty)
{
return mName;
}
else
{
return mLabel.ToString();
}
}
}
/// <summary>
/// Classe que representar?a arvore de decis鉶 montada;
/// </summary>
public class TreeNode
{
private ArrayList mChilds = null;
private Attribute mAttribute;
public TreeNode(Attribute attribute)
{
if (attribute.values != null)
{
mChilds = new ArrayList(attribute.values.Length);
for (int i = 0; i < attribute.values.Length; i++)
mChilds.Add(null);
}
else
{
mChilds = new ArrayList(1);
mChilds.Add(null);
}
mAttribute = attribute;
}
//添加节点
public void AddTreeNode(TreeNode treeNode, string ValueName)
{
int index = mAttribute.indexValue(ValueName);
mChilds[index] = treeNode;
}
public int totalChilds
{
get
{
return mChilds.Count;
}
}
public TreeNode getChild(int index)
{
return (TreeNode)mChilds[index];
}
public Attribute attribute
{
get
{
return mAttribute;
}
}
/// 获取分支名称
public TreeNode getChildByBranchName(string branchName)
{
int index = mAttribute.indexValue(branchName);
return (TreeNode)mChilds[index];
}
}
//ID3算法
public class DecisionTreeID3
{
private DataTable mSamples;
private int mTotalPositives = 0;
private int mTotal = 0;
private string mTargetAttribute = "result";
private double mEntropySet = 0.0;
/// 返回positive的数目,用于支持一个划分的字表
private int countTotalPositives(DataTable samples)
{
int result = 0;
foreach (DataRow aRow in samples.Rows)
{
if ((bool)aRow[mTargetAttribute] == true)
result++;
}
return result;
}
/// 计算分类所需要的期望信息
private double calcEntropy(int positives, int negatives)
{
int total = positives + negatives;
double ratioPositive = (double)positives/total;
double ratioNegative = (double)negatives/total;
if (ratioPositive != 0)
ratioPositive = -(ratioPositive) * System.Math.Log(ratioPositive, 2);
if (ratioNegative != 0)
ratioNegative = - (ratioNegative) * System.Math.Log(ratioNegative, 2);
double result = ratioPositive + ratioNegative;
return result;
}
//获取样本的positives和negatives的分布,即具体个数
private void getValuesToAttribute(DataTable samples, Attribute attribute, string value, out int positives, out int negatives)
{
positives = 0;
negatives = 0;
foreach (DataRow aRow in samples.Rows)
{
if ( ((string)aRow[attribute.AttributeName] == value) )
if ( (bool)aRow[mTargetAttribute] == true)
positives++;
else
negatives++;
}
}
/// 计算对一个给定的样本分类所需要的期望信息(熵),并计算划分的信息增益
private double gain(DataTable samples, Attribute attribute)
{
string[] values = attribute.values;
double sum = 0.0;
for (int i = 0; i < values.Length; i++)
{
int positives, negatives;
positives = negatives = 0;
getValuesToAttribute(samples, attribute, values[i], out positives, out negatives);
double entropy = calcEntropy(positives, negatives);
sum += -(double)(positives + negatives)/mTotal * entropy;
}
return mEntropySet + sum;
}
//选取最大的信息增益
private Attribute getBestAttribute(DataTable samples, Attribute[] attributes)
{
double maxGain = 0.0;
Attribute result = null;
foreach (Attribute attribute in attributes)
{
double aux = gain(samples, attribute);
if (aux > maxGain)
{
maxGain = aux;
result = attribute;
}
}
return result;
}
private bool allSamplesPositives(DataTable samples, string targetAttribute)
{
foreach (DataRow row in samples.Rows)
{
if ( (bool)row[targetAttribute] == false)
return false;
}
return true;
}
private bool allSamplesNegatives(DataTable samples, string targetAttribute)
{
foreach (DataRow row in samples.Rows)
{
if ( (bool)row[targetAttribute] == true)
return false;
}
return true;
}
/// <summary>
/// Retorna uma lista com todos os valores distintos de uma tabela de amostragem
/// </summary>
/// <param name="samples">DataTable com as amostras</param>
/// <param name="targetAttribute">Atributo (coluna) da tabela a qual ser?verificado</param>
/// <returns>Um ArrayList com os valores distintos</returns>
private ArrayList getDistinctValues(DataTable samples, string targetAttribute)
{
ArrayList distinctValues = new ArrayList(samples.Rows.Count);
foreach(DataRow row in samples.Rows)
{
if (distinctValues.IndexOf(row[targetAttribute]) == -1)
distinctValues.Add(row[targetAttribute]);
}
return distinctValues;
}
//返回样本中最普通的类
private object getMostCommonValue(DataTable samples, string targetAttribute)
{
ArrayList distinctValues = getDistinctValues(samples, targetAttribute);
int[] count = new int[distinctValues.Count];
foreach(DataRow row in samples.Rows)
{
int index = distinctValues.IndexOf(row[targetAttribute]);
count[index]++;
}
int MaxIndex = 0;
int MaxCount = 0;
for (int i = 0; i < count.Length; i++)
{
if (count[i] > MaxCount)
{
MaxCount = count[i];
MaxIndex = i;
}
}
return distinctValues[MaxIndex];
}
/// 长出一个划分,包含样本的表
private TreeNode internalMountTree(DataTable samples, string targetAttribute, Attribute[] attributes)
{
if (allSamplesPositives(samples, targetAttribute) == true)
return new TreeNode(new Attribute(true));
if (allSamplesNegatives(samples, targetAttribute) == true)
return new TreeNode(new Attribute(false));
if (attributes.Length == 0)
return new TreeNode(new Attribute(getMostCommonValue(samples, targetAttribute)));
mTotal = samples.Rows.Count;
mTargetAttribute = targetAttribute;
mTotalPositives = countTotalPositives(samples);
mEntropySet = calcEntropy(mTotalPositives, mTotal - mTotalPositives);
Attribute bestAttribute = getBestAttribute(samples, attributes);
TreeNode root = new TreeNode(bestAttribute);
DataTable aSample = samples.Clone();
foreach(string value in bestAttribute.values)
{
aSample.Rows.Clear();
DataRow[] rows = samples.Select(bestAttribute.AttributeName + " = " + "'" + value + "'");
foreach(DataRow row in rows)
{
aSample.Rows.Add(row.ItemArray);
}
// 加上除bestAttribute以外的属性
ArrayList aAttributes = new ArrayList(attributes.Length - 1);
for(int i = 0; i < attributes.Length; i++)
{
if (attributes[i].AttributeName != bestAttribute.AttributeName)
aAttributes.Add(attributes[i]);
}
// Cria uma nova lista de atributos menos o atributo corrente que ?o melhor atributo
//如果样本的集合为空,则加一个树叶,标记样本中最普通的类
if (aSample.Rows.Count == 0)
{
return new TreeNode(new Attribute(getMostCommonValue(aSample, targetAttribute)));
}
//否则递归
else
{
DecisionTreeID3 dc3 = new DecisionTreeID3();
TreeNode ChildNode = dc3.mountTree(aSample, targetAttribute, (Attribute[])aAttributes.ToArray(typeof(Attribute)));
root.AddTreeNode(ChildNode, value);
}
}
return root;
}
/// 长出一个子树
public TreeNode mountTree(DataTable samples, string targetAttribute, Attribute[] attributes)
{
mSamples = samples;
return internalMountTree(mSamples, targetAttribute, attributes);
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -