📄 create_tree.m
字号:
%
% create_tree.m
%
% Sunghwan Yoo
%
% Create a tree structure for id3
function [node] = create_tree(dataset, level, cutoff)
% create a blank node
clear node;
% structure of node
% node.subnode() : subnode for each splitted values.
% node.split_attribute : selected splitting attribute on current node
% node.split_vars : set of vars exists on the splitted attribute
% node.info_gain : value of information gain
% node.count_child_node : count of elements in the child node
% node.level : level
% node.classvar : target class { republican, democrat }
[nrows, ncols] = size(dataset);
class_attr = ncols;
% Stop if we have reached at the maximum level.
if( level >= ncols-1 )
node.subnbode0 = [];
node.subnbode1 = [];
node.subnbode2 = [];
node.split_attribute = 0;
node.split_vars = [];
node.info_gain = 0;
node.count_child_node = 0;
node.level = 0;
if( sum(dataset(:,class_attr) == 1) > sum(dataset(:,class_attr) == 0) )
node.classvar = 1;
else
node.classvar = 0;
end
return;
end
% Calculate information gain for this current dataset.
g = info_gain(dataset);
[max_info_gain, split_attr] = max(g);
% calculate p-value if cutoff has been given.
pvalue = 0;
if( max_info_gain > 0 && cutoff < 1)
pvalue = chisq_test( dataset, class_attr, split_attr );
end
% if max gain gives negative result of pvalue exceeds cutoff, stop growing.
if (max_info_gain <= 0 || pvalue > cutoff)
node.subnbode0 = [];
node.subnbode1 = [];
node.subnbode2 = [];
node.split_attribute = 0;
node.split_vars = [];
node.info_gain = 0;
node.count_child_node = 0;
node.level = 0;
if( sum(dataset(:,class_attr) == 1) > sum(dataset(:,class_attr) == 0) )
node.classvar = 1;
else
node.classvar = 0;
end
return;
end
% Now, set up node
node.subnode0 = [];
node.subnode1 = [];
node.subnode2 = [];
node.split_attribute = split_attr;
node.split_vars = unique(dataset(:,split_attr));
node.info_gain = max_info_gain;
uvals = unique(dataset(:,split_attr));
node.count_child_node = length(uvals);
node.level = level;
node.classvar = -1;
% node.classvar is not evaluated here because it has children
for i = 1 : node.count_child_node
new_dataset = [];
clear new_node;
% create a new dataset per each subnode ..
for j=1 : nrows
if( dataset( j, split_attr ) == uvals(i) ) % split it by {values} in split_attr
new_dataset = [new_dataset; dataset(j,:)];
end
end
[srows, scols] = size(new_dataset);
% create empty leaf if new_dataset is blank
if( isempty(new_dataset) )
new_node.subnode0 = [];
new_node.subnode1 = [];
new_node.subnode2 = [];
new_node.split_attribute = 0;
new_node.split_vars = [];
new_node.info_gain = 0;
new_node.count_child_node = 0;
new_node.level = 0;
new_node.classvar = -1;
else
new_node = create_tree( new_dataset, level+1, cutoff );
end
if( uvals(i) == 0 )
node.subnode0 = new_node;
elseif( uvals(i) == 1 )
node.subnode1 = new_node;
elseif( uvals(i) == 2 )
node.subnode2 = new_node;
end
end
% Perform chi-square test (it only generates contingency table and call test function
function [pvalue] = chisq_test(dataset, class_attr, split_attr)
uvar_class = unique(dataset(:,class_attr));
uvar_split = unique(dataset(:,split_attr));
ctable = zeros(length(uvar_class), length(uvar_split));
for i=1:length(uvar_class)
for j=1:length(uvar_split)
xi = dataset( find(dataset(:,class_attr) == uvar_class(i)),:);
jval = uvar_split(j);
ctable(i,j) = sum(xi(:,i) == jval);
end
end
[rows, cols] = size(dataset);
pvalue = chisq_test_con(ctable);
% Perfor chisquare test on the given contingeny table
function [pvalue] = chisq_test_con(Y)
% Compute total number of trials
N=sum(sum(Y));
[row col] = size(Y);
% Compute the totals for Attribute A (Gender)
nidot=sum(Y,2);
% Compute the totals for Attribute B (College)
ndotj=sum(Y);
% Compute the relative frequencies (probability estimates)
% for Attribute B
if( N == 0 )
pvalue = 1;
return;
end
pdotj = ndotj/N;
% Compute the expected frequencies (an outer product)
NP=nidot*pdotj;
% Compute the relative frequencies (probability estimates)
% for Attribute A
pidot = nidot/N;
% Compute the chi-square statistic for the test of
% independence of attributes
if( find(NP == 0) )
pvalue = 1;
return;
end
q=sum(sum(((Y-NP).^2)./NP));
% Compute the degrees of freedom for q:
[k h] = size(Y);
dof = (h-1)*(k-1);
if (dof >= 1 && row > 1 && col > 1)
% Compute the p-value
pvalue = chiSquareProbQuad(dof, q, inf);
else
pvalue = 1;
end
function prob = chiSquareProbQuad(r, xl, xu)
warning off;
% Function chiSquareProbQuad.m
% Computes probabilities related to the chi-square
% distribution with r degrees of freedom.
%
% This function computes and returns:
%
% P(xl < X < xu) = \int_{xl}^{xu} f(x) dx
%
%
% Usage: prob = chiSquareProb(r, xl, xu)
%
% Input:
% r - scalar representing the number of degrees of freedom
% xl - scalar lower limit of integration
% xu - scalar upper limit of integration
%
% Note: This function the built-in Matlab function quad
%
% Author: Ernest E. Rothman.
% Created 10/12/2005
% Last modified: 10/21/2005
%
% This work is licensed under a Creative Commons License.
% See http://creativecommons.org/licenses/by-nc-sa/2.5/
%
xl=max(xl,0.0);
rOver2 = 0.5*r;
gammaval = gamma(rOver2);
F = @(x) x.^(rOver2-1).*exp(-0.5*x)/gamma(rOver2)/2^rOver2;
if xu == inf
prob = 1-quad(F,0.0,xl);
else
prob = quad(F,xl,xu);
end
warning on;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -