📄 naive_bayes.m
字号:
%
% naive_bayes.m
%
% created by Sunghwan Yoo
%
%
% Goal: predict label Y given X
%
% Naive Bayes
% ===========
%
% Learning:
% Compute prior probabilities p(Y=1), and p(Y=0)
% For all y = [0 1]
% For all i
% Compute p(X_i | y)
% Smooth p(X_i | y) if desired
%
% Note:
% p(Y=1) = fraction of training data with label Y=1
% p(X_i=1 | Y=0) = fraction of training data with label Y=0 that have 1
% for the value of attribute i.
%
% Classification:
% To classify Xhat:
% p0 = P(Y=0) * prod_i p(Xhat_i | Y=0)
% p1 = P(Y=1) * prod_i p(Xhat_i | Y=1)
% if p0 > p1 then Yhat = 0,
% else Yhat = 1 (or implement by using max() function)
%
% Data import and series of experiments will be taken care by the main.m
function [param, training_acc, test_acc] = naive_bayes(training_set, test_set)
param = naive_bayes_learn(training_set);
training_acc = 0;
test_acc = 0;
training_acc = test_naive_bayes( param, training_set );
test_acc = test_naive_bayes( param, test_set );
function [param] = naive_bayes_learn( dataset )
[nrows, ncols] = size(dataset);
num_attributes = ncols - 1;
for i=1:ncols
param.uvalues_per_col{i} = unique(dataset(:,i)); % gets maximum value for each column
end
%% initalize parameters
param.uclassifiers = unique(dataset(:,ncols));
num_of_classifiers = length(param.uclassifiers);
param.pr_y = zeros(1, num_of_classifiers);
param.pr_cond_xy = cell(1, num_attributes);
for i=1:num_attributes
param.pr_cond_xy{i} = zeros( length(param.uvalues_per_col{i}), num_of_classifiers );
end
%% Now we need to estimate P(Y=yj) and P(xi|Y=Yj)
for j = 1:num_of_classifiers
% get P(Y=yj)
param.pr_y(j) = length(find(dataset(:,ncols) == param.uclassifiers(j))) / nrows;
% get P(*|Y=yj)
yj = dataset( find(dataset(:,ncols)== param.uclassifiers(j)),:);
n_yj = length( find(dataset(:,ncols) == param.uclassifiers(j)) );
% get P(xi|Y=yj)
for i = 1:num_attributes
for k = 1:length(param.uvalues_per_col{i})
xval = param.uvalues_per_col{i}(k);
n_xi_yj = sum(yj(:,i) == xval);
param.pr_cond_xy{i}(k,j) = ( n_xi_yj + 1) / ( n_yj + length(param.uvalues_per_col{i}) );
end
end
end
% test naive bayes classifier with given parameter and test sets
function [test_acc] = test_naive_bayes( param, dataset )
[nrows, ncols] = size( dataset );
success = 0;
failure = 0;
for i=1:nrows
if( test_naive_bayes_case( param, dataset(i,1:ncols-1) ) == dataset(i,ncols) )
success = success + 1;
else
failure = failure + 1;
end
end
test_acc = (success) / (success + failure);
function [class_res] = test_naive_bayes_case(param, new_xi)
[nrows,ncols] = size(new_xi);
p = ones(1, length(param.pr_y));
for i = 1 : ncols-1
for k = 1 : length(param.uvalues_per_col{i})
if( new_xi(i) == param.uvalues_per_col{i}(k) )
p = p .* param.pr_cond_xy{i}(k,:); % product each conditional probability
end
end
end
[junk, class_id] = max(param.pr_y .* p);
class_res = param.uclassifiers(class_id);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -