📄 decisiontree.cpp
字号:
#include <Stdafx.h>
#include "DecisionTree.h"
DecisionTree:: DecisionTree()
{ //构造函数
root=new Node();
numOfIns=0;
numOfAttr=0;
trainingSet=NULL;
weight=NULL;
used=NULL;
numOfNodes=0;
}
DecisionTree:: ~DecisionTree()
{ //析构函数
DeleteTree(root);
if(trainingSet!=NULL)
{
delete []trainingSet;
trainingSet=NULL;
}
if(used!=NULL)
{
delete []used;
used=NULL;
}
}
Node* DecisionTree:: GetRoot()
{
return this.root;
}
int DecisionTree:: GetnumOfIns()
{
return this.numOfIns;
}
int DecisionTree:: GetnumOfAttr()
{
return this.numOfAttr;
}
void DecisionTree:: SetnumOfIns(int numOfIns)
{
this.numOfIns=numOfIns;
}
void DecisionTree:: SetnumOfAttr(int numOfAttr)
{
this.numOfAttr=numOfAttr;
}
int DecisionTree:: GetnumOfNodes()
{
return this.numOfNodes;
}
void DecisionTree:: DeleteTree(Node* root)
{
if(root->leftchild!=NULL){
deleteTree(root->leftchild);
}
if(root->rightchild!=NULL)
deleteTree(root->rightchild);
delete root;
}
/*double DecisionTree:: Entropy(int attribute,int start,int end,double** trainingSet)
{ //由非类别属性attribute划分子集的熵
int number=Count(attribute,start,end,trainingSet);
int* count=new int[number];//记录各属性值数目
for(i=0;i<number;i++)
count[i]=0;
int* attrs=new int[number];//记录各属性值
int n=0;
attrs[0]=trainingSet[start][attribute];
n++;
for(i=start+1;i<=end&&n<number;i++)
{
for(j=0;j<n;j++)
{
if(trainingSet[i][attribute]==attrs[j])
break;
}
if(j==n)
{
attrs[n]=trainingSet[i][attribute];
n++;
}
}
int** temp=new int*[number];
for(i=0;i<number;i++)
temp[i]=new int[2];
for(i=start;i<=end;i++)
{
for(j=0;j<number;j++)
{
if(trainingSet[i][attribute]==attrs[j])
{
count[j]++;
temp[j][trainingSet[i][numOfAttr-1]]++;
}
}
}
double entropy;
for(i=0;i<number;i++)
{
entropy+=count[i]/(end-start+1)*(-temp[i][0]/count[i]*(log(temp[i][0]/count[i])/log(2))-temp[i][1]/count[i]*(log(temp[i][1]/count[i])/log(2)));
}
return entropy;
}*/
double DecisionTree:: SplitInfo(int attribute,int start,int end,double** trainingSet)
{ //(分裂信息,S关于属性attribute的各值的熵) 这里只作任意样本的分类信息
/*int number=Count(attribute,start,end,trainingSet);
int* count=new int[number];//记录各属性值数目
for(i=0;i<number;i++)
count[i]=0;
int* attrs=new int[number];//记录各属性值
int n=0;
attrs[0]=trainingSet[start][attribute];
n++;
for(i=start+1;i<=end&&n<number;i++)
{
for(j=0;j<n;j++)
{
if(trainingSet[i][attribute]==attrs[j])
break;
}
if(j==n)
{
attrs[n]=trainingSet[i][attribute];
n++;
}
}
for(i=start;i<=end;i++)
{
for(j=0;j<number;j++)
{
if(trainingSet[i][attribute]==attrs[j])
count[j]++;
}
}
double splitinfo;
for(i=0;i<=number;i++)
splitinfo+=-count[i]/(end-start+1)*(log(count[i]/(end-start+1))/log(2);*/
double splitinfo;
int countneg=0;
int countpos=0;
for(int i=start;i<=end;i++)
{
if(trainingSet[i][attribute]==0)
countneg++;
else if(trainingSet[i][attribute]==1)
countpos++;
}
splitinfo=-countneg/(end-start+1)*(log(countneg/(end-start+1))/log(2))-countpos/(end-start+1)*(log(countpos/(end-start+1))/log(2));
return splitinfo;
}
/*double DecisionTree:: Gain(int attribute,int start,int end, double** trainingSet)
{ //信息增益
double gain=0;
sort(attribute,start,end,trainingSet);
int i;
double* gain=new double[end-start];
double info=SplitInfo(numOfAttr-1,start,end,trainingSet);
for(i=start;i<end;i++)
{
int entropy=(i-start+1)/(end-start+1)*SplitInfo(numOfAttr-1,start,i,trainingSet)+(end-i)/(end-start+1)*SplitInfo(numOfAttr-1,i+1,end,trainingSet);
gain[i-start]=info-entropy;
if(gain[i-start]>gain)
{
gain=gain[i-start];
}
}
return gain;
}*/
double DecisionTree:: GainRatio(int attribute, int start,int end,double** trainingSet)
{ //信息增益率
double gainratio;
double gain=0;
sort(attribute,start,end,trainingSet);
int i;
int splitpoint;
double* g=new double[end-start];
double info=SplitInfo(numOfAttr-1,start,end,trainingSet);
for(i=start;i<end;i++)
{
int entropy=(i-start+1)/(end-start+1)*SplitInfo(numOfAttr-1,start,i,trainingSet)+(end-i)/(end-start+1)*SplitInfo(numOfAttr-1,i+1,end,trainingSet);
g[i-start]=info-entropy;
if(g[i-start]>gain)
{
gain=g[i-start];
splitpoint=i;
}
}
double splitinfo; //分裂信息
splitinfo=-(splitpoint-start+1)/(end-start+1)*(log((splitpoint-start+1)/(end-start+1))/log(2))-(end-splitpoint)/(end-start+1)*(log((end-splitpoint)/(end-start+1))/log(2));
gainratio=gain/splitinfo;
return gainratio;
}
void DecisionTree:: sort(int attribute,int start,int end,double** trainingSet)
{
int test;
for(test=start; test<end; test++)
if( trainingSet[test][attribute]>trainingSet[test+1][attribute]){
break;
}
if(test==end) return;
int i,j;
i=start;
j=end;
double* pivot=new double[numOfAttr];
for(int kk =0;kk<numOfAttr;kk++)
pivot[kk]= trainingSet[start][kk];
while(i<j) {
while(i<j && pivot[attribute] <= trainingSet[j][attribute]) j--;
if(i<j)
{
int s = i++;
for(int k = 0;k<5;k++)
trainingSet[s][k] = trainingSet[j][k];
}
while(i<j && pivot[attribute] > trainingSet[i][attribute]) i++;
if(i<j)
{
int s = j--;
for(int k =0;k<5;k++)
trainingSet[s][k]=trainingSet[i][k];
}
}
for(int kk =0;kk<5;kk++)
trainingSet[i][kk] = pivot[kk];
if(i-start>1) sort(attribute,start,i-1,trainingSet);
if(end-j>1) sort(attribute,j+1,end,trainingSet);
}
int DecisionTree:: Count(int attribute,int start,int end,double** trainingSet)
{
/*int number=end-start+1; //计算离散属性值attribute的数目
int i;
int j;
for(i=start;i<=end;i++)
{
for(j=start;j<i;j++)
{
if(trainingSet[i][attribute]==trainingSet[j][attribute])
{
number--;
break;
}
}
}*/
int number;
int temp=int(trainingSet[start][attribute]);
int i;
for(i=start+1;i<=end;i++)
{
if(trainingSet[i][attribute]!=temp)
break;
}
if(i<=end)
number=2;
else number=1;
return number;
}
void DecisonTree:: BuildTree()
{
this->root=GenerateTree(0,numOfIns-1,trainingSet);
cout<<"This tree is built"<<endl;
}
Node* DecisionTree:: GenerateTree(int start,int end,double** trainingSet)
{
if(trainingSet==NULL) //训练集为空
{
return NULL;
}
if(Count(numOfAttr-1,start,end,trainingSet)==1) //所有样本属于同一类别
{
Node* node= new Node();
node->isLeaf = 1;
node->attribute = int(trainingSet[start][numOfAttr-1]);
if(node->attribute==1)
node->value=1;
else node->value=0;
numOfNodes++;
return node;
}
if(numOfAttr==1) //只有类别属性,返回一个结点,其值是出现最多的属性值
{
int countneg=0;
int countpos=0;
for(int i=start;i<=end;i++)
{
if(trainingSet[i][numOfAttr-1]==0)
countneg++;
else if(trainingSet[i][numOfAttr-1]==1)
countpos++;
}
int more;
if(countneg>countpos)
more=0;
else
more=1;
Node* node=new Node();
node->isLeaf=1;
node->attribute=more;
node->value=countpos/(end-start+1);
numOfNodes++;
return node;
}
else //选择信息增益比最大的属性进行分类
{
int* used=new int[numOfAttr];
for(int n=0;n<numOfAttr;n++)
used[n]=0;
double gainratio=0;
int attribute;
for(int i=0;i<numOfAttr-1;i++)
{
sort(i,start,end,trainingSet);
int j;
double* g=new double[end-start];
double info=SplitInfo(numOfAttr-1,start,end,trainingSet);
int* splitpoint=new int[numOfAttr-1]; //记录分割点
for(j=start;j<end;j++)
{
int entropy=(j-start+1)/(end-start+1)*SplitInfo(numOfAttr-1,start,j,trainingSet)+(end-j)/(end-start+1)*SplitInfo(numOfAttr-1,j+1,end,trainingSet);
g[j-start]=info-entropy;
if(g[j-start]>gain)
{
gain=g[j-start];
splitpoint[i]=j;
}
}
if(!used[i]&&gainratio<GainRatio(i,start,end,trainingSet))//选择信息增益比最大的属性
{
gainratio=GainRatio(i,start,end,trainingSet);
attribute=i;
}
}
sort(attribute,start,end,trainingSet);
Node* node=new Node();
used[attribute]=1;
numOfNodes++;
if(numOfNodes=4||numOfNodes=5||numOfNodes=7||numOfNodes=8||numOfNodes=11||numOfNodes=12||numOfNodes=14||numOfNodes=15)//停止生长,置当前结点为叶子结点(按先根序列,这些结点为第3层结点(根结点为第0层))
{
node->isLeaf=1;
int numOfPos=0;
int numOfNeg=0;
for(int i=start;i<=end;i++)
{
if(trainingSet[i][numOfAttr-1]==1)
numOfPos++;
else numOfNeg++;
}
node->value=numOfPos/(end-start+1);//有病的概率
if(numOfPos>=numOfNeg)
node->attribute=1;
else node->attribute=0;
}
else//继续生长
{
node->value=trainingSet[splitpoint[attribute]][attribute];
node->attribute=attribute;
node->isLeaf=0;
node->leftchild=GenerateTree(start,splitpoint[attribute],trainingSet);
node->rightchild=GenerateTree(splitpoint[attribute]+1,end,trainingSet);
}
return node;
}
}
double DecisionTree:: ClassifyIns(double* instance)
{
if(root==NULL)
return -1;
Node* p=root;
while(p->isLeaf==0)
{
if(instance[p->attribute]<=p->value)
p=p->leftchild;
else
p=p->rightchild;
}
return p->value;
}
void DecisonTree:: SaveTree(ofstream& fout,Node* node)
{
if(node==NULL)
{
return;
}
fout<<node->isLeaf<<endl;
fout<<node->value<<endl;
fout<<node->attribute<<endl;
// cout<<"save a node"<<endl;
if(node->leftchild!=NULL)
SaveTree(fout,node->leftchild);
if(node->rightchild!=NULL)
SaveTree(fout,node->rightchild);
}
void DecisonTree:: SaveTree(char* fname)
{
ofstream fout(fname);
saveTree(fout,root);
fout.close();
}
Node* DecisonTree:: LoadNode(ifstream& fin)
{
int i = 0;
if(fin.eof())
{
return NULL;
}
else
{
int isLeaf = 0;
double value =0;
int attribute = 0;
fin>>isLeaf;
fin>>value;
fin>>attribute;
Node* node = new Node(value,attribute,isLeaf);
if(node->isLeaf ==1)
{
return node;
}
node->leftchild = LoadNode(fin);
node->rightchild = LoadNode(fin);
return node;
}
}
void CDecisonTree:: LoadTree(char* fname)
{
ifstream fin(fname);
root= LoadNode(fin);
fin.close();
}
void DecisonTree:: ShowTree(Node* node)
{
if(node ==NULL)
return;
cout<<"tree node"<<node->GetValue()<<" "<<node->GetAttribute()<<" "<<node->IsLeaf()<<endl;
if(node->leftchild!=NULL)
{
ShowTree(node->leftchild);
}
if(node->rightchild!=NULL)
ShowTree(node->rightchild);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -