📄 nnd12mo.m
字号:
%==================================================================
% Display the instructions.
%
% ME('instr')
%==================================================================
elseif strcmp(cmd,'instr') & (fig)
nnsettxt(desc_text,...
'Use the radio buttons',...
'to select the network',...
'parameters to train',...
'with backpropagation.',...
'',...
'The corresponding',...
'contour plot is',...
'shown below.',...
'',...
'Click in the contour',...
'graph to start the',...
'momentum backprop',...
'learning algorithm.',...
'You can reset the',...
'algorithm parameters',...
'using the sliders.')
%==================================================================
% Respond to radio buttons.
%
% ME('radio',i)
%==================================================================
elseif strcmp(cmd,'radio') & (fig) & (nargin == 2)
% GET DATA
option = get(option_ptr,'userdata');
% ALTER TRAINABLE PARAMETERS
if (arg1 ~= option)
% HIGHLIGHT NEW RADIO BUTTON
set(radios(option),'value',0)
set(radios(arg1),'value',1)
option = arg1;
% CLEAR AXES
delete(get(cont_axis,'children'))
% CONSTANTS
W1 = [10; 10];
b1 = [-5;5];
W2 = [1 1];
b2 = [-1];
P = -2:0.1:2;
[R,Q] = size(P);
%A1 = logsig(W1*P+b1*ones(1,Q));
%T = logsig(W2*A1+b2*ones(1,Q));
A1 = nndlogsig(W1*P+b1*ones(1,Q));
T = nndlogsig(W2*A1+b2*ones(1,Q));
% ERROR SURFACE & VARIABLE NAMES
if option == 1
load nndbp1
set(variables([1 4]),'color',[1 1 1])
set(variables([2 3]),'color',nndkblue)
elseif option == 2
load nndbp2
set(variables([1 2]),'color',[1 1 1])
set(variables([3 4]),'color',nndkblue)
else
load nndbp3
set(variables([2 3]),'color',[1 1 1])
set(variables([1 4]),'color',nndkblue)
end
set(fig,'nextplot','add')
axes(cont_axis)
set(get(cont_axis,'xlabel'),'string',v1)
set(get(cont_axis,'ylabel'),'string',v2)
set(cont_axis,'xlim',range1,'ylim',range2)
[dummy,cont_h] = contour(x2,y2,E2,levels);
set(cont_h,'erasemode','none');
plot3(range1([1 2 2 1 1]),range2([1 1 2 2 1]),1000*ones(1,5),...
'color',nndkblue);
cont_h2 = plot(optx,opty,'+','color',nnred);
cont_h = [cont_h; cont_h2];
view(2)
% STORE DATA
set(cont_ptr,'userdata',cont_h);
set(path_ptr,'userdata',[]);
set(option_ptr,'userdata',option);
end
%==================================================================
% Respond to learning rate slider.
%
% ME('lr')
%==================================================================
elseif strcmp(cmd,'lr')
set(lr_bar,'enable','off');
lr = get(lr_bar,'value');
set(lr_text,'string',sprintf('%4.1f',round(lr*10)*0.1))
%==================================================================
% Respond to momentum constant slider.
%
% ME('mc')
%==================================================================
elseif strcmp(cmd,'mc')
set(mc_bar,'enable','off');
mc = get(mc_bar,'value');
set(mc_text,'string',sprintf('%4.2f',round(mc*100)*0.01))
end
%==================================================================
% Respond to mouse down.
%
% ME('down')
%==================================================================
if (strcmp(cmd,'down') && (fig) && (nargin == 1)) || strcmp(cmd,'lr') || strcmp(cmd,'mc')
if strcmp(cmd,'lr') || strcmp(cmd,'mc')
temp=get(path_ptr,'userdata');
if ~isnan(temp)
temp2=get(temp(1),'xdata');
x=temp2(1);
temp2=get(temp(1),'ydata');
y=temp2(1);
else
x = NaN;
y = NaN;
end
else
pt = get(cont_axis,'currentpoint');
x = pt(1);
y = pt(3);
end
xlim = get(cont_axis,'xlim');
ylim = get(cont_axis,'ylim');
if (x > xlim(1) & x < xlim(2) & y > ylim(1) & y < ylim(2))
% GET DATA
option = get(option_ptr,'userdata');
path = get(path_ptr,'userdata');
cont_h = get(cont_ptr,'userdata');
% REMOVE PREVIOUS PATH
set(fig,'nextplot','add')
delete(path);
% INITIAL VALUES
W1 = [10; 10];
b1 = [-5;5];
W2 = [1 1];
b2 = [-1];
P = -2:0.1:2;
[R,Q] = size(P);
%A1 = logsig(W1*P+b1*ones(1,Q));
%T = logsig(W2*A1+b2*ones(1,Q));
A1 = nndlogsig(W1*P+b1*ones(1,Q));
T = nndlogsig(W2*A1+b2*ones(1,Q));
% PLOT START POINT
dkblue = nndkblue;
red = nnred;
axes(cont_axis);
path = [...
plot(x,y,'o','color',dkblue,'markersize',8,'erasemode','none');
plot(x,y,'o','color',[1 1 1],'markersize',10,'erasemode','none');
plot(x,y,'o','color',dkblue,'markersize',12,'erasemode','none')];
drawnow;
% PLOT PATH
set(fig,'pointer','watch')
% INITIALIZE TRAINING
if option == 1
ep = 300;
W1(1,1) = x;
W2(1,1) = y;
elseif option == 2
ep = 300;
W1(1,1) = x;
b1(1) = y;
else
ep = 60;
b1(1) = x;
b1(2) = y;
end
lr = get(lr_bar,'value');
mc = get(mc_bar,'value');
%A1 = logsig(W1*P+b1*ones(1,Q));
%A2 = logsig(W2*A1+b2*ones(1,Q));
A1 = nndlogsig(W1*P+b1*ones(1,Q));
A2 = nndlogsig(W2*A1+b2*ones(1,Q));
E = T-A2;
xx = [x zeros(1,ep)];
yy = [y zeros(1,ep)];
ee = [sum(sum(E.*E)) zeros(1,ep)]; %[sumsqr(E) zeros(1,ep)];
dW1 = 0;
db1 = 0;
dW2 = 0;
db2 = 0;
% TRAINING #1
if option == 1
for i=2:(ep+1)
SSE = sum(sum(E.*E)); %sumsqr(E);
D2 = A2.*(1-A2).*E;
D1 = A1.*(1-A1).*(W2'*D2);
dW1 = mc*dW1 + (1-mc)*D1*P'*lr;
db1 = mc*db1 + (1-mc)*D1*ones(Q,1)*lr;
dW2 = mc*dW2 + (1-mc)*D2*A1'*lr;
db2 = mc*db2 + (1-mc)*D2*ones(Q,1)*lr;
newx = W1(1,1) + dW1(1,1); W1(1,1) = newx; xx(i) = newx;
newy = W2(1,1) + dW2(1,1); W2(1,1) = newy; yy(i) = newy;
%A1 = logsig(W1*P+b1*ones(1,Q));
%A2 = logsig(W2*A1+b2*ones(1,Q));
A1 = nndlogsig(W1*P+b1*ones(1,Q));
A2 = nndlogsig(W2*A1+b2*ones(1,Q));
E = T-A2;
ee(i) = sum(sum(E.*E)); %sumsqr(E);
end
% TRAINING #2
elseif option == 2
for i=2:(ep+1)
SSE = sum(sum(E.*E)); %sumsqr(E);
D2 = A2.*(1-A2).*E;
D1 = A1.*(1-A1).*(W2'*D2);
dW1 = mc*dW1 + (1-mc)*D1*P'*lr;
db1 = mc*db1 + (1-mc)*D1*ones(Q,1)*lr;
dW2 = mc*dW2 + (1-mc)*D2*A1'*lr;
db2 = mc*db2 + (1-mc)*D2*ones(Q,1)*lr;
newx = W1(1,1) + dW1(1,1); W1(1,1) = newx; xx(i) = newx;
newy = b1(1) + db1(1); b1(1) = newy; yy(i) = newy;
%A1 = logsig(W1*P+b1*ones(1,Q));
%A2 = logsig(W2*A1+b2*ones(1,Q));
A1 = nndlogsig(W1*P+b1*ones(1,Q));
A2 = nndlogsig(W2*A1+b2*ones(1,Q));
E = T-A2;
ee(i) = sum(sum(E.*E)); %sumsqr(E);
end
% TRAINING #3
else
for i=2:(ep+1)
SSE = sum(sum(E.*E)); %sumsqr(E);
D2 = A2.*(1-A2).*E;
D1 = A1.*(1-A1).*(W2'*D2);
dW1 = mc*dW1 + (1-mc)*D1*P'*lr;
db1 = mc*db1 + (1-mc)*D1*ones(Q,1)*lr;
dW2 = mc*dW2 + (1-mc)*D2*A1'*lr;
db2 = mc*db2 + (1-mc)*D2*ones(Q,1)*lr;
newx = b1(1) + db1(1); b1(1) = newx; xx(i) = newx;
newy = b1(2) + db1(2); b1(2) = newy; yy(i) = newy;
%A1 = logsig(W1*P+b1*ones(1,Q));
%A2 = logsig(W2*A1+b2*ones(1,Q));
A1 = nndlogsig(W1*P+b1*ones(1,Q));
A2 = nndlogsig(W2*A1+b2*ones(1,Q));
E = T-A2;
ee(i) = sum(sum(E.*E)); %sumsqr(E);
end
end
% CONTOUR PLOT
path = [path; plot(xx,yy,'color',nnred); plot(xx,yy,'o','color',nnred,'markersize',6)];
set(fig,'nextplot','new')
% SAVE DATA
set(path_ptr,'userdata',path);
set(fig,'pointer','arrow')
end
set(mc_bar,'enable','on');
set(lr_bar,'enable','on');
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -