⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bp_mlpno1.m

📁 模式识别的一个基础程序,手写数字模式识别,提供给大家分享
💻 M
字号:
function BP_MLP

clear;
clc;

sample_num = 20;
test_num = 360;
h_unitnum = 10;
i_unitnum = 2;
o_unitnum = 1;
rand('state', sum(100*clock));%保证每次产生随机数状态重置
xmin = -10;
xmax = 10;
p = [zeros(sample_num);zeros(sample_num)];

x = linspace(xmin,xmax,sample_num);
y = linspace(xmin,xmax,sample_num);
z = (sin(x)./x)'*(sin(y)./y);

for i = 1:sample_num
    x_train = linspace(x(i),x(i),sample_num);
    y_train = y;    
    p((2*i-1):2*i,:) = [x_train;y_train];   
end


alpha = 0.016;
eita = 0.01;
threshold = 0.8;

a = 0.8;
b = 1.0;
w_h = a*rand(h_unitnum,i_unitnum)-0.5*a;%h_unitnum*i_unitnum
w_o = b*rand(o_unitnum,h_unitnum)-0.5*b;%o_unitnum*h_unitnum
b_h = a*rand(h_unitnum,1)-0.5*a;%h_unitnum*1
b_o = b*rand(o_unitnum,1)-0.5*b;%o_unitnum*1
 
w_h_exp = [w_h (-1)*b_h];%h_unitnum*(i_unitnum+1)输入当隐层的权值扩展
w_o_exp = [w_o (-1)*b_o];%o_unitnum*(h_unitnum+1)隐层到输出的权值扩展

error=[];
epochmax=20000;

w_o_exp_old = zeros(o_unitnum,h_unitnum+1);
w_h_exp_old = zeros(h_unitnum,i_unitnum+1);
e = zeros(sample_num);
for i=1:epochmax
    for j=1:sample_num    
        p_exp = [p((2*j-1):2*j,:)' ones(sample_num,1)]';% (i_unitnum+1)*sample_num采样输入扩展
        h_out = logsig(w_h_exp*p_exp); %h_unitnum*sample_num隐层输出
        h_out_exp = [h_out' ones(sample_num,1)]';%(h_unitnum+1)*sample_num隐层输出扩展
        o_out = w_o_exp*h_out_exp;%o_unitnum*sample_num    
        e(j,:) = z(j,:)-o_out;
        
        %反向计算误差
        %隐层到输出权值阈值更新
        delta_o = e(j,:); %o_unitnum*sample_num
        delta_o1 = delta_o*h_out_exp'; %o_unitnum*(h_unitnum+1)
        w_o_exp = w_o_exp+alpha*delta_o1+eita*(w_o_exp-w_o_exp_old);%o_unitnum*(h_unitnum+1)
        w_o_exp_old = w_o_exp;
        %输入到隐层权值阈值更新
        delta_h = w_o'*delta_o.*h_out.*(1-h_out);%h_unitnum*sample_num
        delta_h1 = delta_h*p_exp';%h_unitnum*(i_unitnum+1)
        w_h_exp = w_h_exp+alpha*delta_h1+eita*(w_h_exp-w_h_exp_old); 
        w_h_exp_old = w_h_exp;
    end
    et = sumsqr(e);
    error = [error et];
    %判断
    if et<threshold
        break;
    end
    [m,n] = size(error);
    switch n
        case 80 
            sw = 1;
           meshout(sw,x,y,z,test_num,xmin,xmax,w_h_exp,w_o_exp);
        case 500
            sw = 2;
            meshout(sw,x,y,z,test_num,xmin,xmax,w_h_exp,w_o_exp);
        case 1000
            sw = 3;
            meshout(sw,x,y,z,test_num,xmin,xmax,w_h_exp,w_o_exp);
        case 5000
            sw = 4;
            meshout(sw,x,y,z,test_num,xmin,xmax,w_h_exp,w_o_exp);
        case 20000
            sw = 5;
            meshout(sw,x,y,z,test_num,xmin,xmax,w_h_exp,w_o_exp);
        otherwise
            continue;
    end
   
end
%输出训练误差
[m,n] = size(error);
figure(1);
plot(1:n,error);

%测试输出
function meshout(sw,x,y,z,test_num,xmin,xmax,wh,wo)
pt = [zeros(test_num);zeros(test_num)];
xt = linspace(xmin,xmax,test_num);
yt = linspace(xmin,xmax,test_num);
for i = 1:test_num
    x_train = linspace(xt(i),xt(i),test_num);
    y_train = yt;    
    pt((2*i-1):2*i,:) = [x_train;y_train];   
end
for j=1:test_num    
        pt_exp = [pt((2*j-1):2*j,:)' ones(test_num,1)]';
        ht_out = logsig(wh*pt_exp); 
        ht_out_exp = [ht_out' ones(test_num,1)]';
        ot_out = wo*ht_out_exp;
        zt(j,:) = ot_out+1;
end
%输出结果
figure(sw+1);
mesh(x,y,z);
xlabel('input 1');
ylabel('input 2');
zlabel('output');
hold on;
mesh(xt,yt,zt);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -