📄 nnd12vl.asv
字号:
uicontrol(...
'units','points',...
'position',[400 110 60 20],...
'string','Contents',...
'callback','nndtoc')
uicontrol(...
'units','points',...
'position',[400 75 60 20],...
'string','Close',...
'callback',[me '(''close'')'])
% DATA POINTERS
dummy = 0;
cont_ptr = uicontrol('visible','off','userdata',cont_h);
option_ptr = uicontrol('visible','off','userdata',option);
path_ptr = uicontrol('visible','off','userdata',[]);
% SAVE WINDOW DATA AND LOCK
H = [fig_axis desc_text lr_bar lr_text inc_bar inc_text ...
dec_bar dec_text cont_axis cont_ptr ...
var1 var2 var3 var4 radio1 radio2 radio3 option_ptr path_ptr];
set(fig,'userdata',H,'nextplot','new')
% INSTRUCTION TEXT
feval(me,'instr');
% LOCK WINDOW
set(fig,'nextplot','new','color',nnltgray);
nnchkfs;
%==================================================================
% Display the instructions.
%
% ME('instr')
%==================================================================
elseif strcmp(cmd,'instr') & (fig)
nnsettxt(desc_text,...
'Use the radio buttons',...
'to select the network',...
'parameters to train',...
'with backpropagation.',...
'',...
'The corresponding',...
'contour plot is',...
'shown below.',...
'',...
'Click in the contour',...
'graph to start the',...
'variable learning',...
'rate backpropagation',...
'learning algorithm.')
%==================================================================
% Respond to radio buttons.
%
% ME('radio',i)
%==================================================================
elseif strcmp(cmd,'radio') & (fig) & (nargin == 2)
% GET DATA
option = get(option_ptr,'userdata');
% ALTER TRAINABLE PARAMETERS
if (arg1 ~= option)
% HIGHLIGHT NEW RADIO BUTTON
set(radios(option),'value',0)
set(radios(arg1),'value',1)
option = arg1;
% CLEAR AXES
delete(get(cont_axis,'children'))
% CONSTANTS
W1 = [10; 10];
b1 = [-5;5];
W2 = [1 1];
b2 = [-1];
P = -2:0.1:2;
[R,Q] = size(P);
%A1 = logsig(W1*P+b1*ones(1,Q));
%T = logsig(W2*A1+b2*ones(1,Q));
A1 = nndlogsig(W1*P+b1*ones(1,Q));
T = nndlogsig(W2*A1+b2*ones(1,Q));
% ERROR SURFACE & VARIABLE NAMES
if option == 1
load nndbp1
set(variables([1 4]),'color',[1 1 1])
set(variables([2 3]),'color',nndkblue)
elseif option == 2
load nndbp2
set(variables([1 2]),'color',[1 1 1])
set(variables([3 4]),'color',nndkblue)
else
load nndbp3
set(variables([2 3]),'color',[1 1 1])
set(variables([1 4]),'color',nndkblue)
end
set(fig,'nextplot','add')
axes(cont_axis)
set(get(cont_axis,'xlabel'),'string',v1)
set(get(cont_axis,'ylabel'),'string',v2)
set(cont_axis,'xlim',range1,'ylim',range2)
[dummy,cont_h] = contour(x2,y2,E2,levels);
set(cont_h,'erasemode','none');
plot3(range1([1 2 2 1 1]),range2([1 1 2 2 1]),1000*ones(1,5),...
'color',nndkblue);
cont_h2 = plot(optx,opty,'+','color',nnred);
cont_h = [cont_h; cont_h2];
view(2)
% STORE DATA
set(cont_ptr,'userdata',cont_h);
set(path_ptr,'userdata',[]);
set(option_ptr,'userdata',option);
end
%==================================================================
% Respond to learning rate slider.
%
% ME('lr')
%==================================================================
elseif strcmp(cmd,'lr')
lr = get(lr_bar,'value');
set(lr_text,'string',sprintf('%4.1f',round(lr*10)*0.1))
%==================================================================
% Respond to learning rate increase slider.
%
% ME('inc')
%==================================================================
elseif strcmp(cmd,'inc')
inc = get(inc_bar,'value');
set(inc_text,'string',sprintf('%4.2f',round(inc*100)*0.01))
%==================================================================
% Respond to learning rate decrease slider.
%
% ME('dec')
%==================================================================
elseif strcmp(cmd,'dec')
dec = get(dec_bar,'value');
set(dec_text,'string',sprintf('%4.2f',round(dec*100)*0.01))
%==================================================================
% Respond to mouse down.
%
% ME('down')
%==================================================================
elseif strcmp(cmd,'down') & (fig) & (nargin == 1)
pt = get(cont_axis,'currentpoint');
x = pt(1);
y = pt(3);
xlim = get(cont_axis,'xlim');
ylim = get(cont_axis,'ylim');
if (x > xlim(1) & x < xlim(2) & y > ylim(1) & y < ylim(2))
% GET DATA
option = get(option_ptr,'userdata');
path = get(path_ptr,'userdata');
cont_h = get(cont_ptr,'userdata');
% REMOVE PREVIOUS PATH
set(fig,'nextplot','add')
delete(path);
% INITIAL VALUES
W1 = [10; 10];
b1 = [-5;5];
W2 = [1 1];
b2 = [-1];
P = -2:0.1:2;
[R,Q] = size(P);
%A1 = logsig(W1*P+b1*ones(1,Q));
%T = logsig(W2*A1+b2*ones(1,Q));
A1 = nndlogsig(W1*P+b1*ones(1,Q));
T = nndlogsig(W2*A1+b2*ones(1,Q));
% PLOT START POINT
dkblue = nndkblue;
red = nnred;
axes(cont_axis);
path = [...
plot(x,y,'o','color',dkblue,'markersize',8,'erasemode','none');
plot(x,y,'o','color',[1 1 1],'markersize',10,'erasemode','none');
plot(x,y,'o','color',dkblue,'markersize',12,'erasemode','none')];
drawnow;
% PLOT PATH
set(fig,'pointer','watch')
% INITIALIZE TRAINING
if option == 1
W1(1,1) = x;
W2(1,1) = y;
elseif option == 2
W1(1,1) = x;
b1(1) = y;
else
b1(1) = x;
b1(2) = y;
end
ep = 100;
lr = get(lr_bar,'value');
inc = get(inc_bar,'value');
dec = get(dec_bar,'value');
A1 = logsig(W1*P+b1*ones(1,Q));
A2 = logsig(W2*A1+b2*ones(1,Q));
E = T-A2;
% BACKPROPAGATION PHASE
D2 = A2.*(1-A2).*E;
D1 = A1.*(1-A1).*(W2'*D2);
SSE = sum(sum(E.*E)); %sumsqr(E);
dW1 = 0;
db1 = 0;
dW2 = 0;
db2 = 0;
xx = [x zeros(1,ep)];
yy = [y zeros(1,ep)];
MC = mc;
for i=2:(ep+1)
% LEARNING PHASE
dW1 = MC*dW1 + (1-MC)*D1*P'*lr;
db1 = MC*db1 + (1-MC)*D1*ones(Q,1)*lr;
dW2 = MC*dW2 + (1-MC)*D2*A1'*lr;
db2 = MC*db2 + (1-MC)*D2*ones(Q,1)*lr;
MC = mc;
new_W1 = W1; new_b1 = b1;
new_W2 = W2; new_b2 = b2;
if (option == 1)
newx = W1(1,1) + dW1(1,1); new_W1(1,1) = newx;
newy = W2(1,1) + dW2(1,1); new_W2(1,1) = newy;
elseif(option == 2)
newx = W1(1,1) + dW1(1,1); new_W1(1,1) = newx;
newy = b1(1) + db1(1); new_b1(1) = newy;
else
newx = b1(1) + db1(1); new_b1(1) = newx;
newy = b1(2) + db1(2); new_b1(2) = newy;
end
% PRESENTATION PHASE
new_A1 = logsig(new_W1*P+new_b1*ones(1,Q));
new_A2 = logsig(new_W2*A1+new_b2*ones(1,Q));
new_E = T-new_A2;
new_SSE = sum(sum(new_E.*new_E)); %sumsqr(new_E);
% MOMENTUM & ADAPTIVE LEARNING RATE PHASE
if new_SSE > SSE*er
lr = lr * dec;
MC = 0;
else
if new_SSE < SSE
lr = lr * inc;
end
W1 = new_W1; b1 = new_b1; A1 = new_A1;
W2 = new_W2; b2 = new_b2; A2 = new_A2;
x = newx;
y = newy;
E = new_E; SSE = new_SSE;
% BACKPROPAGATION PHASE
D2 = A2.*(1-A2).*E;
D1 = A1.*(1-A1).*(W2'*D2);
end
% TRAINING RECORD
xx(i) = x;
yy(i) = y;
end
% CONTOUR PLOT
path = [path; plot(xx,yy,'color',nnred); plot(xx,yy,'o','color',nnred,'markersize',6)];
set(fig,'nextplot','new')
% SAVE DATA
set(path_ptr,'userdata',path);
set(fig,'pointer','arrow')
end
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -