⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 tbpx2.m

📁 一个BP网络改进的M程序
💻 M
字号:
function [w1,b1,w2,b2,i,tr] = tbpx2(w1,b1,f1,w2,b2,f2,p,t,tp)
%TBPX2 Train 2-layer feed-forward network w/fast backpropagation.
%  
%  This function is obselete.
%  Use NNT2FF and TRAIN to update and train your network.

nntobsf('tbpx2','Use NNT2FF and TRAIN to update and train your network.')

%  [W1,B1,W2,B2,TE,TR] = TBPX2(W1,B2,F1,W1,B1,F2,P,T,TP)
%    Wi - Weight matrix for the ith layer.
%    Bi - Bias vector for the ith layer.
%    Fi - Transfer function (string) for the ith layer.
%    P  - RxQ matrix of input vectors.
%    T  - SxQ matrix of target vectors.
%    TP - Training parameters (optional).
%  Returns:
%    Wi - new weights.
%    Bi - new biases.
%    TE - the actual number of epochs trained.
%    TR - training record: [row of errors]
%  
%  Training parameters are:
%    TP(1) - Epochs between updating display, default = 25.
%    TP(2) - Maximum number of epochs to train, default = 1000.
%    TP(3) - Sum-squared error goal, default = 0.02.
%    TP(4) - Learning rate, 0.01.
%    TP(5) - Learning rate increase, default = 1.05.
%    TP(6) - Learning rate decrease, default = 0.7.
%    TP(7) - Momentum constant, default = 0.9.
%    TP(8) - Maximum error ratio, default = 1.04.
%  Missing parameters and NaN's are replaced with defaults.

% Mark Beale, 1-31-92
% Revised 12-15-93, MB
% Copyright 1992-2002 The MathWorks, Inc.
% $Revision: 1.12 $

if nargin < 8,error('Not enough arguments.');end

% TRAINING PARAMETERS
if nargin == 8, tp = []; end
tp = nndef(tp,[25 1000 0.02 1.05 0.7 0.9 1.04]);
df = tp(1);
me = tp(2);
eg = tp(3);

im = tp(4);
dm = tp(5);
mc = tp(6);
er = tp(7);
df1 = feval(f1,'delta');
df2 = feval(f2,'delta');
[S1,C1] = size(w1);

lr1 = 0.0005*ones(S1,C1);    %学习率
lr2 = 0.0005*ones(1,S1);
ls1 = 0.0005*ones(S1,1); 
ls2 = 0.0005*ones(1,1);
lr = lr2(1,1);
new_dk1 = w1*0;
new_dd1 = b1*0;
new_dk2 = w2*0;
new_dd2 = b2*0;
dw1 = w1*0;
db1 = b1*0;
dw2 = w2*0;
db2 = b2*0;
dk1 = ones(size(w1));
dd1 = ones(size(b1));
dk2 = ones(size(w2));
dd2 = ones(size(b2));
MC = 0;

% PRESENTATION PHASE
a1 = feval(f1,w1*p,b1);
a2 = feval(f2,w2*a1,b2);
e = t-a2;
SSE = sumsqr(e);

% TRAINING RECORD
tr = zeros(2,me+1);
tr(1:2,1) = [SSE; lr];

% PLOTTING FLAG
[r,q] = size(p);
[s,q] = size(t);
plottype = (max(r,s) == 1);

% PLOTTING
newplot;
message = sprintf('TRAINBPX: %%g/%g epochs, lr = %%g, SSE = %%g.\n',me);
fprintf(message,0,lr,SSE)
if plottype
  h = plotfa(p,t,p,a2);
else
  h = plottr(tr(1:2,1),eg);
end

% BACKPROPAGATION PHASE
d2 = feval(df2,a2,e);
d1 = feval(df1,a1,d2,w2);


for i=1:me

  % CHECK PHASE
  if SSE < eg, i=i-1; break, end

  % LEARNING PHASE
  [dw1,db1,new_dk1,new_dd1] = newlearnbpm(p,d1,lr1,ls1,MC,dw1,db1);
  [dw2,db2,new_dk2,new_dd2] = newlearnbpm(a1,d2,lr2,ls2,MC,dw2,db2);
  MC = mc;
  
  new_w1b = w1; new_b1b = b1;
  new_w2b = w2; new_b2b = b2;
  new_w1 = w1 + dw1; new_b1 = b1 + db1;
  new_w2 = w2 + dw2; new_b2 = b2 + db2;

  % PRESENTATION PHASE
  new_a1 = feval(f1,new_w1*p,new_b1);
  new_a2 = feval(f2,new_w2*new_a1,new_b2);
  new_e = t-new_a2;
 
  new_SSE = sumsqr(new_e);
  
  

  % MOMENTUM & ADAPTIVE LEARNING RATE PHASE
  
        for nr1=1:S1
        for nc1=1:C1
            if dk1(nr1,nc1) * new_dk1(nr1,nc1) >= 0

   lr1(nr1,nc1) =  lr1(nr1,nc1) * im;
else
    
   lr1(nr1,nc1) =  lr1(nr1,nc1) * dm;
end
end
end
for ns1=1:S1
    if dd1(ns1,1) * new_dd1(ns1,1) >=0
        ls1(ns1,1) = ls1(ns1,1) * im;
    else
        ls1(ns1,1) = ls1(ns1,1) * dm;
    end
end

for nc2=1:S1
        if dk2(1,nc2) * new_dk2(1,nc2) >= 0
             lr2(1,nc2) =  lr2(1,nc2) * im;
         else
              lr2(1,nc2) =  lr2(1,nc2) * dm;
          end
      end
  if dd2 * new_dd2 >=0
          ls2 = ls2 * im;
      else
          ls2 = ls2 * dm;
      end
      
  
    w1 = new_w1; b1 = new_b1; a1 = new_a1;
    w2 = new_w2; b2 = new_b2; a2 = new_a2;
    e = new_e; SSE = new_SSE;
  dk1 = new_dk1;
  dk2 = new_dk2;
  dd2 = new_dd2;
  dd1 = new_dd1;
    % BACKPROPAGATION PHASE
   d2 = feval(df2,a2,e);
  d1 = feval(df1,a1,d2,w2);
   lr = lr2(1,1);
   

  % TRAINING RECORD
  tr(:,i+1) = [SSE; lr];

  % PLOTTING
  if rem(i,df) == 0
    fprintf(message,i,lr,SSE)
    if plottype
      delete(h);
      h = plot(p,a2);
    drawnow
    else
      h = plottr(tr(1:2,1:(i+1)),eg,h);
    end
  end
end

% TRAINING RECORD
tr = tr(1:2,1:(i+1));

% PLOTTING
if rem(i,df) ~= 0
  fprintf(message,i,lr,SSE)
  if plottype
    delete(h);
    plot(p,a2);
  drawnow
  else
    plottr(tr,eg,h);
  end
end

% WARNINGS
if SSE > eg
  disp(' ')
  disp('TRAINBPX: Network error did not reach the error goal.')
  disp('  Further training may be necessary, or try different')
  disp('  initial weights and biases and/or more hidden neurons.')
  disp(' ')
end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -