📄 generate_decision_tree.m
字号:
function Result=Generate_decision_tree(DataName,WhereSen,ForecastSen,attributName,i,j)
%DataName为数表名称,ForecastSen预测属性名称,attributName为现有的属性名称,i为结点位置标记,为
%全局变量
%WhereSen为筛选语句名称,也就是从决策树根到这个结点的筛选条件 由j来记录parent的位置为局部变量
global Node;
global i;
logintimeout(15);
conn = database('DecisionTreeTest', '', '');
%exec(conn,'use PAKDDCompetition2007');
%所建决策树结点结构 对中间结点 1 属性值attributSen 2 筛选语句whereSen 3父结点ID 4 子结点ID 根结点父结点为NULL
%对叶子结点 1 属性值 2 筛选语句 3 分类divSen 4 父结点ID 5 子结点为NULL
%把属性的单元数组转为字符串 这里的attributName是一个1*n的单元数组
if i==1
%第一次循环的时候得到需要进行操作的列名
KK=['select b.name from sysobjects a inner join syscolumns b on a.id=b.id where a.name=''',DataName,'''']
curs1=exec(conn,['select b.name from sysobjects a inner join syscolumns b on a.id=b.id where a.name=''',DataName,'''']);
% curs1=exec(conn,'select b.name from sysobjects a inner join syscolumns b on a.id=b.id where a.name=''DecisionTreeTest''');
curs1=fetch(curs1);
attributeNameArray=curs1.data;
close(curs1);
%为适应DecisionTreeTestCo翻转了90度
x=size(attributeNameArray);
attributName=reshape(attributeNameArray,x(2),x(1));
else
attributNameList='';
for I=1:length(attributName)
if ~isempty(attributName{1,I})
if I==length(attributName)
attributNameList=[attributNameList,attributName{1,I}]
else
attributNameList=[attributNameList,attributName{1,I},','];
end
end
end
end
curs2=exec(conn,['select',' ',ForecastSen,' from',' ',DataName,' ',WhereSen,' ','group by',' ',ForecastSen]);
setdbprefs('DataReturnFormat','cellarray');
KK=['select',' ',ForecastSen,' from',' ',DataName,' ',WhereSen,' ','group by',' ',ForecastSen]
curs2=fetch(curs2);
DiffClus=curs2.data;
%close(curs2);
%如果Sample都在同一个类C上,则产生叶子结点如下,并用类C标记
if length(DiffClus)==1
Node(i).attributSen=attributNameList;
Node(i).whereSen=WhereSen
Node(i).divSen=DiffClus(1);
if i==1
Node(i).parent=0;
else
Node(i).parent=j;
end
return;
end
%如果剩下的输入属性为空,则按多数表决决定类别
numnoempty=0;
for I=1:length(attributName)
if isempty(attributName{1,I})
numnoempty=numnoempty+1;
end
end
if numnoempty==length(attributName)-2
divflag=MostChoose(DataName,WhereSen,ForecastSen);
%由子函数得到的divflag即为多数表决的结果 产生叶子结点如下
Node(i).attributSen=attributNameList;
Node(i).whereSen=WhereSen;
Node(i).divSen=divflag;
if i==1
Node(i).parent=0;
else
Node(i).parent=j;
end
return;
end
%若是其它情况则调用信息增益比较算法来确定信息增益最高属性
sz=DecisionTreeTestCon(DataName,WhereSen,ForecastSen,attributName);
if sz==0%如果信息增益过小,则直接产生叶子结点
divflag=MostChoose(DataName,WhereSen,ForecastSen);
Node(i).attributSen=attributNameList;
Node(i).whereSen=WhereSen
Node(i).divSen=divflag;
if i==1
Node(i).parent=0;
else
Node(i).parent=j;
end
else
Node(i).attributSen=DecisionTreeTestCon(DataName,WhereSen,ForecastSen,attributName);
Node(i).whereSen=WhereSen;
if i==1
Node(i).parent=0;
else
Node(i).parent=j;
end
%找出新结点的分类情况并递归求解
KK=['select ',' ',Node(i).attributSen,' from DecisionTreeTest group by',' ',Node(i).attributSen]
curs=exec(conn,['select ',' ',Node(i).attributSen,' from',' ',DataName,' group by',' ',Node(i).attributSen]);
setdbprefs('DataReturnFormat','cellarray');
curs=fetch(curs,3);
NewDiv=curs.data;
close(curs);
WhereSenOld=WhereSen;
%WhereSenOld用来保存上一层的where条件
if i==1
WhereSenOld=WhereSen;
WhereSen=['where',' ',Node(i).attributSen,'='];
else
WhereSen=[WhereSen,' and',' ',Node(i).attributSen,'='];
end
%进入循环递归之前对一些变量进行保存,以便出递归时调用
%WhereSenOld1用来保存进循环前的where条件
WhereSenOld1=WhereSen;
attributNameOld=attributName;
j=i;
%取各条件分枝做递归运算
for I=1:length(NewDiv)
WhereSen=[WhereSenOld1,'''',NewDiv{I,1},''''];
i
for II=1:length(attributNameOld)
if strcmp(attributNameOld{1,II},Node(i).attributSen)
attributName{1,II}={};
end
end
i=i+1;
KK=['select count(*) from',' ',DataName,' ',WhereSen]
curs=exec(conn,['select count(*) from',' ',DataName,' ',WhereSen]);
setdbprefs('DataReturnFormat','numeric');
curs=fetch(curs,3);
RecordNum=curs.data;
if RecordNum(1)<2
divflag=MostChoose(DataName,WhereSenOld,ForecastSen);
Node(i).attributSen=attributNameList;
Node(i).whereSen=WhereSen
Node(i).divSen=divflag;
if i==1
Node(i).parent=0;
else
Node(i).parent=j;
end
else
Generate_decision_tree(DataName,WhereSen,ForecastSen,attributName,i,j);
end
end
KK=['select b.name from sysobjects a inner join syscolumns b on a.id=b.id where a.name=''',DataName,'''']
Result=Node;
end
function modivflag=MostChoose(DataName,WhereSen,ForecastSen)
logintimeout(15);
conn = database('DecisionTreeTest', '', '');
% exec(conn,'use PAKDDCompetition2007');
KK=['select ',' ',ForecastSen,',count(',ForecastSen,') from',' ',DataName,' ',WhereSen,' group by',' ',ForecastSen]
curs3=exec(conn,['select ',' ',ForecastSen,',count(',ForecastSen,') from',' ',DataName,' ',WhereSen,' group by',' ',ForecastSen]);
setdbprefs('DataReturnFormat','cellarray');
curs3=fetch(curs3,3);
test=curs3.data;
% close(curs3);
divnum=0;
for I=1:length(test)
if divnum<test{I,2}
divnum=test{I,2};
divflag=I;
end
end
modivflag=test{divflag,1};
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -