⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nnd4pr.m

📁 Martin T.Hagan等著,戴葵等译,神经网络设计,机械工业出版社,一书的所有例程
💻 M
📖 第 1 页 / 共 2 页
字号:
      set(dot_ptr,'userdata',1);
    
    else
      pt = get(t2_axis,'currentpoint');
      x = pt(1);
      y = pt(3);

      % WHITE DOT
      if (x >= -.15) & (x <= .15) & (y >= -.15) & (y <= .15)
         set(fig,...
         'pointer','circle',...
         'WindowButtonUpFcn',[me '(''setdot'')'])
        set(dot_ptr,'userdata',0);
      end
    end
  end

%==================================================================
% Set decision dot.
%
% ME('setdot')
%==================================================================

elseif strcmp(cmd,'setdot') & (fig) & (nargin == 1)

  pt = get(v_axis,'currentpoint');
  x = pt(1);
  x = round(x*2)/2;
  y = pt(3);
  y = round(y*2)/2;

  dot = get(dot_ptr,'userdata');
 
  if (x >= -p_max) & (x <= p_max) & (y >= -p_max) & (y <= p_max)
    P = get(P_ptr,'userdata');
    T = get(T_ptr,'userdata');
    w = get(w_ptr,'userdata');
    b = get(b_ptr,'userdata');
    dots = get(dots_ptr,'userdata');
    
    q = length(dots)+1;
    deg = pi/180;
    angle = [0:5:360]*deg;
    cx = cos(angle)*0.15;
    cy = sin(angle)*0.15;
    set(fig,'nextplot','add');
    axes(v_axis);

    P = [P [x;y]];
    T = [T dot];

    if all(T == hardlim(w*P,b))
      col = nndkblue;
    else
      col = nnred;
    end
    set(db_line,...
      'color',col);

    a = hardlim(w*[x;y]+b);
    if (a == dot)
      col = nndkgray;
    else
      col = nnred;
    end
    dots(q) = fill(cx+x,cy+y,[1 1 1]-dot,...
      'edgecolor',col,...
      'erasemode','none',...
      'linewidth',2);
    set(fig,'nextplot','new');
  
    set(P_ptr,'userdata',P);
    set(T_ptr,'userdata',T);
    set(dots_ptr,'userdata',dots);
  end

  set(fig,...
    'pointer','arrow',...
    'WindowButtonUpFcn','')
    
%==================================================================
% Learn.
%
% ME('learn')
%==================================================================

elseif strcmp(cmd,'learn')
  
  max_epoch = 1;
  cmd = 'train';

%==================================================================
% Bias.
%
% ME('bias')
%==================================================================

elseif strcmp(cmd,'bias')

  set(no_bias,'value',0);

%==================================================================
% Bias.
%
% ME('bias')
%==================================================================

elseif strcmp(cmd,'nobias')

  set(bias,'value',0);
  
  w = get(w_ptr,'userdata');
  b = 0;
  P = get(P_ptr,'userdata');
  T = get(T_ptr,'userdata');
  dots = get(dots_ptr,'userdata');
  
  if (w(1) ~= 0)
    pp2 = [-p_max p_max];
    pp1 = -(w(2)*pp2+b)/w(1);
    if (pp1(1) < -p_max | pp1(1) > p_max)
      pp1(1) = p_max*sign(pp1(1));
      pp2(1) = -(w(1)*pp1(1)+b)/w(2);
    end
    if (pp1(2) < -p_max | pp1(2) > p_max)
      pp1(2) = p_max*sign(pp1(2));
      pp2(2) = -(w(1)*pp1(2)+b)/w(2);
    end
  elseif (w(2) ~= 0)
    pp1 = [-p_max p_max];
    pp2 = -(w(1)*pp1+b)/w(2);
  else
    pp1 = [0 0];
    pp2 = [0 0];
  end

  % NEW BOUNDARY
  set(db_line,...
    'color',nnltyell)
  set(cross,...
    'color',nndkblue);
  if all(T == hardlim(w*P,b))
    col = nndkblue;
  else
    col = nnred;
  end
  set(db_line,...
    'xdata',pp1,...
    'ydata',pp2,...
    'color',col)

  % REFRESH DOTS
  for k=1:length(T)
    a = hardlim(w*P(:,k)+b);
    if (a == T(k))
      col = nndkgray;
    else
      col = nnred;
    end
    set(dots(k),...
     'facecolor',[1 1 1]-T(k),...
     'edgecolor',col)
  end

  % NEW PARAMETER VALUES
  set(b_text,...
    'color',nnltgray)
  nntxtchk;
  set(b_text,...
    'string',sprintf('%5.3g',b),...
    'color',nndkblue)
  nntxtchk;
  
  set(b_ptr,'userdata',b);

%==================================================================
% Random weights.
%
% ME('random')
%==================================================================

elseif strcmp(cmd,'random')

  [w,b] = feval('rands',1,2);
  if get(no_bias,'value'), b = 0; end

  P = get(P_ptr,'userdata');
  T = get(T_ptr,'userdata');
  dots = get(dots_ptr,'userdata');

  if (w(1) ~= 0)
    pp2 = [-p_max p_max];
    pp1 = -(w(2)*pp2+b)/w(1);
    if (pp1(1) < -p_max | pp1(1) > p_max)
      pp1(1) = p_max*sign(pp1(1));
      pp2(1) = -(w(1)*pp1(1)+b)/w(2);
    end
    if (pp1(2) < -p_max | pp1(2) > p_max)
      pp1(2) = p_max*sign(pp1(2));
      pp2(2) = -(w(1)*pp1(2)+b)/w(2);
    end
  elseif (w(2) ~= 0)
    pp1 = [-p_max p_max];
    pp2 = -(w(1)*pp1+b)/w(2);
  else
    pp1 = [0 0];
    pp2 = [0 0];
  end

  % NEW BOUNDARY
  set(db_line,...
    'color',nnltyell)
  set(cross,...
    'color',nndkblue);
  if all(T == hardlim(w*P,b))
    col = nndkblue;
  else
    col = nnred;
  end
  set(db_line,...
    'xdata',pp1,...
    'ydata',pp2,...
    'color',col)

  % REFRESH DOTS
  for k=1:length(T)
    a = hardlim(w*P(:,k)+b);
    if (a == T(k))
      col = nndkgray;
    else
      col = nnred;
    end
    set(dots(k),...
     'facecolor',[1 1 1]-T(k),...
     'edgecolor',col)
  end

  % NEW PARAMETER VALUES
  set(w1_text,...
    'color',nnltgray)
  set(w2_text,...
    'color',nnltgray)
  set(b_text,...
    'color',nnltgray)
  nntxtchk;
  set(w1_text,...
    'string',sprintf('%5.3g',w(1)),...
    'color',nndkblue)
  set(w2_text,...
    'string',sprintf('%5.3g',w(2)),...
    'color',nndkblue)
  set(b_text,...
    'string',sprintf('%5.3g',b),...
    'color',nndkblue)
  nntxtchk;
  
  set(w_ptr,'userdata',w);
  set(b_ptr,'userdata',b);
  
%==================================================================
end

%==================================================================
% Train.
%
% ME('train')
%==================================================================

if strcmp(cmd,'train')
  
  % GET DATA
  w = get(w_ptr,'userdata');
  b = get(b_ptr,'userdata');
  P = get(P_ptr,'userdata');
  T = get(T_ptr,'userdata');
  dots = get(dots_ptr,'userdata');
  bf = get(bias,'value');
  j = get(index_ptr,'userdata');
  q = length(T);
  if (q == 0)
    return
  end

  for i=1:max_epoch
    if (j > q), j = 1; end
    if all(T == hardlim(w*P,b)), break, end

    dot_col = get(dots(j),'facecolor');
    set(dots(j),...
      'facecolor',nngreen)
    nnpause(0.5);
    set(dots(j),...
      'facecolor',nnltyell)
    nnpause(0.5);
    set(dots(j),...
      'facecolor',nngreen)
    nnpause(0.5);

    a = hardlim(w*P(:,j),b);
    e = T(:,j) - a;
    dw = e*P(:,j)';
    db = e;
    w = w + dw;
    if bf, b = b + db; end

    % CALCULATE NEW DECISION BOUNDARY
    if (dw ~= 0) | (db ~= 0)
      if (w(1) ~= 0)
        pp2 = [-p_max p_max];
        pp1 = -(w(2)*pp2+b)/w(1);
        if (pp1(1) < -p_max | pp1(1) > p_max)
          pp1(1) = p_max*sign(pp1(1));
          pp2(1) = -(w(1)*pp1(1)+b)/w(2);
        end
        if (pp1(2) < -p_max | pp1(2) > p_max)
          pp1(2) = p_max*sign(pp1(2));
          pp2(2) = -(w(1)*pp1(2)+b)/w(2);
        end
      elseif (w(2) ~= 0)
        pp1 = [-p_max p_max];
        pp2 = -(w(1)*pp1+b)/w(2);
      else
        pp1 = [0 0];
        pp2 = [0 0];
      end

      % NEW BOUNDARY
      set(db_line,...
        'color',nnltyell)
      set(cross,...
        'color',nndkblue);

      if all(T == hardlim(w*P,b))
        col = nndkblue;
      else
        col = nnred;
      end
      set(db_line,...
        'xdata',pp1,...
        'ydata',pp2,...
        'color',col)

      % REFRESH DOTS
      for k=1:length(T)
        a = hardlim(w*P(:,k)+b);
        if (a == T(k))
          col = nndkgray;
        else
          col = nnred;
        end
        set(dots(k),...
          'facecolor',[1 1 1]-T(k),...
          'edgecolor',col)
      end

      % NEW PARAMETER VALUES
      set(w1_text,...
        'color',nnltgray)
      set(w2_text,...
        'color',nnltgray)
      set(b_text,...
        'color',nnltgray)
      nntxtchk;
      set(w1_text,...
        'string',sprintf('%5.3g',w(1)),...
        'color',nndkblue)
      set(w2_text,...
        'string',sprintf('%5.3g',w(2)),...
        'color',nndkblue)
      set(b_text,...
        'string',sprintf('%5.3g',b),...
        'color',nndkblue)
      nntxtchk;
    end
    set(dots(j),...
      'facecolor',dot_col)
    j = j + 1;
  end
 
  set(w_ptr,'userdata',w);
  set(b_ptr,'userdata',b);
  set(index_ptr,'userdata',j);
end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -