📄 wajue.cpp
字号:
// wajue.cpp : Defines the entry point for the console application.
//
#include "stdafx.h"
#include <iostream>
#include <fstream>
#include <sstream>
#include "DTree.h"
using namespace std;
DTree *root;
vector<StoreData> trainAll, //所有的训练数据
testAll, //所有的测试数据
train, //选取的训练数据
test; //测试数据
vector<OriganData> OtrainAll, //原始的训练数据
OtestAll, //原始测试数据
Otrain; //原始的选取的训练数据
vector<int> attributes; //属性的范围
ifstream fin;
set<int> trainSet; //选取的训练数据编号集合
int sortKind; //排序的方式
double conSpit[6]; //连续取值的属性的阈值,用c4.5的办法求得。
int size = 0;
void init()
{
readData(OtrainAll, "crx.train");
readData(OtestAll, "crx.test");
unsigned int selectDataNum = 350;
selectData(OtrainAll, Otrain, selectDataNum, (int)OtrainAll.size());
processConValue();
changeData(Otrain, train);
changeData(OtestAll, test);
for (int i = 0; i < MaxAttr; i++)
{
attributes.push_back(i);
}
}
void readData(vector<OriganData> &data, const char* fileName)
{
fin.open(fileName);
int iterNum;
if (fileName[5] == 'r')
iterNum = trainAllNum;
else
iterNum = testAllNum;
string line;
OriganData d;
for (int i = 0; i < iterNum; i++)
{
fin >> line;
while (line.find(',') > 0 && line.find(',') < line.length())
{
line[line.find(',')] = ' ';
}
if (line.find('?') == 0 || line.find('?') >= line.length())
{
//cout << line << endl;
istringstream stream(line);
stream >> d.A1 >> d.A2 >> d.A3 >> d.A4 >> d.A5 >> d.A6 >> d.A7 >> d.A8 >>
d.A9 >> d.A10 >> d.A11 >> d.A12 >> d.A13 >> d.A14 >> d.A15 >> d.label;
data.push_back(d);
}
else
{
if (fileName[5] == 'r')
trainAllNum--;
else testAllNum--;
}
}
fin.close();
}
void selectData(vector<OriganData> &data, vector<OriganData> &subdata, unsigned int selectDataNum, int dataNum)
{
srand((unsigned)time(NULL));
int index;
trainSet.clear();
subdata.clear();
while (trainSet.size() < selectDataNum)
{
//cout << data.size() << endl;
index = rand() % dataNum;
if (trainSet.count(index) == 0)
{
trainSet.insert(index);
subdata.push_back(data.at(index));
}
}
}
bool header(const OriganData &d1, const OriganData &d2)
{
double d1_v, d2_v;
switch (sortKind)
{
case 2: d1_v = d1.A2;
d2_v = d2.A2;
break;
case 3:d1_v = d1.A3;
d2_v = d2.A3;
break;
case 8:d1_v = d1.A8;
d2_v = d2.A8;
break;
case 11:d1_v = d1.A11;
d2_v = d2.A11;
break;
case 14:d1_v = d1.A14;
d2_v = d2.A14;
break;
case 15:d1_v = d1.A15;
d2_v = d2.A15;
break;
}
return d1_v <= d2_v;
}
double Entropy(double p, double s)
{
double n = s - p;
double result = 0;
if (n != 0)
result += - double(n) / s * log(double(n) / s) / log(2.0);
if (p != 0)
result += double(-p) / s * log(double(p) / s) / log(2.0);
return result;
}
double Gain(double p1, double s1, double p2, double s2)
{
return Entropy(p1 + p2, s1 + s2) - double(p1 / s1) * Entropy(p1, s1) - double(p2 / s2) * Entropy(p2, s2);
}
void processConValue()
{
int con[6] = {2, 3, 8, 11, 14, 15};
for (int i = 0; i < 6; i++)
{
sortKind = con[i];
stable_sort(Otrain.begin(), Otrain.end(), header);
/*
for (vector<OriganData>::iterator it = Otrain.begin(); it != Otrain.end(); it++)
cout << (*it).A2 << (*it).label << '\t';
cout << endl;
*/
double bestGain = 0; //记录最佳的Gain。
double gain;
vector<OriganData>::iterator bestit = Otrain.end();
for (vector<OriganData>::iterator it = Otrain.begin(); it != Otrain.end() - 1; it++)
{
if ((*it).label != (*(it + 1)).label)
{
int p1 = 0, p2 = 0, n1 = 0, n2 = 0; //记录正反例的个数
for (vector<OriganData>::iterator jt = Otrain.begin(); jt != it + 1; jt++)
if ((*jt).label == '+')
p1++;
else n1++;
for (vector<OriganData>::iterator jt2 = it + 1; jt2 != Otrain.end(); jt2++)
if ((*jt2).label == '+')
p2++;
else n2++;
gain = Gain(p1, p1 + n1, p2, p2 + n2);
if (gain > bestGain)
{
bestGain = gain;
bestit = it;
}
}
}
if (bestit == Otrain.end())
bestit = Otrain.begin();
switch (sortKind)
{
case 2: conSpit[i] = ((*bestit).A2 + (*(bestit + 1)).A2) / 2;
break;
case 3: conSpit[i] = ((*bestit).A3 + (*(bestit + 1)).A3) / 2;
break;
case 8: conSpit[i] = ((*bestit).A8 + (*(bestit + 1)).A8) / 2;
break;
case 11: conSpit[i] = ((*bestit).A11 + (*(bestit + 1)).A11) / 2;
break;
case 14: conSpit[i] = ((*bestit).A14 + (*(bestit + 1)).A14) / 2;
break;
case 15: conSpit[i] = ((*bestit).A15 + (*(bestit + 1)).A15) / 2;
break;
}
}
}
void changeData(vector<OriganData> &Otrain, vector<StoreData> &train)
{
StoreData d;
for (vector<OriganData>::iterator it = Otrain.begin(); it != Otrain.end(); it++)
{
//A1
switch ((*it).A1)
{
case 'a': d.A[0] = 1; break;
case 'b': d.A[0] = 2; break;
default: d.A[0] = 1;
}
//A2
d.A[1] = (*it).A2 < conSpit[0] ? 1 : 2;
//A3
d.A[2] = (*it).A3 < conSpit[1] ? 1 : 2;
//A4
switch ((*it).A4)
{
case 'u': d.A[3] = 1; break;
case 'y': d.A[3] = 2; break;
case 'l': d.A[3] = 3; break;
case 't': d.A[3] = 4; break;
default: d.A[3] = 1;
}
//A5
switch ((*it).A5[0])
{
case 'g': d.A[4] = (*it).A6.length() == 1 ? 1 : 3; break;
case 'p': d.A[4] = 2; break;
}
//A6
switch ((*it).A6[0])
{
case 'c': d.A[5] = (*it).A6.length() == 1 ? 1 : 3; break;
case 'd': d.A[5] = 2; break;
case 'i': d.A[5] = 4; break;
case 'j': d.A[5] = 5; break;
case 'k': d.A[5] = 6; break;
case 'm': d.A[5] = 7; break;
case 'r': d.A[5] = 8; break;
case 'q': d.A[5] = 9; break;
case 'w': d.A[5] = 10; break;
case 'x': d.A[5] = 11; break;
case 'e': d.A[5] = 12; break;
case 'a': d.A[5] = 13; break;
case 'f': d.A[5] = 14; break;
default: d.A[5] = 1;
}
//A7
switch ((*it).A7[0])
{
case 'v': d.A[6] = 1; break;
case 'h': d.A[6] = 2; break;
case 'b': d.A[6] = 3; break;
case 'j': d.A[6] = 4; break;
case 'n': d.A[6] = 5; break;
case 'z': d.A[6] = 6; break;
case 'd': d.A[6] = 7; break;
case 'f': d.A[6] = 8; break;
case 'o': d.A[6] = 9; break;
default: d.A[6] = -1;
}
//A8
d.A[7] = (*it).A3 < conSpit[2] ? 1 : 2;
//A9
switch ((*it).A9)
{
case 't': d.A[8] = 1; break;
case 'f': d.A[8] = 2; break;
default: d.A[8] = -1;
}
//A10
switch ((*it).A10)
{
case 't': d.A[9] = 1; break;
case 'f': d.A[9] = 2; break;
default: d.A[9] = -1;
}
//A11
d.A[10] = (*it).A11 < conSpit[3] ? 1 : 2;
//A12
switch ((*it).A12)
{
case 't': d.A[11] = 1; break;
case 'f': d.A[11] = 2; break;
default: d.A[11] = 1;
}
//A13
if ((*it).A13 == 'g')
d.A[12] = 1;
else if ((*it).A13 == 'p')
d.A[12] = 2;
else d.A[12] = 3;
//A14
d.A[13] = (*it).A14 < conSpit[4] ? 1 : 2;
//A15
d.A[14] = (*it).A15 < conSpit[5] ? 1 : 2;
d.label = (*it).label;
train.push_back(d);
}
}
void printV(const vector<StoreData> &samples)
{
for (vector<StoreData>::const_iterator it = samples.begin(); it != samples.end(); it++)
{
for (int i = 0; i < 15; i++)
cout << (*it).A[i] << ' ';
cout << (*it).label << endl;
}
}
int creatTree(DTree *&p, const vector<StoreData> &samples, vector<int> &attributes)
{
//if (samples.size() > 0)
// printV(samples);
if (p == NULL)
p = new DTree();
if (allTheSame(samples, '+'))
{
p->node.label = '+';
p->node.attrNum = 16;
p->childs.clear();
return 1;
}
if (allTheSame(samples, '-'))
{
p->node.label = '-';
p->node.attrNum = 16;
p->childs.clear();
return 1;
}
if (attributes.size() == 0)
{
p->node.label = mostCommonAttr(samples);
p->node.attrNum = 16;
p->childs.clear();
return 1;
}
p->node.attrNum = findBestArrt(samples, attributes);
if (p->node.attrNum == -1)
{
p->node.label = mostCommonAttr(samples);
p->node.attrNum = 16;
p->childs.clear();
return 1;
}
p->node.label = ' ';
vector<int> newAttributes;
for (vector<int>::iterator it = attributes.begin(); it != attributes.end(); it++)
if ((*it) != p->node.attrNum)
newAttributes.push_back((*it));
vector<StoreData> subSamples[16];
for (int i = 0; i < 16; i++)
subSamples[i].clear();
for (vector<StoreData>::const_iterator it1 = samples.begin(); it1 != samples.end(); it1++)
{
//cout << (*it).A[p->node.attrNum] << endl;
subSamples[(*it1).A[p->node.attrNum]].push_back((*it1));
}
DTree *child;
for (i = 1; i <= ArrtNum[p->node.attrNum]; i++)
{
child = new DTree;
child->node.attr = i;
if (subSamples[i].size() == 0)
child->node.label = mostCommonAttr(samples);
else
creatTree(child, subSamples[i], newAttributes);
p->childs.push_back(child);
}
return 0;
}
int findBestArrt(const vector<StoreData> &samples, vector<int> &attributes)
{
int attr,
bestAttr = 0,
p = 0,
s = (int)samples.size();
for (vector<StoreData>::const_iterator it = samples.begin(); it != samples.end(); it++) //统计抽样中正例的个数
{
if ((*it).label == '+')
p++;
}
double result;
double bestResult = 0;
int subN[15], subP[15]; //统计每种属性的值的正例、反例个数
for (vector<int>::iterator it2=attributes.begin(); it2!=attributes.end(); it2++)
{
attr = (*it2);
result = Entropy(p, s);
for (int i = 0; i < 15; i++)
{
subN[i] = 0;
subP[i] = 0;
}
for (vector<StoreData>::const_iterator jt = samples.begin(); jt != samples.end(); jt++)
{
//cout << (*jt).A[attr] << ' ';
if ((*jt).A[attr] > 14)
{
cout<< "Oh my god!" << endl;
exit(-1);
}
if ((*jt).label == '+')
subP[(*jt).A[attr]] ++;
else
subN[(*jt).A[attr]] ++;
}
for (i = 1; i <= ArrtNum[attr]; i++)
{
//cout << i << ' ';
if (i > 14)
{
cout<< "Oh my god!" << endl;
exit(-1);
}
result -= double(subP[i] + subN[i]) / s * Entropy(subP[i], subP[i] + subN[i]);
}
if (result > bestResult)
{
bestResult = result;
bestAttr = attr;
}
}
if (bestResult == 0)
{
return -1;
}
else return bestAttr;
}
bool allTheSame(const vector<StoreData> &samples, char ch)
{
for (vector<StoreData>::const_iterator it = samples.begin(); it != samples.end(); it++)
if ((*it).label != ch)
return false;
return true;
}
char mostCommonAttr(const vector<StoreData> &samples)
{
int p = 0, n = 0;
for (vector<StoreData>::const_iterator it = samples.begin(); it != samples.end(); it++)
if ((*it).label == '+')
p++;
else
n++;
if (p >= n)
return '+';
else
return '-';
}
char testTree(DTree *p, StoreData d)
{
if (p->node.label != ' ')
return p->node.label;
int attrNum = p->node.attrNum;
if (d.A[attrNum] < 0)
return ' ';
return testTree(p->childs.at(d.A[attrNum] - 1), d);
}
void testData()
{
int miss = 0;
int i = 0;
for (vector<StoreData>::iterator it = test.begin(); it != test.end(); it++)
{
i++;
if (testTree(root, (*it)) != (*it).label)
miss++;
}
miss = test.size() - miss;
cout << "fight:";
cout << double(miss) / test.size() << endl;
}
void printTree(DTree *p, int depth)
{
if (p->node.label != ' ')
{
cout << p->node.label << endl;
return;
}
cout << " : " << p->node.attrNum << endl;;
for (vector<DTree*>::iterator it = p->childs.begin(); it != p->childs.end(); it++)
{
for (int i = 0; i < depth; i++)
cout << '\t';
cout << (*it)->node.attr;
printTree(*it, depth + 1);
}
}
void freeTree(DTree *p)
{
if (p == NULL)
return;
for (vector<DTree*>::iterator it = p->childs.begin(); it != p->childs.end(); it++)
{
freeTree(*it);
}
delete p;
size++;
}
void outResult()
{
cout << "Tree size:" << size << endl;
}
int main()
{
init();
creatTree(root, train, attributes);
printTree(root, 0);
testData();
freeTree(root);
outResult();
return 0;
}
/*int main(int argc, char* argv[])
{
return 0;
}
*/
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -