⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nnd17no.m

📁 《神经网络设计》英文版的配套源代码
💻 M
📖 第 1 页 / 共 2 页
字号:
  axes(fa_axis)
  %plot(P,T,'color',nndkblue,'linewidth',3);
  plot(P,T,'color',nnblue,'linewidth',3);
  set(get(fa_axis,'ylabel'),...
    'string','Target')

  axes(fb_axis)
  delete(get(fb_axis,'children'))
  fb_plot = plot([P(1) P(end)],[0 0],'k');

  set(fig,'nextplot','new')

  set(lls_radiob,'value',0);
  set(ols_radiob,'value',0);
  set(rand_radiob,'value',0);
  set(but_train,'enable','off');  
  
  
%==================================================================
% Respond to difficulty index slider.
%
% ME('i')
%==================================================================

elseif strcmp(cmd,'i')
  
  i = get(i_bar,'value');
  i = round(i);
  set(i_text,'string',sprintf('%g',i))
  
  set(fig,'nextplot','add')
  delete(get(fa_axis,'children'))
  P = -2:(.4/i):2;
  T = 1 + sin(i*pi*P/4);
  axes(fa_axis)
  %plot(P,T,'color',nndkblue,'linewidth',3);
  plot(P,T,'color',nnblue,'linewidth',3);
  set(get(fa_axis,'ylabel'),...
    'string','Target')

  axes(fb_axis)
  delete(get(fb_axis,'children'))
  fb_plot = plot([P(1) P(end)],[0 0],'k');

  set(fig,'nextplot','new')

  set(lls_radiob,'value',0);
  set(ols_radiob,'value',0);
  set(rand_radiob,'value',0);
  set(but_train,'enable','off');  
  
  
%==================================================================
% Respond to train button.
%
% ME('train')
%==================================================================

elseif strcmp(cmd,'train') & (fig) & (nargin == 1)

  set(but_train,'enable','off');  
    
  set(fig,'nextplot','add')
  
  set(fig,'pointer','watch')
  
  i = round(get(i_bar,'value'));

  P = -2:(.4/i):2;
  Pfine = -2:(.4/1000):2;
  T = 1 + sin(i*pi*P/4);
  [R,Q] = size(P);
  P2 = P;
  [R,Q2] = size(P2);
  [R,Q2fine] = size(Pfine);

  S1 = round(get(s1_bar,'value'));
  R = 1;
  S2 = 1;

  W10=get(W1_ptr,'userdata');
  B10=get(B1_ptr,'userdata');
  W20=get(W2_ptr,'userdata');
  B20=get(B2_ptr,'userdata');
  

  err_goal = 0.005;
  max_epoch = 200;
  mingrad=0.001;
  mu_initial=.01;
  v=10;
  maxmu=1e10;

  axes(fa_axis)
  set(get(fa_axis,'children'),'erasemode','normal');
  delete(get(fa_axis,'children'))
  
  delete(get(fb_axis,'children'))

  %A = W20*logsig(W10*P2+B10*ones(1,Q2))+B20*ones(1,Q2);
  pp2 = repmat(P2,S1,1);
  pp2fine = repmat(Pfine,S1,1);
  % If size of Input weights smaller than requested weights, we adjust W10
  if length(W10)<S1
      W10=[W10; zeros(S1-length(W10),1)];
      B10=[B10; zeros(S1-length(B10),1)];
      W20=[W20 zeros(S1-length(W20),1)];     
  end
  n12fine = abs(pp2fine-W10*ones(1,Q2fine)).*(B10*ones(1,Q2fine));
  a12fine = exp(-n12fine.^2);
  n12 = abs(pp2-W10*ones(1,Q2)).*(B10*ones(1,Q2));
  a12 = exp(-n12.^2);
  A = W20*a12 + B20*ones(1,Q2);
  Target = plot(P,T,'-','color',nnred,'linewidth',3,'erasemode','none');
  %Target = plot(P,T,'-','color',nndkblue,'linewidth',3,'erasemode','none');

  AA = A;
  %ind = find((AA < -0.01) | (AA > 2.01));
  %if length(ind)
  %  AA(ind) = AA(ind)+NaN;
  %end

  temp=[(W20'*ones(1,Q2)).*a12; B20*ones(1,Q2)];
  fa_plot2 = plot(P,temp,':k');
  
  Attempt = plot(P2,AA,'-','color',nndkblue,'linewidth',2,'erasemode','none');
  %Attempt = plot(P2,AA,'-','color',nnred,'linewidth',2,'erasemode','none');

  axes(fb_axis)
  fb_plot = plot(pp2fine(1,:),a12fine,'k');

  drawnow

%%%%%%%%%%%%%%%%%%%%%%%%%% BEGINNING OF MARTIN'S CODE

% DEFINE SIZES
RS = S1*R; RS1 = RS+1; RSS = RS + S1; RSS1 = RSS + 1;
RSS2 = RSS + S1*S2; RSS3 = RSS2 + 1; RSS4 = RSS2 + S2;

%%%%%%%%%%%%%%%%%%%%%%%%%%

W1=W10;B1=B10;W2=W20;B2=B20;
dW1=W10;dB1=B10;dW2=W20;dB2=B20;

%%%%%%%%%%%%%%%%%%%%%%%%%%

mu=mu_initial;
ii=eye(RSS4);
meu=zeros(max_epoch,1);
mer=meu;grad=meu;
%A1 = logsig(W1*P+B1*ones(1,Q));
A1 = exp(-(abs(pp2-W1*ones(1,Q2)).*(B1*ones(1,Q2))).^2);
%A2 = W2*A1+B2*ones(1,Q);
A2 = W2*A1 + B2*ones(1,Q2);
E1 = T-A2;
%f1=sumsqr(E1);
f1=sum(sum(E1.*E1));
flops(0);

% MAIN LOOP

t1=clock;
tstnan=0;
for k=1:max_epoch,
  mu=mu/v;
  mer(k)=f1;
  meu(k)=mu;
  tst=1;

% FIND JACOBIAN
  A1 = kron(A1,ones(1,S2));
  D2 = nnmdlin(A2);
  %D1 = nnmdlog(A1,D2,W2);
  SS2=-2*(abs(pp2-W1*ones(1,Q2)).*(B1*ones(1,Q2))).*A1.*(W2'*D2);
  den=abs(pp2-W1*ones(1,Q2));
  flg=den~=0;
  den=den+~flg;
  D1 = SS2.*((B1*ones(1,Q2)).*(W1*ones(1,Q2)-pp2)).*flg./den;
  D1b = SS2.*abs(pp2-W1*ones(1,Q2));
  %jac1 = nnlmarq(kron(P,ones(1,S2)),D1);
  jac1 = D1';
  jac2 = nnlmarq(A1,D2);
  %jac=[jac1,D1',jac2,D2'];
  jac=[jac1,D1b',jac2,D2'];

% CHECK THE MAGNITUDE OF THE GRADIENT
  E1=E1(:);
  je=jac'*E1;
  grad(k)=norm(je);
  if grad(k)<mingrad,
    mer=mer(1:k);
    meu=meu(1:k);
    grad=grad(1:k);
    disp('Training has stopped.')
    disp('Local minumum reached. Gradient is close to zero.')
    fprintf('Magnitude of gradient = %g.\n',grad(k));
    set(but_train,'enable','on');  
    break
  end

% INNER LOOP, INCREASE mu UNTIL THE ERRORS ARE REDUCED
  jj=jac'*jac;
  while tst>0,
    dw=-(jj+ii*mu)\je;
    % ODJ Under some conditions we may get NaN in dw, so we exit
    if isnan(dw)
      tstnan=1;
      tst=0;
    else
      dW1(:)=dw(1:RS);
      dB1=dw(RS1:RSS);
      dW2(:)=dw(RSS1:RSS2);
      dB2=dw(RSS3:RSS4);
      W1n=W1+dW1;B1n=B1+dB1;W2n=W2+dW2;
      B2n=B2+dB2;
      %A1 = logsig(W1n*P+B1n*ones(1,Q));
      A1 = exp(-(abs(pp2-W1n*ones(1,Q)).*(B1n*ones(1,Q))).^2);
      %A2 = W2n*A1+B2n*ones(1,Q);
      A2 = W2n*A1 + B2n*ones(1,Q);
      E2 = T-A2;
      %f2=sumsqr(E2);  
      f2=sum(sum(E2.*E2));  
      if f2>=f1,
        mu=mu*v;

%  TEST FOR MAXIMUM mu
        if (mu > maxmu),
          mer=mer(1:k);
          meu=[meu(1:k);mu];
          grad=grad(1:k);
          disp('Maximum mu exceeded.')
          fprintf('mu = %g.\n',mu);
          fprintf('Maximum allowable mu = %g.\n',maxmu);
          set(but_train,'enable','on');  
          break;
        end
      else
        tst=0;
      end
    end            
  end
  
  if tstnan
      set(but_train,'enable','on');  
      break;
  end

%  TEST IF THE ERROR REACHES THE ERROR GOAL
  if f2<=err_goal,
    f1=f2;
    W1=W1n;B1=B1n;W2=W2n;B2=B2n;
    mer=[mer(1:k);f2];
    meu=[meu(1:k);mu];
    grad=grad(1:k);
    disp('Training has stopped. Goal achieved.')
    set(but_train,'enable','on');  
    break; 
  end

  if(mu>maxmu),
    set(but_train,'enable','on');  
    break;
  end

  W1=W1n;B1=B1n;W2=W2n;B2=B2n;E1=E2;
  f1=f2;

  %%%%%%%%%%%%%%%%%%%%%%%%% PLOTTING ALTERED BY MARK
  if (R==1)&(S2==1),
    n12fine = abs(pp2fine-W1*ones(1,Q2fine)).*(B1*ones(1,Q2fine));
    a12fine = exp(-n12fine.^2);
    n12 = abs(pp2-W1*ones(1,Q2)).*(B1*ones(1,Q2));
    a12 = exp(-n12.^2);
    %A = W2*logsig(W1*P2+B1*ones(1,Q2))+B2*ones(1,Q2);
    A = W2*a12 + B2*ones(1,Q2);
    set(Attempt,'color',nnltyell);
    set(Attempt,'visible','off');
    %set(Target,'color',nnred);
    set(Target,'color',nndkblue);

    temp=[(W2'*ones(1,Q2)).*a12; B2*ones(1,Q2)];
    for ktemp=1:length(fa_plot2)
       set(fa_plot2(ktemp),'ydata',temp(ktemp,:));     
    end
    for ktemp=1:length(fb_plot)
       set(fb_plot(ktemp),'ydata',a12fine(ktemp,:));     
    end
   %fa_plot2 = plot(P,temp,':k');
      
    AA = A;
    ind = find((AA < -0.01) | (AA > 2.01));
    if length(ind)
      AA(ind) = AA(ind)+NaN;
    end
    set(Attempt,'ydata',AA);
    set(Attempt,'color',nndkblue,'visible','on');
    %set(Attempt,'color',nnred,'visible','on');
    drawnow
  end
  
  pause(0.025);

end

%%%%%%%%%%%%%%%%%%%%%%%%%% END OF MARTIN'S CODE

    n12 = abs(pp2-W1*ones(1,Q2)).*(B1*ones(1,Q2));
    a12 = exp(-n12.^2);
    %A = W2*logsig(W1*P2+B1*ones(1,Q2))+B2*ones(1,Q2);
    A = W2*a12 + B2*ones(1,Q2);
    set(Attempt,'color',nnltyell);
    set(Attempt,'visible','off');
    %set(Target,'color',nnred);
    set(Target,'color',nndkblue);

    temp=[(W2'*ones(1,Q2)).*a12; B2*ones(1,Q2)];
    for ktemp=1:length(fa_plot2)
       set(fa_plot2(ktemp),'ydata',temp(ktemp,:));     
    end
    %fa_plot2 = plot(P,temp,':k');
      
    AA = A;
    ind = find((AA < -0.01) | (AA > 2.01));
    if length(ind)
      AA(ind) = AA(ind)+NaN;
    end
    set(Attempt,'ydata',AA);
    set(Attempt,'color',nndkblue,'visible','on');
    %set(Attempt,'color',nnred,'visible','on');
    drawnow

  set(fig,'nextplot','new')
  
  if (k==max_epoch),
    disp('Training has stopped.')
    disp('Maximum number of epochs was reached.')
    fprintf('epochs = %g.\n',k);
    fprintf('Final error = %g.\n',f2);
  end

  %set(lls_radiob,'value',0);
  %set(ols_radiob,'value',0);
  %set(rand_radiob,'value',0);
  
  set(fig,'pointer','arrow')
  set(but_train,'enable','on');  

end



if strcmp(cmd,'update_plot')
  set(fig,'nextplot','add')
  delete(get(fa_axis,'children'))
  P = -2:(.4/i):2;
  Pfine = -2:(.4/1000):2;
  T = 1 + sin(i*pi*P/4);
  axes(fa_axis)
  %plot(P,T,'color',nndkblue,'linewidth',3);
  plot(P,T,'color',nnblue,'linewidth',3);
  set(get(fa_axis,'ylabel'),...
    'string','Target')

  delete(get(fb_axis,'children'))

  [R,Q2] = size(P);
  [R,Q2fine] = size(Pfine);
  pp2 = repmat(P,S1,1);
  pp2fine = repmat(Pfine,S1,1);

  if length(W10)<S1
      W10=[W10; zeros(S1-length(W10),1)];
      B10=[B10; zeros(S1-length(B10),1)];
      W20=[W20 zeros(S1-length(W20),1)];     
  end
  n12fine = abs(pp2fine-W10*ones(1,Q2fine)).*(B10*ones(1,Q2fine));
  a12fine = exp(-n12fine.^2);
  n12 = abs(pp2-W10*ones(1,Q2)).*(B10*ones(1,Q2));
  a12 = exp(-n12.^2);
  %A = W20*exp(-(abs(pp2-W10*ones(1,Q2)).*(B10*ones(1,Q2))).^2) + B20*ones(1,Q2);
  A = W20*a12 + B20*ones(1,Q2);
  %set(Attempt,'color',nnltyell);
  %set(Attempt,'visible','off');
  %set(Target,'color',nnred);

  AA = A;
  %ind = find((AA < -0.01) | (AA > 2.01));
  %if length(ind)
  %  AA(ind) = AA(ind)+NaN;
  %end
  
  temp=[(W20'*ones(1,Q2)).*a12; B20*ones(1,Q2)];
  fa_plot2 = plot(P,temp,':k');
  
  Attempt = plot(P,AA,'-','color',nndkblue,'linewidth',2,'erasemode','none');
  %Attempt = plot(P,AA,'-','color',nnred,'linewidth',2,'erasemode','none');
  %set(Attempt,'ydata',AA);
  %set(Attempt,'color',nndkblue,'visible','on');

  axes(fb_axis)
  fb_plot = plot(pp2fine(1,:),a12fine,'k');
  
  drawnow

  set(but_train,'enable','on');

  set(fig,'nextplot','new')
end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -