📄 naivebayes.m
字号:
%田宏宇《机器学习》作业2
function NaiveBayes()
%Naive贝叶斯数据分类器
%clear all;
% calc xmean,sigma and its eigen decomposition
allsamples=[];%所有训练图像
load faceTemplet pattern;
%faceTemplet人脸包结构说明:
%pattern——1*40数组,包含40个人脸数据
% pattern(i).num——此样本集中有几个样本,典型为10
% pattern(i).figure——1*1结构,data
% pattern(i).figure(j).data——1*10结构,存储10个样本
% pattern(i).figure(j).data(k)——1*1结构,face
% pattern(i).figure(j).data(k).face——图像样本
nclass = 40; %选取类别数,以下的注释中假定为40
ntrain = 5; %训练数,以下的注释中假定为5
ntest = pattern(1).num - ntrain; %测试数,以下的注释中假定为5
for i=1:nclass
for j=1:ntrain
% imshow(a);
b=pattern(i).figure.data(j).face(:); % b是列矢量 1×N,其中N=10304,提取顺序是先列后行,即从上到下,从左到右
b=im2double(b'); %转为实数行向量
allsamples = [allsamples; b]; % allsamples 是一个M * 10304 矩阵,allsamples 中每一列数据代表一张图片,其中M为训练样本数
end
end
disp('1.训练样本导入完毕!');
tmp_T = cputime;
accu = 0;
%以下求类的判别函数
oneClassSamples = ones(ntrain, size(allsamples,2)); %一类中的训练图像Feature
discResult = ones(1, nclass); %预先构造一个判别结果数组
discStruct.funHandle = @sin; %预先构造一个判别函数结构体
discStruct.sigma = []; %最终此判别函数结构体的长度为nclass
discStruct.mu = [];
for i=1:nclass
for j=1:ntrain
oneClassSamples(j,:) = allsamples((i-1)*ntrain+j,:);
% OneClassSamples 是一个ntrain * N 矩阵
end
%求一类样本的均值以及协方差阵
meanClassSamples = mean(oneClassSamples); %均值行向量,1*N
sigma = cov(meanClassSamples);
%构造一个判别函数结构体
discStruct(i).funHandle = ...
@(x, MU, SIGMA)(-x*(inv(SIGMA))*x'/2 + MU*(inv(SIGMA))'*x' -MU*(inv(SIGMA))*MU'/2 - log(det(SIGMA))/2);
discStruct(i).sigma = sigma;
discStruct(i).mu = meanClassSamples;
end
disp('2.类判别函数构造完毕!');
%以下用NaiveBayes准则进行分类测试
% 测试过程
testsamples = []; %所有测试样本M*10304, M = nclass*ntest
for i = 1 : nclass
for j = ntrain+1 : ntrain+ntest % 读入ntest x 5 副测试图像
b = pattern(i).figure.data(j).face(:); % b是列矢量 1×N,其中N=10304,提取顺序是先列后行,即从上到下,从左到右
b = im2double(b'); % 转换为行向量
testsamples = [testsamples; b];
end
end
for i = 1 : nclass
for j = 1 : ntest
for k = 1 : nclass
discResult(k) = discStruct(k).funHandle(testsamples((i-1)*ntest+j,:), discStruct(k).mu, discStruct(k).sigma);
end
%利用排序方式得到样本属于第几类
[maxv, indx] = max(discResult);
result = ['第', num2str(i),'类中第', num2str(j), '张识别'];
if indx == i
accu = accu + 1;
result = [result, '正确'];
else
result = [result, '错误,被识别为了第', num2str(indx), '类'];
end
disp(result);
end
end
accuracy = accu / (nclass*ntest); %输出识别率
disp('>>> ');
disp(strcat('从',num2str(nclass),'类图像中选择前',num2str(ntrain),...
'张训练,后',num2str(ntest),'张测试,'));
disp(strcat('Naive贝叶斯分类, 共花费时间',num2str(cputime-tmp_T)));
disp('--- ');
disp(strcat('识别率为', num2str(accuracy*100), '%'));
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -