⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bayes_2.m

📁 Bayes分类器应用于IRIS数据集的例子
💻 M
字号:
clear;

load ('classalldatafile.mat');

ClassAllData=ClassAllData';
class_1=class_1';
class_2=class_2';
class_3=class_3';

sampleNum_1 = size(class_1,2);   %样本数目 sampleNum_1,2,3=500
sampleNum_2 = size(class_2,2);
sampleNum_3 = size(class_3,2);

sampleAll = sampleNum_1 + sampleNum_2 + sampleNum_3;   %总样本数

prior_1 = sampleNum_1/sampleAll;   %计算先验概率
prior_2 = sampleNum_2/sampleAll;
prior_3 = sampleNum_3/sampleAll;

mean_1 = sum(class_1(1:3,:),2)/sampleNum_1;   %计算均值P55
mean_2 = sum(class_2(1:3,:),2)/sampleNum_2;  
mean_3 = sum(class_3(1:3,:),2)/sampleNum_3;  

sigma_1 = (class_1(1:3,:)-repmat(mean_1,[1 sampleNum_1]))*(class_1(1:3,:)-repmat(mean_1,[1 sampleNum_1]))'/sampleNum_1;   %计算方差
sigma_2 = (class_2(1:3,:)-repmat(mean_2,[1 sampleNum_2]))*(class_2(1:3,:)-repmat(mean_2,[1 sampleNum_2]))'/sampleNum_2;
sigma_3 = (class_3(1:3,:)-repmat(mean_3,[1 sampleNum_3]))*(class_3(1:3,:)-repmat(mean_3,[1 sampleNum_3]))'/sampleNum_3;

det_sigma_1=det(sigma_1);
det_sigma_2=det(sigma_2);
det_sigma_3=det(sigma_3);

inv_sigma_1=inv(sigma_1);
inv_sigma_2=inv(sigma_2);
inv_sigma_3=inv(sigma_3);

gx = zeros(3,sampleAll);
temp = zeros(1500,1500);
temp = -0.5*(ClassAllData(1:3,:)-repmat(mean_1,[1 sampleAll]))'*inv_sigma_1*(ClassAllData(1:3,:)-repmat(mean_1,[1 sampleAll])) ...
       -0.5*log(det_sigma_1)+log(prior_1);      %判别函数P31(2-88)
gx(1,:) = (diag(temp))';
temp = -0.5*(ClassAllData(1:3,:)-repmat(mean_2,[1 sampleAll]))'*inv_sigma_2*(ClassAllData(1:3,:)-repmat(mean_2,[1 sampleAll])) ...
       -0.5*log(det_sigma_2)+log(prior_2);     
gx(2,:) = (diag(temp))';
temp = -0.5*(ClassAllData(1:3,:)-repmat(mean_3,[1 sampleAll]))'*inv_sigma_3*(ClassAllData(1:3,:)-repmat(mean_3,[1 sampleAll])) ...
       -0.5*log(det_sigma_3)+log(prior_3);     
gx(3,:) = (diag(temp))';
clear temp;
[max_gx, ClassLabel] = max(gx);    %max_gx:判别函数矩阵gx中每一列的最大值   ClassLabel:最大值所在的行

Err_ind = find(ClassAllData(4,:)~=ClassLabel);   %Err_ind:分类错误的点的编号

errorrate=size(Err_ind,2)/sampleAll;   %错误率

disp('错误率为');disp(errorrate);

figure(1)
hold on
scatter3(class_1(1,:),class_1(2,:),class_1(3,:),3,'r');   %画图
scatter3(class_2(1,:),class_2(2,:),class_2(3,:),3,'g');
scatter3(class_3(1,:),class_3(2,:),class_3(3,:),3,'b');
scatter3(ClassAllData(1,Err_ind),ClassAllData(2,Err_ind),ClassAllData(3,Err_ind),23,'k');
view(127.5,30)
grid
xlabel('x');
ylabel('y');
zlabel('z');
hold off
figure(2)
hold on
scatter3(class_1(1,:),class_1(2,:),class_1(3,:),3,'r');   %画图
scatter3(class_2(1,:),class_2(2,:),class_2(3,:),3,'g');
scatter3(class_3(1,:),class_3(2,:),class_3(3,:),3,'b');
scatter3(ClassAllData(1,Err_ind),ClassAllData(2,Err_ind),ClassAllData(3,Err_ind),23,'k');
xlabel('x');
ylabel('y');
zlabel('z');
view(0,90)
hold off
figure(3)
hold on
scatter3(class_1(1,:),class_1(2,:),class_1(3,:),3,'r');   %画图
scatter3(class_2(1,:),class_2(2,:),class_2(3,:),3,'g');
scatter3(class_3(1,:),class_3(2,:),class_3(3,:),3,'b');
scatter3(ClassAllData(1,Err_ind),ClassAllData(2,Err_ind),ClassAllData(3,Err_ind),23,'k');
view(90,0)
xlabel('x');
ylabel('y');
zlabel('z');
hold off
figure(4)
hold on
scatter3(class_1(1,:),class_1(2,:),class_1(3,:),3,'r');   %画图
scatter3(class_2(1,:),class_2(2,:),class_2(3,:),3,'g');
scatter3(class_3(1,:),class_3(2,:),class_3(3,:),3,'b');
scatter3(ClassAllData(1,Err_ind),ClassAllData(2,Err_ind),ClassAllData(3,Err_ind),23,'k');
view(0,0)
xlabel('x');
ylabel('y');
zlabel('z');
hold off

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -