📄 dtid3.cpp
字号:
// DTID3.cpp: implementation of the CDTID3 class.
//
//////////////////////////////////////////////////////////////////////
#include "stdafx.h"
#include "mining.h"
#include "DTID3.h"
#ifdef _DEBUG
#undef THIS_FILE
static char THIS_FILE[]=__FILE__;
#define new DEBUG_NEW
#endif
TreeNode dt;
//////////////////////////////////////////////////////////////////////
// Construction/Destruction
//////////////////////////////////////////////////////////////////////
CDTID3::CDTID3()
{
}
CDTID3::~CDTID3()
{
}
void CDTID3::PrepareID3(char ***string, int height, int width, char **title)
{
char *valid,*avail_att,*rule;
int tab_cnt=0;
valid=(char *)malloc(height * sizeof(char));
avail_att=(char *)malloc(width * sizeof(char));
string1=string;
attributes=M2=M4=width;
rows=M1=height;
M3=10;
rule=(char *)malloc(M4 * sizeof(char));
title1=title;
if ((ofp = fopen("d_tree.txt","a+")) == NULL)
{
::MessageBox(NULL,"不能打开d_tree.dat!","tishi",MB_OK);
return;
}
if ((pfp = fopen("d_pos.txt","a+")) == NULL)
{
::MessageBox(NULL,"不能打开d_pos.dat!","tishi",MB_OK);
return;
}
if ((nfp = fopen("d_neg.txt","a+")) == NULL)
{
::MessageBox(NULL,"不能打开d_neg.dat!","tishi",MB_OK);
return;
}
fprintf(pfp,"rule\n");
fprintf(nfp,"rule\n");
fprintf(ofp,"\n");
if (check_all_positive())
{
fprintf(ofp,"HALT:all_positive\n");
fprintf(pfp,"HALT:all_positive\n");
fprintf(nfp,"HALT:all_positive\n");
fprintf(pfp, "rule_end\n");
fprintf(nfp, "rule_end\n");
fclose(ofp);
fclose(nfp);
fclose(pfp);
return;
}
if (check_all_negative())
{
fprintf(ofp,"HALT:all_negative\n");
fprintf(pfp,"HALT:all_negative\n");
fprintf(nfp,"HALT:all_negative\n");
fprintf(pfp, "rule_end\n");
fprintf(nfp, "rule_end\n");
fclose(ofp);
fclose(nfp);
fclose(pfp);
return;
}
memset (valid, 42, rows); // set to '*'
memset (avail_att, 42, M2); // set to '*'
memset (rule, 45, M4); // set to '-'
// strcpy(dtNode.Title,"first");
// dtNode.nChildCount=1;
// dtNode.child=(ChildNode *)malloc(sizeof(ChildNode));
strcpy(dt.Title,"first");
dt.nChildCount=1;
dt.child=(ChildNode *)malloc(sizeof(ChildNode));
if (create_tree(dt.child,rule, avail_att, valid,tab_cnt) == 999)
{
return;
}
fprintf(pfp, "rule_end\n");
fprintf(nfp, "rule_end\n");
fclose(ofp);
fclose(pfp);
fclose(nfp);
}
void CDTID3::disaster(int i)
{
switch(i)
{
case 1: printf("** ID3 failure **\n");
break;
}
}
int CDTID3::check_all_positive()
{
int i;
for(i=0;i<rows;++i)
{
if (strcmp(string1[i][attributes-1],_T("不及格")) == 0)
{
return 0;
}
}
return 1;
}
int CDTID3::check_all_negative()
{
int i;
for(i=0;i<rows;++i)
{
if (strcmp(string1[i][attributes-1],"及格") == 0)
{
return 0;
}
}
return 1;
}
int CDTID3::find_att(char *avail_att, char *valid)
{
int i, j, l, y_tot = 0, n_tot = 0, y_tot_2, n_tot_2;
bool bHave=false;
int tot_diff_atts; //不同值的个数
int att_no = 0;
double max_inf_gain = -1.0;
double entropy, entropy_2, r_entropy_tot;
double *att_entropy;
char **att_names;
char *valid_2;
att_entropy=(double *)malloc(M2 * sizeof(double));
//att_names存放一列中的所有的值,所以应为M1个,即行数rows
att_names=(char **)malloc(M1 * sizeof(char *));
for(int k=0;k<M1;k++)
att_names[k]=(char *)malloc(M3 * sizeof(char));
//
valid_2=(char *)malloc(M1 * sizeof(char));
//初始化每一列的信息熵
for (i=0;i<M2;++i)
{
att_entropy[i] = -2.0;
}
//求出要预测列的信息熵
for (i=0;i<M1;i++)
{
if (valid[i] == '*')
{
if (strcmp(string1[i][attributes-1],"及格") == 0)
++y_tot;
if (strcmp(string1[i][attributes-1],"不及格") == 0)
++n_tot;
}
}
if (y_tot == 0 || n_tot == 0)
entropy = 0.0;
else
{
entropy = 0.0 - ((y_tot/(double)(y_tot+n_tot))
*log((y_tot/(double)(y_tot+n_tot))))
- ((n_tot/(double)(y_tot+n_tot))*
log((n_tot/(double)(y_tot+n_tot))));
}
//求所有有效列的信息熵
for (i=0;i<attributes-1;++i)
{
//判断此列是否有效
if (avail_att[i] == '*')
{
bHave=true;
r_entropy_tot = 0.0;
//求得这一列所有不同的值,放到att_names的最前面
tot_diff_atts = get_diff_att_types(valid,
att_names, i, rows);
for (j=0;j<tot_diff_atts;++j)
{
//初始化valid_2,初始为所有的行都无效
memset (valid_2, 32, M1);
//确定与每一个值对应的有效的行,用valid_2标识
for (l=0;l<rows;++l)
{
if(valid[l]=='*')
{
if(strcmp(att_names[j],string1[l][i]) == 0)
valid_2[l] = '*';
}
}
y_tot_2 = 0;
n_tot_2 = 0;
for (l=0;l<M1;l++)
{
if (valid_2[l]== '*')
{
if (strcmp(string1[l][attributes-1],"及格") == 0)
++y_tot_2;
if (strcmp(string1[l][attributes-1],"不及格") == 0)
++n_tot_2;
}
}
if (n_tot_2 == 0 || y_tot_2 == 0)
entropy_2 =0.0;
else
{
entropy_2 =0.0 - ((y_tot_2/(double)(y_tot_2+n_tot_2))
*log((y_tot_2/(double)(y_tot_2+n_tot_2))))
- ((n_tot_2/(double)(y_tot_2+n_tot_2))
*log((n_tot_2/(double)(y_tot_2+n_tot_2))));
}
r_entropy_tot = r_entropy_tot +
(entropy_2
* ((n_tot_2+y_tot_2)/(double)(n_tot+y_tot)));
}//end_for(j=0)
att_entropy[i] = entropy - r_entropy_tot;
}//end_if
}//end_for(i=1)
if(bHave==false)
return -1;
for (l=0;l<M2;++l)
{
if (att_entropy[l] >= max_inf_gain)
{
max_inf_gain = att_entropy[l];
att_no = l;
}
}
/* if (max_inf_gain == 0.0)
{
disaster(1);
return 999;
}*/
return att_no;
}
int CDTID3::get_diff_att_types(char *valid, char **att_names, int att, int max_row)
{
int j,l,k;
char **att_temp;
att_temp=(char **)malloc(M1 * sizeof(char *));
for(k=0;k<M1;k++)
att_temp[k]=(char *)malloc(M3 * sizeof(char));
//将这一列所有的值复制到att_names中
for(j=0;j<max_row;j++)
{
strcpy(att_names[j],string1[j][att]);
}
//将无效行的att_names中的值置*,即去掉
for(l=0;l<j;++l)
{
if (valid[l] != '*')
memset(att_names[l], 42, M3-1);
}
//将重复的值置*,即去掉
for(j=0;j<max_row;j++)
{
l=1;
for(l=l+j;l<max_row;l++)
{
if (strcmp(att_names[j],att_names[l]) == 0)
{
memset(att_names[l], 42, M3-1);
}
}
}
//将不同的值拷贝到att_temp中,l为行号
for(l=0,k=0;l<j;l++)
{
if (att_names[l][0] != '*')
{
strcpy(att_temp[k],att_names[l]);
k++;
}
}
//将att_names全置*
for(l=0;l<j;++l)
{
memset(att_names[l], 42, M3-1);
}
//将att_temp中的值复制回att_names中
for(l=0;l<k;++l)
{
strcpy(att_names[l],att_temp[l]);
}
for(l=0,k=0;l<j;l++)
{
if (att_names[l][0] != '*')
++k;
}
return k; //k为不同的值的个数
}
int CDTID3::not_all_same(char *valid)
{
int i, y_tot = 0, n_tot = 0;
//w double m1,m2;
char *valid_n;
ASSERT(valid!=NULL);
valid_n=(char *)malloc(strlen(valid)+1);
strcpy(valid_n,valid);
for (i=0;i<M1;i++)
{
if (valid_n[i] == '*')
{
if (strcmp(string1[i][attributes-1],"及格") == 0)
++y_tot;
if (strcmp(string1[i][attributes-1],"不及格") == 0)
++n_tot;
}
}
free(valid_n);
// m1=y_tot/(double)(y_tot+n_tot);
// m2=n_tot/(double)(y_tot+n_tot);
if (n_tot == 0 )//|| m1>=0.98)
return 2;
/* all yes */
else if (y_tot == 0)// || m2>=0.98)
return 3;
/* all no */
else
return 1;
}
int CDTID3::Chuli(char *valid_2)
{
int i, y_tot = 0, n_tot = 0;
char *valid_ch;
ASSERT(valid_2!=NULL);
valid_ch=(char *)malloc(strlen(valid_2)+1);
strcpy(valid_ch,valid_2);
for (i=0;i<M1;i++)
{
if (valid_ch[i] == '*')
{
if (strcmp(string1[i][attributes-1],"及格") == 0)
++y_tot;
if (strcmp(string1[i][attributes-1],"不及格") == 0)
++n_tot;
}
}
free(valid_ch);
if(y_tot>=n_tot)
return 2;
else return 3;
}
int CDTID3::create_tree(ChildNode *child, char *rule, char *avail_att, char *valid, int tab_cnt)
{
char **att_names;
char *valid_2;
char *avail_att_2;
char *rule_2;
int j, l, i, ret, tot_diff_atts, att_no,result;
char *rule_c,*avail_att_c,*valid_c;
TreeNode *dt;
ASSERT((rule!=NULL)&&(avail_att!=NULL)&&(valid!=NULL));
rule_c=(char *)malloc(strlen(rule)+1);
avail_att_c=(char *)malloc(strlen(avail_att)+1);
valid_c=(char *)malloc(strlen(valid)+1);
strcpy(rule_c,rule);
strcpy(avail_att_c,avail_att);
strcpy(valid_c,valid);
dt=(TreeNode *)malloc(sizeof(TreeNode));
//用于标识新的有效的行和列
valid_2=(char *)malloc(rows * sizeof(char));
avail_att_2=(char *)malloc(attributes * sizeof(char));
rule_2=(char *)malloc(M4 * sizeof(char));
//用于存放这一列中所有可能不同的值
att_names=(char **)malloc(M1 * sizeof(char *));
for(int k=0;k<M1;k++)
att_names[k]=(char *)malloc(M3 * sizeof(char));
for (i=0;i<tab_cnt+tab_cnt;++i)
{
fprintf(ofp,"\t");
}
tab_cnt++;
//找到信息增益最大的列(即信息熵最小的列)
att_no = find_att(avail_att_c,valid_c);
if (att_no== 999)
{
//printf("attno\n");
return 999;
}
else if(att_no==-1)
return 998;
child->node=dt;
rule_c[M4-1] = '\0';
avail_att_c[M2-1] = '\0';
strcpy(avail_att_2, avail_att_c);
//使找到的列无效
avail_att_2[att_no] =' ';
fprintf(ofp, "[%s]\n", title1[att_no]);
strcpy(dt->Title,title1[att_no]);
//得到此列中所有不同的值,存储于att_names中,tot_diff_atts是个数
tot_diff_atts = get_diff_att_types(valid_c,att_names, att_no, rows);
dt->nChildCount=tot_diff_atts;
dt->child=(ChildNode *)malloc(tot_diff_atts * sizeof(ChildNode));
for (j=0;j<tot_diff_atts;++j)
{
valid_c[M1-1] = '\0';
strcpy(valid_2,valid_c);
//使本字段不等于当前值的行无效,存于valid_2中
for (l=0;l<rows;++l)
{
if (strcmp(att_names[j],string1[l][att_no]) != 0)
{
valid_2[l] = ' ';
}
}
strcpy(dt->child[j].value,att_names[j]);
if ((ret = not_all_same(valid_2)) == 1)
{
for (i=0;i<tab_cnt+tab_cnt-1;++i)
{
fprintf(ofp,"\t");
}
fprintf(ofp,"%s\n",att_names[j]);
rule_c[att_no-1] = att_names[j][0];
strcpy(rule_2, rule_c);
result=create_tree(&dt->child[j],rule_2, avail_att_2,valid_2,tab_cnt);
if (result == 999)
{
return 999;
}
else
if(result ==998)
{
i=Chuli(valid_2);
for (i=0;i<tab_cnt+tab_cnt-1;++i)
{
fprintf(ofp,"\t");
}
if (i == 2)
{
TreeNode *node1;
node1=(TreeNode *)malloc(sizeof(TreeNode));
strcpy(node1->Title,"YES");
node1->nChildCount=-1;
node1->child=NULL;
dt->child[j].node=node1;
fprintf(ofp,"\t -YES\n");
rule_c[att_no-1] = att_names[j][0];
fprintf(pfp,"%s\n",rule_c);
}
else
{
TreeNode *node2;
node2=(TreeNode *)malloc(sizeof(TreeNode));
strcpy(node2->Title,"NO");
node2->nChildCount=-1;
node2->child=NULL;
dt->child[j].node=node2;
fprintf(ofp,"%s\t -NO\n",att_names[j]);
rule_c[att_no-1] = att_names[j][0];
fprintf(nfp,"%s\n",rule_c);
}
}
}
else
{
for (i=0;i<tab_cnt+tab_cnt-1;++i)
{
fprintf(ofp,"\t");
}
if (ret == 2)
{
TreeNode *node1;
node1=(TreeNode *)malloc(sizeof(TreeNode));
strcpy(node1->Title,"YES");
node1->nChildCount=-1;
node1->child=NULL;
dt->child[j].node=node1;
fprintf(ofp,"%s\t -YES\n",att_names[j]);
rule_c[att_no-1] = att_names[j][0];
fprintf(pfp,"%s\n",rule_c);
}
else
{
TreeNode *node2;
node2=(TreeNode *)malloc(sizeof(TreeNode));
strcpy(node2->Title,"NO");
node2->nChildCount=-1;
node2->child=NULL;
dt->child[j].node=node2;
fprintf(ofp,"%s\t -NO\n",att_names[j]);
rule_c[att_no-1] = att_names[j][0];
fprintf(nfp,"%s\n",rule_c);
}
}
}
/* free(rule_c);
free(valid_c);
free(avail_att_c);
free(rule_2);
free(valid_2);
free(avail_att_2);
free(att_names);*/
return 1;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -