📄 scgm.m
字号:
function [ep,tempo] = scgm(P,T,nh,minerr,maxep,w1,w2)%% SCGM% Main Program (function) % MLP net with Backprop training% (Moller 1993) Scaled Conjugate Gradient with% Exact calculus of second order information (Pearlmutter, 1994)% Functions: GOLDSEC, PROCESS, CALCHV% Off-line Updating% Author: Leandro Nunes de Castro% Unicamp, January 1998%%-------------------------------------------------% Definition and initialization of the parameters%-------------------------------------------------% P = -pi:.1:pi; T = sin(P).*cos(2.*P); T = T';
P0 = P;[np,ni] = size(P0);[no] = size(T,2);ep = 0; cm = .7; alfa = 0.001;%figure(1); clf; plot(T,'r+'); drawnow; % axis([-pi pi -1.1 1.1]); %disp(sprintf('N鷐ero m醲imo de itera珲es: %d',maxep));
%disp(sprintf('N鷐ero de amostras: %d',length(T)));
%disp(sprintf('Crit閞io de parada: SSE=%.2f',minerr));
%disp(sprintf('Dimens鉶 da rede: [%d,%d,%d]',ni,nh,no));
% disp(sprintf('Pressione qualquer tecla para continuar...'));pause;
%------------------% Initialization%------------------lambda = 1e-6; lambdab = 0;delta = 0; deltak = 0; mi = 0;sse = 1000; sseant = sse; val = 0; P = [ones(np,1) P0];Nt = (ni+1)*nh + (nh+1)*no; vgrad = zeros((ni+1)*nh + (nh+1)*no,1); beta = 0; d = vgrad;[sse,vgrad,y] = process(w1,w2,P,T);gdw1 = reshape(d(1:(ni+1)*nh),ni+1,nh);gdw2 = reshape(d((ni+1)*nh+1:(ni+1)*nh+(nh+1)*no),nh+1,no);ngrad = norm(vgrad); vgrad = vgrad/ngrad;d = vgrad + beta*d;s = calcHv(w1,w2,gdw1,gdw2,T,P,d);fini = flops; t0 = clock;sucesso = 1;
%--------------------------------------------------% Network training - SCGM%-------------------------------------------------- while (ep < maxep & sse > minerr) ssea = sse; vgrada = vgrad; normd2 = d'*d; if sucesso == 1, s = calcHv(w1,w2,gdw1,gdw2,T,P,d); delta = d'*s; end; delta = delta + (lambda-lambdab)*normd2; if delta <= 0, % Positivando a Hessiana lambdab = 2*(lambda-delta/normd2); delta = delta + lambda*normd2; lambda = lambdab; end; mi = d'*vgrad; alfa = mi/delta; w1t = w1+alfa*gdw1; w2t = w2+alfa*gdw2; [sse,vgrad,y] = process(w1t,w2t,P,T); if sse >= ssea alfa = goldsec(w1,w2,gdw1,gdw2,T,P,.0001); % disp('Line Search'); w1t = w1+alfa*gdw1; w2t = w2+alfa*gdw2; [sse,vgrad,y] = process(w1t,w2t,P,T); end; deltak = (2*delta*((ssea-sse)/(mi*mi))); w1 = w1t; w2 = w2t; if deltak >= 0, lambdab = 0; sucesso = 1; if rem(ep,Nt) == 0, d = vgrad; % disp('Restart'); else beta = (vgrad'*(vgrad-vgrada))/(vgrada'*vgrada); beta = max(beta,0); d = vgrad + beta*d; end; if deltak >= 0.75, lambda = 0.25*lambda; elseif deltak < 0.25, lambda = lambda + (delta*(1-deltak)/normd2); end; else lambdab = lambda; sucesso = 0; end; if deltak < 0.25, lambda = lambda + (delta*(1-deltak)/normd2); end; gdw1 = reshape(d(1:(ni+1)*nh),ni+1,nh); gdw2 = reshape(d((ni+1)*nh+1:(ni+1)*nh+(nh+1)*no),nh+1,no); s = calcHv(w1,w2,gdw1,gdw2,T,P,d); ep = ep + 1;
%if rem(ep,50) == 0,
% disp(sprintf('SSE: %.2f Itera玢o: %u ||GRAD||: %.2f LR: %.3f',sse,ep,norm(vgrad),alfa));
% figure(1); clf; plot(T,'r+'); hold on; plot(y,'g'); drawnow; % axis([-pi pi -1.1 1.1]);
%end;
veter(ep) = sse; vetalfa(ep) = alfa;end; % end of stopping criteriafend = flops; tflops = fend-fini; %disp(sprintf('SSE: %.2f Itera玢o: %u ||GRAD||: %.2f LR: %.3f',sse,ep,norm(vgrad),alfa));
tempo = etime(clock,t0);
% disp(sprintf('Flops total: %d Time: %d',tflops,T));% Ploting results% figure(1); clf; plot (T,'r+'); hold on; plot(y,'g'); drawnow;% figure(2); semilogy(veter); title('SCGM'); xlabel('Epochs'); ylabel('SSE');
% ---------------------------- %
% SECONDARY INTERNAL FUNCTIONS %
% ---------------------------- %
% Function PROCESS
function [sse, vgrad,y] = process(w1,w2,P,T)
[np,ni] = size(P); ni = ni-1;
[nh,no] = size(w2); nh = nh-1;
z = tanh(P*w1);
z = [ones(np,1) z];
y = z*w2; % Linear outputs
dk = (T-y); gdw2 = z'*dk; % Linear outputs
w20 = reshape(w2(2:nh+1,:),nh,no);
dj = (dk*w20').*dfat(P*w1);
gdw1 = P'*dj;
verr = (T-y); verr = reshape(verr,np*no,1);
sse = verr'*verr;
vgrad = [reshape(gdw1,(ni+1)*nh,1); reshape(gdw2,(nh+1)*no,1)];
% End Function PROCESS
% Function GOLDSEC
function[step] = goldsec(w1,w2,gdw1,gdw2,T,P,dN);
%----------------------------------
% Global definitions
%----------------------------------
[np,ni] = size(P); ni = ni - 1;
[nh,no] = size(w2); nh = nh - 1;
ra = (sqrt(5)-1)/2;
d = [0 1]; % Initial interval
lb = d(1) + (1 - ra)*(d(2) - d(1));
mi = d(1) + ra*(d(2) - d(1));
mf = []; md = [];
%----------------------------------
% Function evaluations
%----------------------------------
w1a = w1; w2a = w2;
w1p = w1+lb*gdw1;
w2p = w2+lb*gdw2;
z = tanh(P*w1p);
z = [ones(np,1) z];
y = z*w2p; % Linear output
verr = (T-y); verr = reshape(verr,np*no,1);
f(1) = verr'*verr;
w1p = w1+mi*gdw1;
w2p = w2+mi*gdw2;
z = tanh(P*w1p);
z = [ones(np,1) z];
y = z*w2p; % Linear output
verr = (T-y); verr = reshape(verr,np*no,1);
f(2) = verr'*verr;
%----------------------------------
% Processamento
%----------------------------------
while abs(d(2) - d(1))/2 > dN,
if f(1) > f(2),
d(1) = lb; lb = mi;
mi = d(1) + ra*(d(2) - d(1));
w1p = w1+mi*gdw1;
w2p = w2+mi*gdw2;
z = tanh(P*w1p);
z = [ones(np,1) z];
y = z*w2p; % Linear output
verr = (T-y); verr = reshape(verr,np*no,1);
f(2) = verr'*verr;
else,
d(2) = mi; mi = lb;
lb = d(1) + (1 - ra)*(d(2) - d(1));
w1p = w1+lb*gdw1;
w2p = w2+lb*gdw2;
z = tanh(P*w1p);
z = [ones(np,1) z];
y = z*w2p; % Linear output
verr = (T-y); verr = reshape(verr,np*no,1);
f(1) = verr'*verr;
end;
mf = [f(1) f(2); mf];
md = [d(2) d(1); md];
end;
[y,vind] = min(mf);
if y(1) < y(2),
ind = vind(1); step = md(ind,1)/2;
else
ind = vind(2); step = md(ind,2)/2;
end;
% End Function GOLDSEC
% Function CALCHV
function[Hv] = calcHv(w1,w2,gdw1,gdw2,T,P,d);
%----------------------------------
% Global Definitions
%----------------------------------
[np,ni] = size(P); ni = ni - 1;
[nh,no] = size(w2); nh = nh - 1;
%---------------
% Second order
%---------------
Rx1 = zeros(np,ni+1);
z = tanh(P*w1);
Rz = (P*gdw1 + Rx1*w1).*dfat(P*w1);
z = [ones(np,1) z];
Rx2 = [zeros(np,1) Rz];
y = z*w2; % Linear output
Ry = z*gdw2 + Rx2*w2; % Linear output
erro = T-y; erro2 = erro;
Rerro2 = Ry;
Rw2 = Rx2'*erro2 + z'*Rerro2;
w20 = reshape(w2(2:nh+1,:),nh,no);
gdw20 = reshape(gdw2(2:nh+1,:),nh,no);
erro1 = (erro2*w20').*dfat(P*w1);
Rerro1 = (Rerro2*w20' + erro2*gdw20').*dfat(P*w1) +...
(erro2*w20'.*(-2.*tanh(P*w1).*dfat(P*w1)));
Rw1 = Rx1'*erro1 + P'*Rerro1;
Hv = [reshape(Rw1,(ni+1)*nh,1); reshape(Rw2,(nh+1)*no,1)];
% End Function CALCHV
% Function DFAT
function mat = dfat(x)
mat = (1+tanh(x)).*(1-tanh(x));
% End FUnction DFAT
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -