📄 nnprune.m
字号:
% >>>>>>>>>>>>> COMPUTE NETWORK OUTPUT FROM TEST DATA y2(theta) <<<<<<<<<<<<<<
% -- Compute only if a test set is present --
if TestDataFlag,
% ---------- NNARX model ----------
if mflag==1,
htest1 = W1*PHI2_aug;
ytest1(H_hidden,:) = pmntanh(htest1(H_hidden,:));
ytest1(L_hidden,:) = htest1(L_hidden,:);
htest2 = W2*ytest1;
ytest2(H_output,:) = pmntanh(htest2(H_output,:));
ytest2(L_output,:) = htest2(L_output,:);
E = Y2 - ytest2; % Error between Y and deterministic part
% --------- NNARMAX1 model --------
elseif mflag==2,
htest1 = W1*PHI2_aug;
ytest1(H_hidden,:) = pmntanh(htest1(H_hidden,:));
ytest1(L_hidden,:) = htest1(L_hidden,:);
htest2 = W2*ytest1;
ytest2(H_output,:) = pmntanh(htest2(H_output,:));
ytest2(L_output,:) = htest2(L_output,:);
Ebar = Y2 - ytest2; % Error between Y and deterministic part
E = filter(1,Chat,Ebar); % Prediction error
ytest2 = ytest2 - E; % One step ahead prediction
% --------- NNARMAX2 model --------
elseif mflag==3,
for t=1:N2,
htest1 = W1*PHI2_aug(:,t);
ytest1(H_hidden,t) = pmntanh(htest1(H_hidden));
ytest1(L_hidden,t) = htest1(L_hidden);
htest2 = W2*ytest1(:,t);
ytest2(H_output,t) = pmntanh(htest2(H_output,:));
ytest2(L_output,t) = htest2(L_output,:);
E(:,t) = Y2(:,t) - ytest2(:,t); % Prediction error
for d=1:min(nc,N2-t),
PHI2_aug(nab+d,t+d) = E(:,t);
end
end
% ---------- NNOE model ----------
elseif mflag==4,
for t=1:N2,
htest1 = W1*PHI2_aug(:,t);;
ytest1(H_hidden,t) = pmntanh(htest1(H_hidden));
ytest1(L_hidden,t) = htest1(L_hidden);
htest2 = W2*ytest1(:,t);
ytest2(H_output,t) = pmntanh(htest2(H_output,:));
ytest2(L_output,t) = htest2(L_output,:);
for d=1:min(na,N2-t),
PHI2_aug(d,t+d) = ytest2(:,t);
end
end
E = Y2 - ytest2; % Error between Y and deterministic part
end
SSE = E(skip:N2)*E(skip:N2)'; % Sum of squared errors (SSE)
PI_test = SSE/(2*N2tot); % Cost function evaluated on test data
PI_test_vec(reduced) = PI_test; % Collect PI_test in vector
end
% >>>>>>>>>>>>>>>>>>>>>> GET NETWORK OUTPUT AND GRADIENT <<<<<<<<<<<<<<<<<<<<<<
[PSI,E] = getgrad(method,NetDef,NN,W1,W2,Chat,Y,U);
PI = E(skip:N)*E(skip:N)'/(2*NN2); % Performance index
PI_vector(reduced) = PI; % Collect PI in vector
% >>>>>>>>>>>>>>>>>>>>>>>> COMPUTE THE HESSIAN MATRIX <<<<<<<<<<<<<<<<<<<<<<
PSI_red = PSI(theta_index,skip:N);
% --- Inverse Hessian if no weight decay ---
if D==0,
H_inv = p0*eye(reduced);
for k=1:outputs:NN2, % Iterative solution
psi=PSI_red(:,(k-1)*outputs+1:k*outputs);
H_inv = H_inv - H_inv*psi*inv(Ident + psi'*H_inv*psi)*psi'*H_inv;
end
FPE = (NN2+reduced)*PI/(NN2-reduced); % FPE estimate
gamma1 = reduced;
% --- Inverse Hessian if weight decay is being used ---
else
R = PSI_red*PSI_red';
H = R;
index3 = 1:(reduced+1):(reduced^2); % A third useful vector
H(index3) = H(index3) + D'; % Add weight deacy to diagonal
H_inv = inv(H); % Inverse Hessian
RHinv = R*H_inv;
gamma1=trace(RHinv*RHinv); % Effective # of parameters
gamma2=trace(RHinv);
FPE = (NN2+gamma1)*PI/(NN2+gamma1-2*gamma2);% FPE estimate
end
FPE_vector(reduced) = FPE; % Collect FPE estimate
deff_vec(reduced)=gamma1; % Collect effective # of param.
% >>>>>>>>>>>> PLOT THE PI's AND THE CURRENT NETWORK STRUCTURE <<<<<<<<<<<<<
% --- Draw PI's ---
figure(1);
pvec=[reduced pvec];
if TestDataFlag
plot(pvec,PI_vector(pvec),'x',pvec,FPE_vector(pvec),'+',...
pvec,PI_test_vec(pvec),'o')
title('x = training error, + = FPE, o = test error')
else
plot(pvec,PI_vector(pvec),'x',pvec,FPE_vector(pvec),'+')
title('x = training error, + = FPE')
end
set(gca,'Xlim',[0 reduced0]);
xlabel('Parameters');
drawnow
% --- Draw pruned network ---
figure(2);
drawnet(W1,W2,eps);
title(['Network after having pruned ',int2str(pr),' weights']);
figure(2); drawnow
% >>>>>>>>>>>>> ELIMINATE THE WEIGHT HAVING THE SMALLEST SALIENCY <<<<<<<<<<<<<
nopruned = floor(max(1,reduced*RePercent/100)); % No of parms to prune
if reduced<=minweights, break; end
bpr = 1;
while bpr<=nopruned
if D==0,
gamma = theta_red./diag(H_inv); % Lagrange multipliers
zeta = theta_red.*gamma; % Salincies if D=0
else
gamma = theta_red./diag(H_inv); % Lagrange multipliers
HRH = H_inv*R*H_inv;
% Saliencies
zeta = gamma.*(H_inv*(D.*theta_red))+gamma.*gamma.*diag(HRH)/2;
end
Critical=[]; WhileFlag=1; ElimWeights=1;
while WhileFlag,
WhileFlag = 0;
HiddenLeft = hidden-length(find(ConnectToHidden==0)); % Hidden units;
[zeta_min,min_index] = min(zeta(HiddenLeft+1:reduced));% Find smallest saliency
min_index = min_index+HiddenLeft;
HiddenNo = HiddenIndex(theta_index(min_index)); % Conection to hidden no?
if theta_index(min_index)~=(hidden+1), % Not the bias
if ConnectToHidden(HiddenNo)==1
HiddenNoRed = find(theta_index(1:HiddenLeft)==HiddenNo);
Eidx = [HiddenNoRed;min_index];
gamma_new = inv(H_inv(Eidx,Eidx))*theta_red(Eidx);
if find(Critical==min_index);
ElimWeights=2;
break;
end
if D==0,
zeta(min_index) = gamma_new'*theta_red(Eidx);
else
zeta(min_index) = gamma_new'*H_inv(Eidx,:)*(D.*theta_red)...
+ gamma_new'*HRH(Eidx,Eidx)*gamma_new/2;
end
Critical = [Critical;min_index];
WhileFlag = 1;
end
end
end
if theta_index(min_index)~=(hidden+1),
ConnectToHidden(HiddenNo) = ConnectToHidden(HiddenNo) - 1;
end
% ----- Eliminate one weight -----
if ElimWeights==1,
theta_red = theta_red - gamma(min_index)*H_inv(:,min_index);
theta_red(min_index) = 0; % Non-zero weights
theta(theta_index) = theta_red;
tmpindex = [1:min_index-1 min_index+1:length(theta_index)];
theta_index = theta_index(tmpindex);
theta_red = theta(theta_index);
reduced = reduced-1; % Remaining weights
theta_data(:,reduced) = theta; % Store parameter vector
D = D([1:min_index-1 min_index+1:length(D)]);
% --- Update inverse Hessian ---
H_inv = H_inv(tmpindex,tmpindex)-H_inv(tmpindex,min_index)...
*H_inv(min_index,tmpindex)/H_inv(min_index,min_index);
if D~=0, R = R(tmpindex,tmpindex); end
% ----- Eliminate two weight -----
elseif ElimWeights==2,
theta_red = theta_red - H_inv(:,Eidx)*gamma_new;
theta_red(Eidx) = [0;0]; % Non-zero weights
theta(theta_index) =theta_red;
tmpindex = [1:Eidx(1)-1 Eidx(1)+1:Eidx(2)-1 Eidx(2)+1:length(theta_index)];
theta_index = theta_index(tmpindex);
theta_red = theta(theta_index);
reduced = reduced-2; % Remaining weights
theta_data(:,[reduced reduced+1]) = [theta theta]; % Store parameter vector
D = D(tmpindex);
% --- Update inverse Hessian ---
H_inv = H_inv(tmpindex,tmpindex)-H_inv(tmpindex,Eidx)...
*inv(H_inv(Eidx,Eidx))*H_inv(Eidx,tmpindex);
if D~=0, R = R(tmpindex,tmpindex); end
bpr = bpr + 1;
end
bpr = bpr + 1;
end
pr = pr + bpr-1; % Total # of pruned weights
% -- Put the parameters back into the weight matrices --
if mflag==2,
W1 = reshape(theta(parameters2+1:parameters12),inputs+1,hidden)';
W2 = reshape(theta(1:parameters2),hidden+1,outputs)';
Chat = [1 theta(parameters12+1:parameters)'];
else
W1 = reshape(theta(parameters2+1:parameters),inputs+1,hidden)';
W2 = reshape(theta(1:parameters2),hidden+1,outputs)';
end
FirstTimeFlag=0;
end
%----------------------------------------------------------------------------------
%------------- END OF NETWORK PRUNING -------------
%----------------------------------------------------------------------------------
fprintf('\n\n\n --> Pruning session terminated <--\n\n\n');
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -