⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nnd12mo.m

📁 《神经网络设计》英文版的配套源代码
💻 M
📖 第 1 页 / 共 2 页
字号:
%==================================================================
% Display the instructions.
%
% ME('instr')
%==================================================================

elseif strcmp(cmd,'instr') & (fig)
  nnsettxt(desc_text,...
    'Use the radio buttons',...
    'to select the network',...
    'parameters to train',...
    'with backpropagation.',...
    '',...
    'The corresponding',...
    'contour plot is',...
    'shown below.',...
    '',...
    'Click in the contour',...
    'graph to start the',...
    'momentum backprop',...
    'learning algorithm.',...
    'You can reset the',...
    'algorithm parameters',...
    'using the sliders.')
    
%==================================================================
% Respond to radio buttons.
%
% ME('radio',i)
%==================================================================

elseif strcmp(cmd,'radio') & (fig) & (nargin == 2)

  % GET DATA
  option = get(option_ptr,'userdata');
  
  % ALTER TRAINABLE PARAMETERS
  if (arg1 ~= option)

    % HIGHLIGHT NEW RADIO BUTTON
    set(radios(option),'value',0)
    set(radios(arg1),'value',1)
    option = arg1;

    % CLEAR AXES
    delete(get(cont_axis,'children'))

    % CONSTANTS
    W1 = [10; 10];
    b1 = [-5;5];
    W2 = [1 1];
    b2 = [-1];
    P = -2:0.1:2;
    [R,Q] = size(P);
    %A1 = logsig(W1*P+b1*ones(1,Q));
    %T = logsig(W2*A1+b2*ones(1,Q));
    A1 = nndlogsig(W1*P+b1*ones(1,Q));
    T = nndlogsig(W2*A1+b2*ones(1,Q));

    % ERROR SURFACE & VARIABLE NAMES
    if option == 1
      load nndbp1
      set(variables([1 4]),'color',[1 1 1])
      set(variables([2 3]),'color',nndkblue)
    elseif option == 2
      load nndbp2
      set(variables([1 2]),'color',[1 1 1])
      set(variables([3 4]),'color',nndkblue)
    else
      load nndbp3
      set(variables([2 3]),'color',[1 1 1])
      set(variables([1 4]),'color',nndkblue)
    end

    set(fig,'nextplot','add')
    axes(cont_axis)
    set(get(cont_axis,'xlabel'),'string',v1)
    set(get(cont_axis,'ylabel'),'string',v2)
    set(cont_axis,'xlim',range1,'ylim',range2)
    [dummy,cont_h] = contour(x2,y2,E2,levels);
    set(cont_h,'erasemode','none');
    plot3(range1([1 2 2 1 1]),range2([1 1 2 2 1]),1000*ones(1,5),...
      'color',nndkblue);
    cont_h2 = plot(optx,opty,'+','color',nnred);
    cont_h = [cont_h; cont_h2];
    view(2)

    % STORE DATA
    set(cont_ptr,'userdata',cont_h);
    set(path_ptr,'userdata',[]);
    set(option_ptr,'userdata',option);
  end

%==================================================================
% Respond to learning rate slider.
%
% ME('lr')
%==================================================================

elseif strcmp(cmd,'lr')
  
  set(lr_bar,'enable','off');
  lr = get(lr_bar,'value');
  set(lr_text,'string',sprintf('%4.1f',round(lr*10)*0.1))

%==================================================================
% Respond to momentum constant slider.
%
% ME('mc')
%==================================================================

elseif strcmp(cmd,'mc')
  
  set(mc_bar,'enable','off');
  mc = get(mc_bar,'value');
  set(mc_text,'string',sprintf('%4.2f',round(mc*100)*0.01))

end
%==================================================================
% Respond to mouse down.
%
% ME('down')
%==================================================================

if (strcmp(cmd,'down') && (fig) && (nargin == 1)) || strcmp(cmd,'lr') || strcmp(cmd,'mc')

  if strcmp(cmd,'lr') || strcmp(cmd,'mc')
    temp=get(path_ptr,'userdata');
    if ~isnan(temp)
      temp2=get(temp(1),'xdata');
      x=temp2(1);
      temp2=get(temp(1),'ydata');
      y=temp2(1);
    else
      x = NaN;
      y = NaN;
    end
  else
    pt = get(cont_axis,'currentpoint');

    x = pt(1);
    y = pt(3);
  end
  xlim = get(cont_axis,'xlim');
  ylim = get(cont_axis,'ylim');

  if (x > xlim(1) & x < xlim(2) & y > ylim(1) & y < ylim(2))

    % GET DATA
    option = get(option_ptr,'userdata');
    path = get(path_ptr,'userdata');
    cont_h = get(cont_ptr,'userdata');

    % REMOVE PREVIOUS PATH
    set(fig,'nextplot','add')
    delete(path);

    % INITIAL VALUES
    W1 = [10; 10];
    b1 = [-5;5];
    W2 = [1 1];
    b2 = [-1];
    P = -2:0.1:2;
    [R,Q] = size(P);
    %A1 = logsig(W1*P+b1*ones(1,Q));
    %T = logsig(W2*A1+b2*ones(1,Q));
    A1 = nndlogsig(W1*P+b1*ones(1,Q));
    T = nndlogsig(W2*A1+b2*ones(1,Q));

    % PLOT START POINT
    dkblue = nndkblue;
    red = nnred;
    axes(cont_axis);
    path = [...
      plot(x,y,'o','color',dkblue,'markersize',8,'erasemode','none');
      plot(x,y,'o','color',[1 1 1],'markersize',10,'erasemode','none');
      plot(x,y,'o','color',dkblue,'markersize',12,'erasemode','none')];
    drawnow;

    % PLOT PATH
    set(fig,'pointer','watch')

    % INITIALIZE TRAINING
    if option == 1
      ep = 300;
      W1(1,1) = x;
      W2(1,1) = y;
    elseif option == 2
      ep = 300;
      W1(1,1) = x;
      b1(1) = y;
    else
      ep = 60;
      b1(1) = x;
      b1(2) = y;
    end
    lr = get(lr_bar,'value');
    mc = get(mc_bar,'value');

    %A1 = logsig(W1*P+b1*ones(1,Q));
    %A2 = logsig(W2*A1+b2*ones(1,Q));
    A1 = nndlogsig(W1*P+b1*ones(1,Q));
    A2 = nndlogsig(W2*A1+b2*ones(1,Q));
    E = T-A2;

    xx = [x zeros(1,ep)];
    yy = [y zeros(1,ep)];
    ee = [sum(sum(E.*E)) zeros(1,ep)]; %[sumsqr(E) zeros(1,ep)];

    dW1 = 0;
    db1 = 0;
    dW2 = 0;
    db2 = 0;

    % TRAINING #1
    if option == 1
      for i=2:(ep+1)
        SSE = sum(sum(E.*E)); %sumsqr(E);
        D2 = A2.*(1-A2).*E;
        D1 = A1.*(1-A1).*(W2'*D2);
        dW1 = mc*dW1 + (1-mc)*D1*P'*lr;
        db1 = mc*db1 + (1-mc)*D1*ones(Q,1)*lr;
        dW2 = mc*dW2 + (1-mc)*D2*A1'*lr;
        db2 = mc*db2 + (1-mc)*D2*ones(Q,1)*lr;
      
        newx = W1(1,1) + dW1(1,1); W1(1,1) = newx; xx(i) = newx;
        newy = W2(1,1) + dW2(1,1); W2(1,1) = newy; yy(i) = newy;

        %A1 = logsig(W1*P+b1*ones(1,Q));
        %A2 = logsig(W2*A1+b2*ones(1,Q));
        A1 = nndlogsig(W1*P+b1*ones(1,Q));
        A2 = nndlogsig(W2*A1+b2*ones(1,Q));
        E = T-A2;
        ee(i) = sum(sum(E.*E)); %sumsqr(E);
      end

    % TRAINING #2
    elseif option == 2
      for i=2:(ep+1)
        SSE = sum(sum(E.*E)); %sumsqr(E);
        D2 = A2.*(1-A2).*E;
        D1 = A1.*(1-A1).*(W2'*D2);
        dW1 = mc*dW1 + (1-mc)*D1*P'*lr;
        db1 = mc*db1 + (1-mc)*D1*ones(Q,1)*lr;
        dW2 = mc*dW2 + (1-mc)*D2*A1'*lr;
        db2 = mc*db2 + (1-mc)*D2*ones(Q,1)*lr;
      
        newx = W1(1,1) + dW1(1,1);  W1(1,1) = newx; xx(i) = newx;
        newy = b1(1)   + db1(1);    b1(1) = newy;   yy(i) = newy;

        %A1 = logsig(W1*P+b1*ones(1,Q));
        %A2 = logsig(W2*A1+b2*ones(1,Q));
        A1 = nndlogsig(W1*P+b1*ones(1,Q));
        A2 = nndlogsig(W2*A1+b2*ones(1,Q));
        E = T-A2;
        ee(i) = sum(sum(E.*E)); %sumsqr(E);
      end

   % TRAINING #3
   else
      for i=2:(ep+1)
        SSE = sum(sum(E.*E)); %sumsqr(E);
        D2 = A2.*(1-A2).*E;
        D1 = A1.*(1-A1).*(W2'*D2);
        dW1 = mc*dW1 + (1-mc)*D1*P'*lr;
        db1 = mc*db1 + (1-mc)*D1*ones(Q,1)*lr;
        dW2 = mc*dW2 + (1-mc)*D2*A1'*lr;
        db2 = mc*db2 + (1-mc)*D2*ones(Q,1)*lr;
      
        newx = b1(1) + db1(1);   b1(1) = newx;   xx(i) = newx;
        newy = b1(2) + db1(2);   b1(2) = newy;   yy(i) = newy;

        %A1 = logsig(W1*P+b1*ones(1,Q));
        %A2 = logsig(W2*A1+b2*ones(1,Q));
        A1 = nndlogsig(W1*P+b1*ones(1,Q));
        A2 = nndlogsig(W2*A1+b2*ones(1,Q));
        E = T-A2;
        ee(i) = sum(sum(E.*E)); %sumsqr(E);
      end
    end

    % CONTOUR PLOT
    path = [path; plot(xx,yy,'color',nnred); plot(xx,yy,'o','color',nnred,'markersize',6)];
    set(fig,'nextplot','new')
    
    % SAVE DATA
    set(path_ptr,'userdata',path);
    set(fig,'pointer','arrow')

  end
  set(mc_bar,'enable','on');
  set(lr_bar,'enable','on');
end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -