⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bp.m

📁 基于梯度下降的BP神经网络参数学习算法.
💻 M
字号:
clear
inputNums=3;    %输入层神经数
outputNums=1;   %输出层神经数
hideNums=10;    %隐藏层神经数
maxcount=2000;  %训练总数
samplenum=3;    %样本个数
precision=0.001;%训练精度

alpha=0.5;%设定值
beta=0.5;%设定值
count=1;%当前训练次数
error=zeros(1,count);
errorp=zeros(1,samplenum);

v=rand(inputNums,hideNums);
st1=zeros(1,hideNums);     %初始化输入层和隐藏层的权值和阈值

w=rand(hideNums,outputNums);
st2=zeros(1,outputNums);    %初始化输出层和隐藏层的权值和阈值

samplelist=[3,1,2;1,6,3;2,4,7];
expectlist=[1,0,1;0,1,1;1,1,1];

while (count<=maxcount)
    c=1;
    while (c<=samplenum)
        for k=1:outputNums
            d(k)=expectlist(c,k);%获得期望输出的向量
        end
        for i=1:inputNums
            x(i)=samplelist(c,i);%获得输入的向量(数据)
        end
        
        %计算隐藏层输入和输出
        for j=1:hideNums
            net=0.0;
            for i=1:inputNums
                net=net+x(i)*v(i,j);
            end
            net=net-st1(j);
            y(j)=1/(1+exp(-net));
        end
        %end
        
        %计算输出层输入和输出
        for k=1:outputNums
            net=0.0;
            for j=1:hideNums
                net=net+y(j)*w(j,k);
            end
            net=net-st2(k);
            o(k)=1/(1+exp(-net)); 
        end
        %end
        
        %误差;
        errortmp=0.0;
        for k=1:outputNums
            errortmp=errortmp+(d(k)-o(k))^2;
        end
        errorp(c)=0.5*errortmp;
        %end
        
        %一般化误差
        for k=1:outputNums
            yitao(k)=(d(k)-o(k))*o(k)*(1-o(k));
        end
        for j=1:hideNums
            tem=0.0;
            for k=1:outputNums
                tem=tem+yitao(k)*w(j,k);
            end
            yitay(j)=tem*y(j)*(1-y(j));
        end
        %end

        %调整各层权值和阈值
        for j=1:hideNums
            for k=1:outputNums
                w(j,k)=w(j,k)+alpha*yitao(k)*y(j);
            end
        end
        for k=1:outputNums
            st2(k)=st2(k)+alpha*yitao(k);
        end
        for i=1:inputNums
            for j=1:hideNums
                v(i,j)=v(i,j)+beta*yitay(j)*x(i);
            end
        end
        for j=1:hideNums
            st1(j)=st1(j)+beta*yitay(j);
        end
        %end

        c=c+1;
    end
    
    double tmp;
    tmp=0.0;
    for i=1:samplenum
        tmp=tmp+errorp(i)*errorp(i);
    end
    tmp=tmp/c;
    error(count)=sqrt(tmp);

    if (error(count)<precision)
        break;
    end  
    
  count=count+1;%训练次数加1
end

p=1:count;
plot(p,error(p),'-');

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -