⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nnd12cg.m

📁 Martin T.Hagan等著,戴葵等译,神经网络设计,机械工业出版社,一书的所有例程
💻 M
字号:
function nnd12cg(cmd,arg1)
%NND12CG Conjugate gradient backpropagation demonstration.
%
%	This demonstration requires the Neural Network Toolbox.

% First Version, 8-31-95.

%==================================================================


% CONSTANTS
me = 'nnd12cg';
max_t = 0.5;
w_max = 10;
p_max = 2;
circle_size = 6;

% FLAGS
change_func = 0;

% DEFAULTS
if nargin == 0, cmd = ''; else cmd = lower(cmd); end

% FIND WINDOW IF IT EXISTS
fig = nnfgflag(me);
if length(get(fig,'children')) == 0, fig = 0; end
  
% GET WINDOW DATA IF IT EXISTS
if fig
  H = get(fig,'userdata');
  fig_axis = H(1);            % window axis
  desc_text = H(2);           % handle to first line of text sequence
  cont_axis = H(3);           % error contour axis
  cont_ptr = H(4);            % pointer to error contour handles
  radios = H(5:7);          % radio buttons
  option_ptr = H(8);         % index of active radio
  path_ptr = H(9);           % pointer to training path handles
end

%==================================================================
% Activate the window.
%
% ME() or ME('')
%==================================================================

if strcmp(cmd,'')
  if fig
    figure(fig)
    set(fig,'visible','on')
  else
    feval(me,'init')
  end

%==================================================================
% Close the window.
%
% ME() or ME('')
%==================================================================

elseif strcmp(cmd,'close') & (fig)
  delete(fig)

%==================================================================
% Initialize the window.
%
% ME('init')
%==================================================================

elseif strcmp(cmd,'init') & (~fig)

  % CHECK FOR NNT
  if ~nntexist(me), return, end

  % CONSTANTS
  W1 = [10; 10];
  b1 = [-5;5];
  W2 = [1 1];
  b2 = [-1];
  P = -2:0.1:2;
  T = logsig(W2*logsig(W1*P,b1),b2);

  % NEW DEMO FIGURE
  fig = nndemof2(me,'DESIGN','Conjugate Gradient Backprop','','Chapter 12');
  set(fig, ...
    'windowbuttondownfcn',nncallbk(me,'down'), ...
    'BackingStore','off',...
    'nextplot','add');
  H = get(fig,'userdata');
  fig_axis = H(1);
  desc_text = H(2);

  % ICON
  nndicon(12,458,363,'shadow')

  % RADIO BUTTONS
  option = 1;
  radio1 = uicontrol(...
    'units','points',...
    'position',[20 335 130 20],...
    'style','radio',...
    'string','W1(1,1), W2(1,1)',...
    'back',nnltgray,...
    'callback',[me '(''radio'',1)'],...
    'value',1);
  radio2 = uicontrol(...
    'units','points',...
    'position',[155 335 115 20],...
    'style','radio',...
    'string','W1(1,1), b1(1)',...
    'back',nnltgray,...
    'callback',[me '(''radio'',2)']);
  radio3 = uicontrol(...
    'units','points',...
    'position',[270 335 105 20],...
    'style','radio',...
    'string','b1(1), b1(2)',...
    'back',nnltgray,...
    'callback',[me '(''radio'',3)']);

  % ERROR SURFACE
  load nndbp1

  cont_axis = nnsfo('a2','',v1,v2,'');
  set(cont_axis, ...
    'units','points',...
    'position',[50 40 280 280],...
    'color',nnltyell,...
    'xlim',range1,...
    'ylim',range2,...
    'colororder',[0 0 0])
  [dummy,cont_h] = contour(x2,y2,E2,levels);
  set(cont_h,'erasemode','none');
  plot3(range1([1 2 2 1 1]),range2([1 1 2 2 1]),1000*ones(1,5),...
    'color',nndkblue);
  cont_h2 = plot(optx,opty,'+','color',nnred);
  cont_h = [cont_h; cont_h2];
  view(2)

  % BUTTONS
  uicontrol(...
    'units','points',...
    'position',[400 110 60 20],...
    'string','Contents',...
    'callback','nndtoc')
  uicontrol(...
    'units','points',...
    'position',[400 75 60 20],...
    'string','Close',...
    'callback',[me '(''close'')'])

  % DATA POINTERS
  cont_ptr = uicontrol('visible','off','userdata',cont_h);
  option_ptr = uicontrol('visible','off','userdata',option);
  path_ptr = uicontrol('visible','off','userdata',[]);

  % SAVE WINDOW DATA AND LOCK
  H = [fig_axis desc_text cont_axis cont_ptr ...
       radio1 radio2 radio3 option_ptr path_ptr];
  set(fig,'userdata',H,'nextplot','new')

  % INSTRUCTION TEXT
  feval(me,'instr');

  % LOCK WINDOW
  set(fig,'nextplot','new','color',nnltgray);

  nnchkfs;

%==================================================================
% Display the instructions.
%
% ME('instr')
%==================================================================

elseif strcmp(cmd,'instr') & (fig)
  nnsettxt(desc_text,...
    'Use the radio buttons',...
    'to select the network',...
    'parameters to train',...
    'with backpropagation.',...
    '',...
    'The corresponding',...
    'contour plot is',...
    'shown to the left.',...
    '',...
    'Click in the contour',...
    'graph to start the',...
    'conjugate gradient',...
    'learning algorithm.')
    
%==================================================================
% Respond to radio buttons.
%
% ME('radio',i)
%==================================================================

elseif strcmp(cmd,'radio') & (fig) & (nargin == 2)

  % GET DATA
  option = get(option_ptr,'userdata');
  
  % ALTER TRAINABLE PARAMETERS
  if (arg1 ~= option)

    % HIGHLIGHT NEW RADIO BUTTON
    set(radios(option),'value',0)
    set(radios(arg1),'value',1)
    option = arg1;

    % CLEAR AXES
    delete(get(cont_axis,'children'))

    % CONSTANTS
    W1 = [10; 10];
    b1 = [-5;5];
    W2 = [1 1];
    b2 = [-1];
    P = -2:0.1:2;
    T = logsig(W2*logsig(W1*P,b1),b2);

    % ERROR SURFACE & VARIABLE NAMES
    if option == 1
      load nndbp1
    elseif option == 2
      load nndbp2
    else
      load nndbp3
    end

    set(fig,'nextplot','add')
    axes(cont_axis)
    set(get(cont_axis,'xlabel'),'string',v1)
    set(get(cont_axis,'ylabel'),'string',v2)
    set(cont_axis,'xlim',range1,'ylim',range2)
    [dummy,cont_h] = contour(x2,y2,E2,levels);
    set(cont_h,'erasemode','none');
    plot3(range1([1 2 2 1 1]),range2([1 1 2 2 1]),1000*ones(1,5),...
      'color',nndkblue);
    cont_h2 = plot(optx,opty,'+','color',nnred);
    cont_h = [cont_h; cont_h2];
    view(2)

    % STORE DATA
    set(cont_ptr,'userdata',cont_h);
    set(path_ptr,'userdata',[]);
    set(option_ptr,'userdata',option);
  end

%==================================================================
% Respond to mouse down.
%
% ME('down')
%==================================================================

elseif strcmp(cmd,'down') & (fig) & (nargin == 1)

  pt = get(cont_axis,'currentpoint');

  x = pt(1);
  y = pt(3);
  xlim = get(cont_axis,'xlim');
  ylim = get(cont_axis,'ylim');

  if (x > xlim(1) & x < xlim(2) & y > ylim(1) & y < ylim(2))

    % GET DATA
    option = get(option_ptr,'userdata');
    path = get(path_ptr,'userdata');
    cont_h = get(cont_ptr,'userdata');

    % REMOVE PREVIOUS PATH
    set(fig,'nextplot','add')
    set(path,'erasemode','normal');
    delete(path);

    % INITIAL VALUES
    W1 = [10; 10];
    b1 = [-5;5];
    W2 = [1 1];
    b2 = [-1];
    P = -2:0.1:2;
    T = logsig(W2*logsig(W1*P,b1),b2);

    % PLOT START POINT
    dkblue = nndkblue;
    red = nnred;
    axes(cont_axis);
    path = [...
      plot(x,y,'o','color',dkblue,'markersize',8,'erasemode','none');
      plot(x,y,'o','color',[1 1 1],'markersize',10,'erasemode','none');
      plot(x,y,'o','color',dkblue,'markersize',12,'erasemode','none')];
    drawnow

    % PLOT PATH
    set(fig,'pointer','watch')

    % INITIALIZE TRAINING
    Lx = x;
    Ly = y;
    if option == 1
      W1(1,1) = x;
      W2(1,1) = y;
    elseif option == 2
      W1(1,1) = x;
      b1(1) = y;
    else
      b1(1) = x;
      b1(2) = y;
    end

    A1 = logsig(W1*P,b1);
    A2 = logsig(W2*A1,b2);
    E = T-A2;
    fa=sumsqr(E);

    D2 = feval('deltalog',A2,E);
    D1 = feval('deltalog',A1,D2,W2);
    [gW1,gb1] = feval('learnbp',P,D1,1);
    [gW2,gb2] = feval('learnbp',A1,D2,1);

    if (option == 1)
      nrmo = gW1(1,1)^2 + gW2(1,1)^2;
    elseif(option == 2)
      nrmo = gW1(1,1)^2 + gb1(1)^2;
    else
      nrmo = gb1(1)^2 + gb1(2)^2;
    end

    % NORM OF GRADIENT
    nrmrt=sqrt(nrmo);

    % INITIALIZE DIRECTION
    dW1old=gW1;db1old=gb1;dW2old=gW2;db2old=gb2;
    dW1=gW1/nrmrt;db1=gb1/nrmrt;dW2=gW2/nrmrt;db2=gb2/nrmrt;

    % ASSIGN PARAMETERS
    tau=0.618;
    tau1=1-tau;
    scaletol=20;
    delta=0.32;
    delta1=.03;
    tol=delta1/scaletol;
    scale=2.0;
    bmax=26;
    n=2;                 %number of steps before reset


    % MAIN LOOP
    max_epoch = 15;
    xx = [x zeros(1,max_epoch)];
    yy = [y zeros(1,max_epoch)];
    for i=2:(max_epoch+1)
    
      % INITIALIZE A
      a=0;
      aold=0;
      b=delta;
      faold=fa;

      % CALCULATE INITIAL SSE 
      W1t = W1; b1t = b1;
      W2t = W2; b2t = b2;
      if (option == 1)
        newx = W1(1,1) + b*dW1(1,1); W1t(1,1) = newx;
        newy = W2(1,1) + b*dW2(1,1); W2t(1,1) = newy;
      elseif(option == 2)
        newx = W1(1,1) + b*dW1(1,1); W1t(1,1) = newx;
        newy = b1(1)   + b*db1(1);   b1t(1) = newy;
      else
        newx = b1(1) + b*db1(1);   b1t(1) = newx;
        newy = b1(2) + b*db1(2);   b1t(2) = newy;
      end
    
      fb = sumsqr(T - logsig(W2t*logsig(W1t*P,b1t),b2t));
    
      % FIND INITIAL INTERVAL WHERE SSE MINIMUM OCCURS
      while (fa>fb)&(b<bmax)
        aold=a;
        faold=fa;
        fa=fb;
        a=b;
        b=scale*b;
        if (option == 1)
          newx = W1(1,1) + b*dW1(1,1); W1t(1,1) = newx;
          newy = W2(1,1) + b*dW2(1,1); W2t(1,1) = newy;
        elseif(option == 2)
          newx = W1(1,1) + b*dW1(1,1); W1t(1,1) = newx;
          newy = b1(1)   + b*db1(1);   b1t(1) = newy;
        else
          newx = b1(1) + b*db1(1);   b1t(1) = newx;
          newy = b1(2) + b*db1(2);   b1t(2) = newy;
        end
        fb = sumsqr(T - logsig(W2t*logsig(W1t*P,b1t),b2t));
      end
      a=aold;
      fa=faold;
    
      % SHOW INITIAL INTERVAL
      if (option == 1)
        newx = W1(1,1) + a*dW1(1,1);
        newy = W2(1,1) + a*dW2(1,1);
      elseif(option == 2)
        newx = W1(1,1) + a*dW1(1,1);
        newy = b1(1)   + a*db1(1);
      else
        newx = b1(1) + a*db1(1);
        newy = b1(2) + a*db1(2);
      end
    
      if (option == 1)
        newx = W1(1,1) + b*dW1(1,1);
        newy = W2(1,1) + b*dW2(1,1);
      elseif(option == 2)
        newx = W1(1,1) + b*dW1(1,1);
        newy = b1(1)   + b*db1(1);
      else
        newx = b1(1) + b*db1(1);
        newy = b1(2) + b*db1(2);
      end
    
      % INITIALIZE C AND D
      c=a+tau1*(b-a);
      if (option == 1)
        newx = W1(1,1) + c*dW1(1,1); W1t(1,1) = newx;
        newy = W2(1,1) + c*dW2(1,1); W2t(1,1) = newy;
      elseif(option == 2)
        newx = W1(1,1) + c*dW1(1,1); W1t(1,1) = newx;
        newy = b1(1)   + c*db1(1);   b1t(1) = newy;
      else
        newx = b1(1) + c*db1(1);   b1t(1) = newx;
        newy = b1(2) + c*db1(2);   b1t(2) = newy;
      end
      fc = sumsqr(T - logsig(W2t*logsig(W1t*P,b1t),b2t));
      d=b-tau1*(b-a);
      if (option == 1)
        newx = W1(1,1) + d*dW1(1,1); W1t(1,1) = newx;
        newy = W2(1,1) + d*dW2(1,1); W2t(1,1) = newy;
      elseif(option == 2)
        newx = W1(1,1) + d*dW1(1,1); W1t(1,1) = newx;
        newy = b1(1)   + d*db1(1);   b1t(1) = newy;
      else
        newx = b1(1) + d*db1(1);   b1t(1) = newx;
        newy = b1(2) + d*db1(2);   b1t(2) = newy;
      end
      fd = sumsqr(T - logsig(W2t*logsig(W1t*P,b1t),b2t));
    
      % MINIMIZE ALONG A LINE
      k=0;
      while (b-a)>tol 
        if ( (fc<fd)&(fb>=min([fa fc fd])) ) | fa<min([fb fc fd])
          b=d; d=c; fb=fd;
          c=a+tau1*(b-a);
          fd=fc;
          if (option == 1)
            newx = W1(1,1) + c*dW1(1,1); W1t(1,1) = newx;
            newy = W2(1,1) + c*dW2(1,1); W2t(1,1) = newy;
          elseif(option == 2)
            newx = W1(1,1) + c*dW1(1,1); W1t(1,1) = newx;
            newy = b1(1)   + c*db1(1);   b1t(1) = newy;
          else
            newx = b1(1) + c*db1(1);   b1t(1) = newx;
            newy = b1(2) + c*db1(2);   b1t(2) = newy;
          end
          fc = sumsqr(T - logsig(W2t*logsig(W1t*P,b1t),b2t));
    
        else
          a=c; c=d; fa=fc;
          d=b-tau1*(b-a);
          fc=fd;
          if (option == 1)
            newx = W1(1,1) + d*dW1(1,1); W1t(1,1) = newx;
            newy = W2(1,1) + d*dW2(1,1); W2t(1,1) = newy;
          elseif(option == 2)
            newx = W1(1,1) + d*dW1(1,1); W1t(1,1) = newx;
            newy = b1(1)   + d*db1(1);   b1t(1) = newy;
          else
            newx = b1(1) + d*db1(1);   b1t(1) = newx;
            newy = b1(2) + d*db1(2);   b1t(2) = newy;
          end
          fd = sumsqr(T - logsig(W2t*logsig(W1t*P,b1t),b2t));
        end
      end

      % UPDATE VARIABLES
      if (option == 1)
        newx = W1(1,1) + a*dW1(1,1); W1(1,1) = newx;
        newy = W2(1,1) + a*dW2(1,1); W2(1,1) = newy;
      elseif(option == 2)
        newx = W1(1,1) + a*dW1(1,1); W1(1,1) = newx;
        newy = b1(1)   + a*db1(1);   b1(1) = newy;
      else
        newx = b1(1) + a*db1(1);   b1(1) = newx;
        newy = b1(2) + a*db1(2);   b1(2) = newy;
      end
      xx(i) = newx;
      yy(i) = newy;
    
      % CALCULATE GRADIENT
      A1 = logsig(W1*P,b1);
      A2 = logsig(W2*A1,b2);
      E = T-A2;
      D2 = feval('deltalog',A2,E);
      D1 = feval('deltalog',A1,D2,W2);
      [gW1,gb1] = feval('learnbp',P,D1,1);
      [gW2,gb2] = feval('learnbp',A1,D2,1);
    
      % NORM SQUARE OF GRADIENT
      if (option == 1)
        nrmn = gW1(1,1)^2 + gW2(1,1)^2;
      elseif(option == 2)
        nrmn = gW1(1,1)^2 + gb1(1)^2;
      else
        nrmn = gb1(1)^2 + gb1(2)^2;
      end

      % CALCULATE DIRECTION
      if rem(i,n)==0
        Z=0;
      else
        Z=nrmn/nrmo;
      end

      % CALCULATE NEW DIRECTIONS
      dW1new = gW1 + dW1old*Z; db1new = gb1 + db1old*Z;
      dW2new = gW2 + dW2old*Z; db2new = gb2 + db2old*Z;

      % SAVE NEW DIRECTIONS
      dW1old = dW1new; db1old = db1new;
      dW2old = dW2new; db2old = db2new;
      nrmo=nrmn;

      %NORMALIZE DIRECTIONS
      if (option == 1)
        nrm = sqrt(dW1new(1,1)^2 + dW2new(1,1)^2);
      elseif(option == 2)
        nrm = sqrt(dW1new(1,1)^2 + db1new(1)^2);
      else
        nrm = sqrt(db1new(1)^2 + db1new(2)^2);
      end
      dW1=dW1new/nrm;db1=db1new/nrm;dW2=dW2new/nrm;db2=db2new/nrm;
    end

    % CONTOUR PLOT
    path = [path; plot(xx,yy,'color',nnred); plot(xx,yy,'o','color',nnred,'markersize',6)];
    set(fig,'nextplot','new')
    
    % SAVE DATA
    set(path_ptr,'userdata',path);
    set(fig,'pointer','arrow')

  end
end

    

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -