⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 wajue.cpp

📁 利用VC++实现决策树分类算法
💻 CPP
字号:
// wajue.cpp : Defines the entry point for the console application.
//

#include "stdafx.h"

#include <iostream>
#include <fstream>
#include <sstream>
#include "DTree.h"


using namespace std;

DTree				*root;

vector<StoreData>	trainAll,		//所有的训练数据
					testAll,		//所有的测试数据
					train,			//选取的训练数据
					test;			//测试数据

vector<OriganData>	OtrainAll,		//原始的训练数据
					OtestAll,		//原始测试数据
					Otrain;			//原始的选取的训练数据

vector<int>			attributes;		//属性的范围
ifstream			fin;			
set<int>			trainSet;		//选取的训练数据编号集合
int					sortKind;		//排序的方式
double				conSpit[6];		//连续取值的属性的阈值,用c4.5的办法求得。
int					size = 0;

void init()
{
	readData(OtrainAll, "crx.train");
	readData(OtestAll, "crx.test");
	unsigned int selectDataNum = 350;
	selectData(OtrainAll, Otrain, selectDataNum, (int)OtrainAll.size());
	processConValue();
	changeData(Otrain, train);
	changeData(OtestAll, test);
	
	for (int i = 0; i < MaxAttr; i++)
	{
		attributes.push_back(i);
	}
}

void readData(vector<OriganData> &data, const char* fileName)
{
	fin.open(fileName);
	int iterNum;
	if (fileName[5] == 'r')
		iterNum = trainAllNum;
	else 
		iterNum = testAllNum;
	string line;
	OriganData d;
	for (int i = 0;  i < iterNum; i++)
	{
		fin >> line;
		while (line.find(',') > 0 && line.find(',') < line.length())
		{
			line[line.find(',')] = ' '; 
		}
		if (line.find('?') == 0 || line.find('?') >= line.length())
		{
			//cout << line << endl;
			istringstream stream(line);
			stream >> d.A1 >> d.A2 >> d.A3 >> d.A4 >> d.A5 >> d.A6 >> d.A7 >> d.A8 >> 
				d.A9 >> d.A10 >> d.A11 >> d.A12 >> d.A13 >> d.A14 >> d.A15 >> d.label;
			data.push_back(d);
		}
		else
		{
			if (fileName[5] == 'r')
				trainAllNum--;
			else testAllNum--;
		}
	}
	fin.close();
}

void selectData(vector<OriganData> &data, vector<OriganData> &subdata, unsigned int selectDataNum, int dataNum)
{
	srand((unsigned)time(NULL)); 
	int index;
	trainSet.clear();
	subdata.clear();
	while (trainSet.size() < selectDataNum)
	{
		//cout << data.size() << endl; 
		index = rand() % dataNum;
		if (trainSet.count(index) == 0)
		{
			trainSet.insert(index);
			subdata.push_back(data.at(index));
		}
	}
}

bool header(const OriganData &d1, const OriganData &d2)
{
	double d1_v, d2_v;
	switch (sortKind)
	{
	case 2:	d1_v = d1.A2; 
		d2_v = d2.A2;
		break;
	case 3:d1_v = d1.A3; 
		d2_v = d2.A3;
		break;
	case 8:d1_v = d1.A8; 
		d2_v = d2.A8;
		break;
	case 11:d1_v = d1.A11; 
		d2_v = d2.A11;
		break;
	case 14:d1_v = d1.A14; 
		d2_v = d2.A14;
		break;
	case 15:d1_v = d1.A15; 
		d2_v = d2.A15;
		break;
	}
	return d1_v <= d2_v;
}

double Entropy(double p, double s)
{
	double n = s - p;
	double result = 0;
	if (n != 0)
		result += - double(n) / s * log(double(n) / s) / log(2.0);
	if (p != 0)
		result += double(-p) / s * log(double(p) / s) / log(2.0);
	return result;
}

double Gain(double p1, double s1, double p2, double s2)
{
	return Entropy(p1 + p2, s1 + s2) - double(p1 / s1) * Entropy(p1, s1) - double(p2 / s2) * Entropy(p2, s2);
}

void processConValue()
{
	int con[6] = {2, 3, 8, 11, 14, 15};
	for (int i = 0; i < 6; i++)
	{
		sortKind = con[i];
		stable_sort(Otrain.begin(), Otrain.end(), header);
/*
		for (vector<OriganData>::iterator it = Otrain.begin(); it != Otrain.end(); it++)
			cout << (*it).A2 << (*it).label << '\t';
		cout << endl;
*/
		double bestGain = 0;		//记录最佳的Gain。
		double gain;
		vector<OriganData>::iterator bestit = Otrain.end();
		for (vector<OriganData>::iterator it = Otrain.begin(); it != Otrain.end() - 1; it++)
		{
			if ((*it).label != (*(it + 1)).label)
			{
				int p1 = 0, p2 = 0, n1 = 0, n2 = 0;		//记录正反例的个数
				for (vector<OriganData>::iterator jt = Otrain.begin(); jt != it + 1; jt++)
					if ((*jt).label == '+')
						p1++;
					else n1++;
				for (vector<OriganData>::iterator jt2 = it + 1; jt2 != Otrain.end(); jt2++)
					if ((*jt2).label == '+')
						p2++;
					else n2++;
				gain = Gain(p1, p1 + n1, p2, p2 + n2);
				if (gain > bestGain)
				{
					bestGain = gain;
					bestit = it;
				}
			}
		}
		if (bestit == Otrain.end())
			bestit = Otrain.begin();
		switch (sortKind)
		{
		case 2:	conSpit[i] = ((*bestit).A2 + (*(bestit + 1)).A2) / 2; 
			break;
		case 3: conSpit[i] = ((*bestit).A3 + (*(bestit + 1)).A3) / 2;
			break;
		case 8: conSpit[i] = ((*bestit).A8 + (*(bestit + 1)).A8) / 2;
			break;
		case 11: conSpit[i] = ((*bestit).A11 + (*(bestit + 1)).A11) / 2;
			break;
		case 14: conSpit[i] = ((*bestit).A14 + (*(bestit + 1)).A14) / 2;
			break;
		case 15: conSpit[i] = ((*bestit).A15 + (*(bestit + 1)).A15) / 2;
			break;
		}  
	}
}

void changeData(vector<OriganData> &Otrain, vector<StoreData> &train)
{
	StoreData d;
	for (vector<OriganData>::iterator it = Otrain.begin(); it != Otrain.end(); it++)
	{
		//A1
		switch ((*it).A1)
		{
		case 'a':	d.A[0] = 1;	break;
		case 'b':	d.A[0] = 2;	break;
		default:	d.A[0] = 1;
		}
		//A2
		d.A[1] = (*it).A2 < conSpit[0] ? 1 : 2;
		//A3
		d.A[2] = (*it).A3 < conSpit[1] ? 1 : 2;
		//A4
		switch ((*it).A4)
		{
		case 'u':	d.A[3] = 1;	break;
		case 'y':	d.A[3] = 2;	break;
		case 'l':	d.A[3] = 3;	break;
		case 't':	d.A[3] = 4;	break;
		default:	d.A[3] = 1;
		}
		//A5
		switch ((*it).A5[0])
		{
		case 'g': d.A[4] = (*it).A6.length() == 1 ? 1 : 3;	break;
		case 'p': d.A[4] = 2;	break;
		}
		//A6
		switch ((*it).A6[0])
		{
		case 'c':	d.A[5] = (*it).A6.length() == 1 ?  1 : 3;	break;
		case 'd':	d.A[5] = 2;	break;
		case 'i':	d.A[5] = 4;	break;
		case 'j':	d.A[5] = 5;	break;		
		case 'k':	d.A[5] = 6;	break;
		case 'm':	d.A[5] = 7;	break;
		case 'r':	d.A[5] = 8;	break;
		case 'q':	d.A[5] = 9;	break;
		case 'w':	d.A[5] = 10;	break;
		case 'x':	d.A[5] = 11;	break;
		case 'e':	d.A[5] = 12;	break;
		case 'a':	d.A[5] = 13;	break;
		case 'f':	d.A[5] = 14;	break;
		default:	d.A[5] = 1;
		}
		//A7
		switch ((*it).A7[0])
		{
		case 'v':	d.A[6] = 1;	break;
		case 'h':	d.A[6] = 2;	break;
		case 'b':	d.A[6] = 3;	break;
		case 'j':	d.A[6] = 4;	break;		
		case 'n':	d.A[6] = 5;	break;
		case 'z':	d.A[6] = 6;	break;
		case 'd':	d.A[6] = 7;	break;
		case 'f':	d.A[6] = 8;	break;
		case 'o':	d.A[6] = 9;	break;
		default:	d.A[6] = -1;
		}
		//A8
		d.A[7] = (*it).A3 < conSpit[2] ? 1 : 2;
		//A9
		switch ((*it).A9)
		{
		case 't':	d.A[8] = 1;	break;
		case 'f':	d.A[8] = 2;	break;
		default:	d.A[8] = -1;
		}
		//A10
		switch ((*it).A10)
		{
		case 't':	d.A[9] = 1;	break;
		case 'f':	d.A[9] = 2;	break;
		default:	d.A[9] = -1;
		}
		//A11
		d.A[10] = (*it).A11 < conSpit[3] ? 1 : 2;
		//A12
		switch ((*it).A12)
		{
		case 't':	d.A[11] = 1;	break;
		case 'f':	d.A[11] = 2;	break;
		default:	d.A[11] = 1;
		}
		//A13
		if ((*it).A13 == 'g')
			d.A[12] = 1;
		else if ((*it).A13 == 'p')
			d.A[12] = 2;
		else d.A[12] = 3;
		//A14
		d.A[13] = (*it).A14 < conSpit[4] ? 1 : 2;
		//A15
		d.A[14] = (*it).A15 < conSpit[5] ? 1 : 2;
        
		d.label = (*it).label;
		train.push_back(d);
	}
}

void printV(const vector<StoreData> &samples)
{
	for (vector<StoreData>::const_iterator it = samples.begin(); it != samples.end(); it++)
	{
		for (int i = 0; i < 15; i++)
			cout << (*it).A[i] << ' ';
		cout << (*it).label << endl;
	}	
}

int creatTree(DTree *&p, const vector<StoreData> &samples, vector<int> &attributes)
{
	//if (samples.size() > 0)
	//	printV(samples);
	if (p == NULL)
		p = new DTree();
	if (allTheSame(samples, '+'))
	{
		p->node.label = '+';
		p->node.attrNum = 16;
		p->childs.clear();
		return 1;
	}
	if (allTheSame(samples, '-'))
	{
		p->node.label = '-';
		p->node.attrNum = 16;
		p->childs.clear();
		return 1;
	}
	if (attributes.size() == 0)
	{
		p->node.label = mostCommonAttr(samples);
		p->node.attrNum = 16;
		p->childs.clear();
		return 1;
	}
	p->node.attrNum = findBestArrt(samples, attributes);
	if (p->node.attrNum == -1)
	{
		p->node.label = mostCommonAttr(samples);
		p->node.attrNum = 16;
		p->childs.clear();
		return 1;
	}
	p->node.label = ' ';
	
	vector<int> newAttributes;
	for (vector<int>::iterator it = attributes.begin(); it != attributes.end(); it++)
		if ((*it) != p->node.attrNum)
			newAttributes.push_back((*it));

	vector<StoreData> subSamples[16];
	for (int i = 0; i < 16; i++)
		subSamples[i].clear();

	for (vector<StoreData>::const_iterator it1 = samples.begin(); it1 != samples.end(); it1++)
	{
		//cout << (*it).A[p->node.attrNum] << endl;
		subSamples[(*it1).A[p->node.attrNum]].push_back((*it1));
	}

	DTree *child;
	for (i = 1; i <= ArrtNum[p->node.attrNum]; i++)
	{
		child = new DTree;	
		child->node.attr = i;
		if (subSamples[i].size() == 0)
			child->node.label = mostCommonAttr(samples);
		else
			creatTree(child, subSamples[i], newAttributes);
		p->childs.push_back(child);
	}
	return 0;
}

int findBestArrt(const vector<StoreData> &samples, vector<int> &attributes)
{
	int attr, 
		bestAttr = 0,
		p = 0,
		s = (int)samples.size();
	for (vector<StoreData>::const_iterator it = samples.begin(); it != samples.end(); it++)	//统计抽样中正例的个数
	{
		if ((*it).label == '+')
			p++;
	}
	double result;
	double bestResult = 0;
	int subN[15], subP[15];			//统计每种属性的值的正例、反例个数
	for (vector<int>::iterator it2=attributes.begin(); it2!=attributes.end(); it2++)
	{
		attr = (*it2);
		result = Entropy(p, s);
		for (int i = 0; i < 15; i++)
		{
			subN[i] = 0;
			subP[i] = 0;
		}
		for (vector<StoreData>::const_iterator jt = samples.begin(); jt != samples.end(); jt++)
		{
			//cout << (*jt).A[attr] << ' ';
			if ((*jt).A[attr] > 14)
			{
				cout<< "Oh my god!" << endl;
				exit(-1);
			}
			if ((*jt).label == '+')
				subP[(*jt).A[attr]] ++;
			else
				subN[(*jt).A[attr]] ++;
		}
		for (i = 1; i <= ArrtNum[attr]; i++)
		{
			//cout << i << ' ';
			if (i > 14)
			{
				cout<< "Oh my god!" << endl;
				exit(-1);
			}
			result -= double(subP[i] + subN[i]) / s * Entropy(subP[i], subP[i] + subN[i]);
		}
		if (result > bestResult)
		{
			bestResult = result;
			bestAttr = attr;
		}
	}

	if (bestResult == 0)
	{
		return -1;
	}
	else return bestAttr;
}

bool allTheSame(const vector<StoreData> &samples, char ch)
{
	for (vector<StoreData>::const_iterator it = samples.begin(); it != samples.end(); it++)
		if ((*it).label != ch)
			return false;
	return true;
}

char mostCommonAttr(const vector<StoreData> &samples)
{
	int p = 0, n = 0;
	for (vector<StoreData>::const_iterator it = samples.begin(); it != samples.end(); it++)
		if ((*it).label == '+')
			p++;
		else
			n++;
	if (p >= n)
		return '+';
	else
		return '-';
}

char testTree(DTree *p, StoreData d)
{
	if (p->node.label != ' ')
		return p->node.label;
	int attrNum = p->node.attrNum;
	if (d.A[attrNum] < 0)
		return ' ';
	return testTree(p->childs.at(d.A[attrNum] - 1), d);
}

void testData()
{
	int miss = 0;
	int i = 0;
	for (vector<StoreData>::iterator it = test.begin(); it != test.end(); it++)
	{
		i++;
		if (testTree(root, (*it)) != (*it).label)
			miss++;
	}
	miss = test.size() - miss;
	cout << "fight:";
	cout << double(miss) / test.size() << endl;
}

void printTree(DTree *p, int depth)
{
	if (p->node.label != ' ')
	{
		cout << p->node.label << endl;
		return;
	}
	cout << " : " << p->node.attrNum << endl;;
	for (vector<DTree*>::iterator it = p->childs.begin(); it != p->childs.end(); it++)
	{
		for (int i = 0; i < depth; i++)
			cout << '\t';
		cout << (*it)->node.attr;
		printTree(*it, depth + 1);
	}
}

void freeTree(DTree *p)
{
	if (p == NULL)
		return;
	for (vector<DTree*>::iterator it = p->childs.begin(); it != p->childs.end(); it++)
	{
		freeTree(*it);
	}
	delete p;
	size++;
}

void outResult()
{
	cout << "Tree size:" << size << endl;
}

int main()
{
	init();
	creatTree(root, train, attributes);
	printTree(root, 0);
	testData();
	freeTree(root);
	outResult();
	return 0;
}
/*int main(int argc, char* argv[])
{
	return 0;
}
*/

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -