quanbu.m

来自「朴素贝叶斯分类器」· M 代码 · 共 83 行

M
83
字号
clear;
clc;
load shengchenshuju;
T=size(r1,1);
%%%%%%%%%%%%k叠交叉验证%%%%%%%%%%%%%
k=4;s=4;
tdata=r1(s:k:T,:);%测试数据包括分类好的决策属性值
r1(s:k:T,:)=[];
traindata=r1;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
[n N]=size(traindata);
[a1 a2]=Classification(traindata(:,N));%找到类变量分类的范围
[traindata A B]=guiyihua(traindata);
H1=[];H2=[];H3=[];
for i=1:n
    if traindata(i,N)>=a2 & traindata(i,N)<=1
           h1=i;
           H1=[H1;h1];
    elseif traindata(i,N)>=a1 & traindata(i,N)<a2
           h2=i;
           H2=[H2;h2];
    elseif traindata(i,N)>=0 & traindata(i,N)<a1
           h3=i;
           H3=[H3;h3];
    end
end
traindata=fanguiyihua(traindata,A,B);
%计算先验概率
c1=size(H1,1);c2=size(H2,1);c3=size(H3,1);
prior=[c1/n c2/n c3/n];
%计算第一类中的各个特征属性以及第一类的均值以及方差
means1=mean(traindata(H1,1:N));    %N表示列向量的个数
sigmapf1=var(traindata(H1,1:N)).*((c1-1)/c1);
m1=cov(traindata(H1,1:N-1)).*((c1-1)/c1);
%计算第二类中的各个特征属性以及第二类的均值以及方差
means2=mean(traindata(H2,1:N));
sigmapf2=var(traindata(H2,1:N)).*((c2-1)/c2);
m2=cov(traindata(H2,1:N-1)).*((c2-1)/c2);
%计算第三类中的各个特征属性以及第三类的均值以及方差
means3=mean(traindata(H3,1:N));
sigmapf3=var(traindata(H3,1:N)).*((c3-1)/c3);
m3=cov(traindata(H3,1:N-1)).*((c3-1)/c3);

means=[means1;means2;means3];
sigmapf=[sigmapf1;sigmapf2;sigmapf3];
sigma=sigmapf.^(1/2);

t1=sqrt(det(m1));
t2=sqrt(det(m2));
t3=sqrt(det(m3));
M1=means(1,1:N-1)';
M2=means(2,1:N-1)';
M3=means(3,1:N-1)';

%%%%%%%%%%%%%%%进入测试阶段%%%%%%%%%%%%%%
[TestSampleNum nn]=size(tdata);
Z=nn-1;
testdata=tdata(:,1:Z);
M=[];
for j=1:TestSampleNum
    f1=exp(-((testdata(j,1:Z))'-M1)'*inv(m1)*((testdata(j,1:Z))'-M1)/2)*1/((2*pi)^(3/2)*t1);
    f2=exp(-((testdata(j,1:Z))'-M2)'*inv(m2)*((testdata(j,1:Z))'-M2)/2)*1/((2*pi)^(3/2)*t2);
    f3=exp(-((testdata(j,1:Z))'-M3)'*inv(m3)*((testdata(j,1:Z))'-M3)/2)*1/((2*pi)^(3/2)*t3);
    p1=f1*prior(1);
    p2=f2*prior(2);
    p3=f3*prior(3);
    M=[M;p1 p2 p3];
end
[m ind] = max(M, [], 2);
tdata(:,nn)=guiyihua1(tdata(:,nn),A(N),B(N));
for i=1:TestSampleNum
   if tdata(i,nn)>=a2 & tdata(i,nn)<=1
    tdata(i,nn)=1;
   elseif tdata(i,nn)>=a1 & tdata(i,nn)<a2
    tdata(i,nn)=2;
   elseif tdata(i,nn)>=0 & tdata(i,nn)<a1
    tdata(i,nn)=3;
   end
end
 
Err_ind = find(tdata(:,nn)~=ind);
E=size(Err_ind,1);
Correctrate=(TestSampleNum-E)/TestSampleNum;

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?