bp.m

来自「利用bp算法对鸢尾花数据进行分类的matlab实现程序」· M 代码 · 共 87 行

M
87
字号
%%%%%%%%%%%%%%%%%%%%%%%%
%BP算法 March 24th,2009%
%%%%%%%%%%%%%%%%%%%%%%%%
clear all;
close all;
clc;
N = 4;                          %第一层神经元个数
M=4;                            %第二层神经元个数
alpha = 0.5;                    %步长
eta = 0.9;                      %惯性系数
in_dim = 4;
minErr = 0.01;                  %均方误差
cycle_num = 1;
array1=xlsread('data1.xls');    %从excel文件中读取待分类数据。
array2=xlsread('data2.xls');
for t=1:50
    array1(t,5)=-1;
    array1(t,6)=0.9;
    array2(t,5)=-1;
    array2(t,6)=0.1;
end
in_matrix=[array1(1,:);
           array2(1,:)];
for i=2:50
    in_matrix=[in_matrix;
               array1(i,:);
               array2(i,:)];
end
w_1 = rand(in_dim+1, N);
w_2 = rand(N+1, M);
w_3= rand(M+1,1);
dd_1 = zeros(in_dim+1, N);
dd_2 = zeros(N+1, M);
dd_3 = zeros(M+1, 1);
err = 1;
num = 0;

while err>minErr
   num = num+1;                                  %记录循环次数
   if mod(num,1000)==0
       num
   end
   err = 0;
   for cycle_num=1:100
   	    in_x = in_matrix(cycle_num, 1:5);
		Coutput = in_matrix(cycle_num, 6);
		out_1 = 1./(1+exp(-1*in_x*w_1));
		out_1s = [out_1 -1];
		out_2 = 1./(1+exp(-1*out_1s*w_2));
        out_2s = [out_2 -1];
        output = 1/(1+exp(-1*out_2s*w_3));

		epsilon = Coutput-output;
		delta3 = 2*epsilon*output*(1-output);
		w_3 = w_3+alpha*delta3.*out_2s'+eta*dd_3;
		w_3s = w_3(1:M);
		delta2 = delta3*w_3s'.*out_2.*(1-out_2);
   	    w_2 = w_2+alpha*out_1s'*delta2+eta*dd_1;
   	    w_2s = w_2(1:N,:);
        delta1 = delta2*w_2s'.*out_1.*(1-out_1);
        w_1 = w_1+alpha*in_x'*delta1+eta*dd_1;
        dd_3 = alpha*delta3.*out_2s'+eta*dd_3;
        dd_2 = alpha*out_1s'*delta2+eta*dd_2;
        dd_1 = alpha*in_x'*delta1+eta*dd_1;
        err = err+epsilon^2;
   end
   err = sqrt(err)/100;
end
sum=0;                                     %计算分类错误个数
yy = zeros(1,100);                         %输出分类结果
for k=1:100
   in_x = in_matrix(k, 1:5);
   op_1 = 1./(1+exp(-in_x*w_1));
   op_1s=[op_1 -1];
   op_2=1./(1+exp(-1*op_1s*w_2));
   op_2s=[op_2 -1];
   yy(k) =1/(1+exp(-1*op_2s*w_3));
   if mod(k,2)==1&&yy(k)<0.5
       sum=sum+1;
   end
   if mod(k,2)==0&&yy(k)>0.5
       sum=sum+1;
   end
end
num
sum
yy

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?