⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 decisiontree.cpp

📁 c4.5决策树的实现,应用于一个医学诊断
💻 CPP
字号:
#include <Stdafx.h>
#include "DecisionTree.h"

DecisionTree:: DecisionTree()
{ //构造函数
	root=new Node();
	numOfIns=0;
	numOfAttr=0;
	trainingSet=NULL;
	weight=NULL;
	used=NULL;
	numOfNodes=0;
}

DecisionTree:: ~DecisionTree()
{ //析构函数
	DeleteTree(root);
	if(trainingSet!=NULL)
	{
		delete []trainingSet;
		trainingSet=NULL;
	}
	if(used!=NULL)
	{
		delete []used;
		used=NULL;
	}
}

Node* DecisionTree:: GetRoot()
{
	return this.root;
}

int DecisionTree:: GetnumOfIns()
{
	return this.numOfIns;
}

int DecisionTree:: GetnumOfAttr()
{
	return this.numOfAttr;
}

void DecisionTree:: SetnumOfIns(int numOfIns)
{
	this.numOfIns=numOfIns;
}

void DecisionTree:: SetnumOfAttr(int numOfAttr)
{
	this.numOfAttr=numOfAttr;
}

int DecisionTree:: GetnumOfNodes()
{
	return this.numOfNodes;
}

void DecisionTree:: DeleteTree(Node* root)
{
	if(root->leftchild!=NULL){
		deleteTree(root->leftchild);
	}
	if(root->rightchild!=NULL)
		deleteTree(root->rightchild);
	delete root;
}

/*double DecisionTree:: Entropy(int attribute,int start,int end,double** trainingSet)
{ //由非类别属性attribute划分子集的熵
		int number=Count(attribute,start,end,trainingSet);
		int* count=new int[number];//记录各属性值数目
		for(i=0;i<number;i++)
			count[i]=0;
		int* attrs=new int[number];//记录各属性值
		int n=0;
		attrs[0]=trainingSet[start][attribute];
		n++;
		for(i=start+1;i<=end&&n<number;i++)
		{
			for(j=0;j<n;j++)
			{
				if(trainingSet[i][attribute]==attrs[j])
					break;
			}
			if(j==n)
			{
				attrs[n]=trainingSet[i][attribute];
				n++;
			}
		}
		int** temp=new int*[number];
		for(i=0;i<number;i++)
			temp[i]=new int[2];
		for(i=start;i<=end;i++)
		{
			for(j=0;j<number;j++)
			{
				if(trainingSet[i][attribute]==attrs[j])
				{
					count[j]++;
					temp[j][trainingSet[i][numOfAttr-1]]++;
				}
			}
		}
	    double entropy;
		for(i=0;i<number;i++)
		{
			entropy+=count[i]/(end-start+1)*(-temp[i][0]/count[i]*(log(temp[i][0]/count[i])/log(2))-temp[i][1]/count[i]*(log(temp[i][1]/count[i])/log(2)));
		}
		return entropy;	
}*/

double DecisionTree:: SplitInfo(int attribute,int start,int end,double** trainingSet)
{ //(分裂信息,S关于属性attribute的各值的熵) 这里只作任意样本的分类信息
		/*int number=Count(attribute,start,end,trainingSet);
		int* count=new int[number];//记录各属性值数目
		for(i=0;i<number;i++)
			count[i]=0;
		int* attrs=new int[number];//记录各属性值
		int n=0;
		attrs[0]=trainingSet[start][attribute];
		n++;
		for(i=start+1;i<=end&&n<number;i++)
		{
			for(j=0;j<n;j++)
			{
				if(trainingSet[i][attribute]==attrs[j])
					break;
			}
			if(j==n)
			{
				attrs[n]=trainingSet[i][attribute];
				n++;
			}
		}
		for(i=start;i<=end;i++)
		{
			for(j=0;j<number;j++)
			{
				if(trainingSet[i][attribute]==attrs[j])
					count[j]++;
			}
		}
		double splitinfo;
		for(i=0;i<=number;i++)
			splitinfo+=-count[i]/(end-start+1)*(log(count[i]/(end-start+1))/log(2);*/
	double splitinfo;
    int countneg=0;
	int countpos=0;
	for(int i=start;i<=end;i++)
	{
		if(trainingSet[i][attribute]==0)
			countneg++;
		else if(trainingSet[i][attribute]==1)
			countpos++;
	}
	splitinfo=-countneg/(end-start+1)*(log(countneg/(end-start+1))/log(2))-countpos/(end-start+1)*(log(countpos/(end-start+1))/log(2));
    return splitinfo;
}

/*double DecisionTree:: Gain(int attribute,int start,int end, double** trainingSet)
{  //信息增益
    double gain=0;
	sort(attribute,start,end,trainingSet);
	int i;
	double* gain=new double[end-start];
	double info=SplitInfo(numOfAttr-1,start,end,trainingSet);
	for(i=start;i<end;i++)
	{
		int entropy=(i-start+1)/(end-start+1)*SplitInfo(numOfAttr-1,start,i,trainingSet)+(end-i)/(end-start+1)*SplitInfo(numOfAttr-1,i+1,end,trainingSet);
		gain[i-start]=info-entropy;
		if(gain[i-start]>gain)
		{
			gain=gain[i-start];
		}
	}
    return gain;
}*/

double DecisionTree:: GainRatio(int attribute, int start,int end,double** trainingSet)
{  //信息增益率
	double gainratio;
	double gain=0;
	sort(attribute,start,end,trainingSet);
	int i;
	int splitpoint;
	double* g=new double[end-start];
	double info=SplitInfo(numOfAttr-1,start,end,trainingSet);
	for(i=start;i<end;i++)
	{
		int entropy=(i-start+1)/(end-start+1)*SplitInfo(numOfAttr-1,start,i,trainingSet)+(end-i)/(end-start+1)*SplitInfo(numOfAttr-1,i+1,end,trainingSet);
		g[i-start]=info-entropy;
		if(g[i-start]>gain)
		{
			gain=g[i-start];
			splitpoint=i;
		}
	}
	double splitinfo;  //分裂信息
	splitinfo=-(splitpoint-start+1)/(end-start+1)*(log((splitpoint-start+1)/(end-start+1))/log(2))-(end-splitpoint)/(end-start+1)*(log((end-splitpoint)/(end-start+1))/log(2));
    gainratio=gain/splitinfo; 
    return gainratio;
}

void DecisionTree:: sort(int attribute,int start,int end,double** trainingSet)
{
	int test;
	for(test=start; test<end; test++)
		if( trainingSet[test][attribute]>trainingSet[test+1][attribute]){
			break;
		}
	if(test==end) return;
    int i,j;
    i=start;  
    j=end; 
    double* pivot=new double[numOfAttr];
	for(int kk =0;kk<numOfAttr;kk++)
		pivot[kk]= trainingSet[start][kk];

    while(i<j) {
        while(i<j && pivot[attribute] <= trainingSet[j][attribute]) j--; 
        if(i<j) 
		{
			int s = i++;
			for(int k = 0;k<5;k++)
				trainingSet[s][k] = trainingSet[j][k];   
		}
        while(i<j && pivot[attribute] > trainingSet[i][attribute]) i++;   
        if(i<j) 
		{
			int s = j--;
			for(int k =0;k<5;k++)
		    	trainingSet[s][k]=trainingSet[i][k]; 
		}
    }
    for(int kk =0;kk<5;kk++)
		 trainingSet[i][kk] = pivot[kk];
    if(i-start>1) sort(attribute,start,i-1,trainingSet);
    if(end-j>1) sort(attribute,j+1,end,trainingSet);	
}

int DecisionTree:: Count(int attribute,int start,int end,double** trainingSet)
{
	/*int number=end-start+1;   //计算离散属性值attribute的数目
	int i;
	int j;
	for(i=start;i<=end;i++)  
	{
		for(j=start;j<i;j++)
		{
			if(trainingSet[i][attribute]==trainingSet[j][attribute])
			{
				number--;
				break;
			}
		}
	}*/
	int number;
	int temp=int(trainingSet[start][attribute]);
	int i;
	for(i=start+1;i<=end;i++)
	{
		if(trainingSet[i][attribute]!=temp)
			break;
	}
    if(i<=end)
		number=2;
	else number=1;
	return number;
}

void DecisonTree:: BuildTree()
{
     this->root=GenerateTree(0,numOfIns-1,trainingSet);
	 cout<<"This tree is built"<<endl;
}

Node* DecisionTree:: GenerateTree(int start,int end,double** trainingSet)
{
	if(trainingSet==NULL)   //训练集为空
	{
		return NULL;
	}
	if(Count(numOfAttr-1,start,end,trainingSet)==1)  //所有样本属于同一类别
	{
		Node* node= new Node();
		node->isLeaf = 1;
		node->attribute = int(trainingSet[start][numOfAttr-1]);
		if(node->attribute==1)
			node->value=1;
		else node->value=0;
		numOfNodes++;
		return node;
	}
	if(numOfAttr==1)  //只有类别属性,返回一个结点,其值是出现最多的属性值
	{
		int countneg=0;
		int countpos=0;
		for(int i=start;i<=end;i++)
		{
			if(trainingSet[i][numOfAttr-1]==0)
				countneg++;
			else if(trainingSet[i][numOfAttr-1]==1)
				countpos++;
		}
		int more;
		if(countneg>countpos)
			more=0;
		else
			more=1;
		Node* node=new Node();
		node->isLeaf=1;
		node->attribute=more;
		node->value=countpos/(end-start+1);
		numOfNodes++;
		return node;
	}
	else    //选择信息增益比最大的属性进行分类
	{
		int* used=new int[numOfAttr];
		for(int n=0;n<numOfAttr;n++)
			used[n]=0;
		double gainratio=0;
		int attribute;
		for(int i=0;i<numOfAttr-1;i++)
		{
			sort(i,start,end,trainingSet);
		    int j;
		    double* g=new double[end-start];
		    double info=SplitInfo(numOfAttr-1,start,end,trainingSet);
		    int* splitpoint=new int[numOfAttr-1];  //记录分割点
		    for(j=start;j<end;j++)
			{
			    int entropy=(j-start+1)/(end-start+1)*SplitInfo(numOfAttr-1,start,j,trainingSet)+(end-j)/(end-start+1)*SplitInfo(numOfAttr-1,j+1,end,trainingSet);
			    g[j-start]=info-entropy;
			    if(g[j-start]>gain)
				{
				    gain=g[j-start];
				    splitpoint[i]=j;
				}
			}
			if(!used[i]&&gainratio<GainRatio(i,start,end,trainingSet))//选择信息增益比最大的属性
			{
				gainratio=GainRatio(i,start,end,trainingSet);
				attribute=i;
			}
		}
			sort(attribute,start,end,trainingSet);
			Node* node=new Node();
			used[attribute]=1;
			numOfNodes++;
			if(numOfNodes=4||numOfNodes=5||numOfNodes=7||numOfNodes=8||numOfNodes=11||numOfNodes=12||numOfNodes=14||numOfNodes=15)//停止生长,置当前结点为叶子结点(按先根序列,这些结点为第3层结点(根结点为第0层))
			{
				node->isLeaf=1;
                int numOfPos=0;
				int numOfNeg=0;
				for(int i=start;i<=end;i++)
				{
					if(trainingSet[i][numOfAttr-1]==1)
						numOfPos++;
					else numOfNeg++;
				}
				node->value=numOfPos/(end-start+1);//有病的概率
				if(numOfPos>=numOfNeg)
					node->attribute=1;
				else node->attribute=0;
			}
			else//继续生长
			{
				node->value=trainingSet[splitpoint[attribute]][attribute];
				node->attribute=attribute;
				node->isLeaf=0;
				node->leftchild=GenerateTree(start,splitpoint[attribute],trainingSet);
			    node->rightchild=GenerateTree(splitpoint[attribute]+1,end,trainingSet);
			}
			return node;
	}
}

double DecisionTree:: ClassifyIns(double* instance)
{
	if(root==NULL)
		return -1;
	Node* p=root;
	while(p->isLeaf==0)
	{
		if(instance[p->attribute]<=p->value)
			p=p->leftchild;
		else
			p=p->rightchild;
	}
	return p->value;
}

void DecisonTree:: SaveTree(ofstream& fout,Node* node)
{
    if(node==NULL)
	{
		return;
	}
	fout<<node->isLeaf<<endl;
	fout<<node->value<<endl;
	fout<<node->attribute<<endl;
	// cout<<"save a node"<<endl;
	if(node->leftchild!=NULL)
    	SaveTree(fout,node->leftchild);
	if(node->rightchild!=NULL)
    	SaveTree(fout,node->rightchild);
}

void DecisonTree:: SaveTree(char* fname)
{
	ofstream fout(fname);
	saveTree(fout,root);
	fout.close();
}

Node* DecisonTree:: LoadNode(ifstream& fin)
{
	int i = 0;
	if(fin.eof())
	{
		return NULL;
	}
	else
	{
		int isLeaf = 0;
		double value =0;
		int attribute = 0;
		fin>>isLeaf;
		fin>>value;
		fin>>attribute;
		Node* node = new Node(value,attribute,isLeaf);
		if(node->isLeaf ==1)
		{
			return node;
		}
	    node->leftchild = LoadNode(fin);
		node->rightchild = LoadNode(fin);
		return node;
	}
}

void CDecisonTree:: LoadTree(char* fname)
{
	ifstream fin(fname);
	root= LoadNode(fin);
	fin.close();
}

void DecisonTree:: ShowTree(Node* node)
{
	if(node ==NULL)
		return;
	cout<<"tree node"<<node->GetValue()<<" "<<node->GetAttribute()<<" "<<node->IsLeaf()<<endl;
	if(node->leftchild!=NULL)
	{
		ShowTree(node->leftchild);
	}
	if(node->rightchild!=NULL)
		ShowTree(node->rightchild);
	
}








⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -