📄 obsprune.m
字号:
PSI(index(j):index(j)+inputs,index2+i) = tmp(ones_i,:)...
.*tmp2(ones_i,:).* PHI_aug;
end
% ---------------------------------------------------------------------
end
% >>>>>>>>>>>>>>>>>>>>>>>> COMPUTE THE HESSIAN MATRIX <<<<<<<<<<<<<<<<<<<<<<
PSI_red = PSI(theta_index,:);
% --- Inverse Hessian if no weight decay ---
if D==0,
H_inv = p0*eye(reduced);
for k=1:outputs:N, % Iterative solution
psi=PSI_red(:,(k-1)*outputs+1:k*outputs);
H_inv = H_inv - H_inv*psi*inv(Ident + psi'*H_inv*psi)*psi'*H_inv;
end
FPE = (N+reduced)*PI/(N-reduced); % FPE estimate
gamma1 = reduced;
% --- Inverse Hessian if weight decay is being used ---
else
R = PSI_red*PSI_red';
H = R;
index3 = 1:(reduced+1):(reduced^2); % A third useful vector
H(index3) = H(index3) + D'; % Add weight deacy to diagonal
H_inv = inv(H); % Inverse Hessian
RHinv = R*H_inv;
gamma1=trace(RHinv*RHinv); % Effective # of parameters
gamma2=trace(RHinv);
FPE = (N+gamma1)*PI/(N+gamma1-2*gamma2); % FPE estimate
end
FPE_vector(reduced) = FPE; % Collect FPE estimate
deff_vec(reduced)=gamma1; % Collect effective # of param.
% >>>>>>>>>>>> PLOT THE PI's AND THE CURRENT NETWORK STRUCTURE <<<<<<<<<<<<<
% --- Draw PI's ---
figure(1);
pvec=[reduced pvec];
if TestDataFlag
plot(pvec,PI_vector(pvec),'x',pvec,FPE_vector(pvec),'+',...
pvec,PI_test_vec(pvec),'o')
title('x = training error, + = FPE, o = test error')
else
plot(pvec,PI_vector(pvec),'x',pvec,FPE_vector(pvec),'+')
title('x = training error, + = FPE')
end
set(gca,'Xlim',[0 reduced0]);
xlabel('Parameters');
drawnow
% --- Draw pruned network ---
figure(2);
drawnet(W1,W2,eps);
title(['Network after having pruned ',int2str(pr),' weights']);
figure(2); drawnow
% >>>>>>>>>>>>> ELIMINATE THE WEIGHT HAVING THE SMALLEST SALIENCY <<<<<<<<<<<<<
nopruned = floor(max(1,reduced*RePercent/100)); % No of parms to prune
if reduced<=minweights, break; end
bpr = 1;
while bpr<=nopruned
% ----- Calculate all saliences -----
if D==0,
gamma = theta_red./diag(H_inv); % Lagrange multipliers
zeta = theta_red.*gamma; % Salincies if D=0
else
gamma = theta_red./diag(H_inv); % Lagrange multipliers
HRH = H_inv*R*H_inv;
% Saliencies
zeta = gamma.*(H_inv*(D.*theta_red))+gamma.*gamma.*diag(HRH)/2;
end
% ----- Add a large number to uninteresting saliences -----
index5 = find(ConnectFromHidden==1);
[zeta_max,max_index] = max(zeta); % Find largest saliency
for hidno=index5',
% Find the weight's location in theta
index6 = hidno:(hidden+1):(hidden+1)*outputs;
IndexInTheta = find(theta(index6)~=0);
IndexInTheta = index6(IndexInTheta);
% Find the location in zeta
IndexInZeta = find(theta_index == IndexInTheta);
zeta(IndexInZeta) = 2*abs(zeta_max);
end
Critical=[]; WhileFlag=1; ElimWeights=1;
while WhileFlag,
WhileFlag = 0;
[zeta_min,min_index] = min(zeta); % Find smallest saliency
% -- Check if weight in question belongs to W1 or W2 --
if (theta_index(min_index)>=1 & theta_index(min_index)<=(hidden+1)*outputs),
W=2; % Belong to W2
else
W=1; % Belong to W1
end
HiddenNo = HiddenIndex(theta_index(min_index)); % Connection to hidden no?
% -- If not a bias for an output unit --
if ~( any([hidden+1:hidden+1:outputs*(hidden+1)]==theta_index(min_index)) ),
% -- If weight has only one connection leading to it --
if W==1 & ConnectToHidden(HiddenNo)==1,
index6 = HiddenNo:(hidden+1):(hidden+1)*outputs;
HiddenNoRed = [];
for k=index6,
HiddenNoRed = [HiddenNoRed,find(theta_index==k)];% Index in zeta vector
end
Eidx = [HiddenNoRed,min_index]; % Indices to weights to and from unit
gamma_new = inv(H_inv(Eidx,Eidx))*theta_red(Eidx);
if find(Critical==min_index);
ElimWeights=2;
break;
end
if D==0,
zeta(min_index) = gamma_new'*theta_red(Eidx);
else
zeta(min_index) = gamma_new'*H_inv(Eidx,:)*(D.*theta_red)...
+ gamma_new'*HRH(Eidx,Eidx)*gamma_new/2;
end
Critical = [Critical;min_index];
WhileFlag = 1;
end
end
end
% If not a bias for an output unit
if ~(any([hidden+1:hidden+1:outputs*(hidden+1)]==theta_index(min_index))),
% -- If a W2 weight is eliminated --
if W==2,
ConnectFromHidden(HiddenNo) = ConnectFromHidden(HiddenNo) - 1;
% --if a W1 weight is eliminated --
else
ConnectToHidden(HiddenNo) = ConnectToHidden(HiddenNo) - 1;
end
if ElimWeights==2,
ConnectFromHidden(HiddenNo) = 0;
end
end
% ----- Eliminate one weight -----
if ElimWeights==1,
theta_red = theta_red - gamma(min_index)*H_inv(:,min_index);
theta_red(min_index) = 0; % Non-zero weights
theta(theta_index) = theta_red;
tmpindex = [1:min_index-1 min_index+1:length(theta_index)];
theta_index = theta_index(tmpindex);
theta_red = theta(theta_index);
reduced = reduced-1; % Remaining weights
theta_data(:,reduced) = theta; % Store parameter vector
D = D([1:min_index-1 min_index+1:length(D)]);
% --- Update inverse Hessian ---
H_inv = H_inv(tmpindex,tmpindex)-H_inv(tmpindex,min_index)...
*H_inv(min_index,tmpindex)/H_inv(min_index,min_index);
if D~=0, R = R(tmpindex,tmpindex); end
% ----- Eliminate more than one weight -----
elseif ElimWeights==2,
LEidx = length(Eidx); % # of weights to be pruned
theta_red = theta_red - H_inv(:,Eidx)*gamma_new;
theta_red(Eidx) = zeros(LEidx,1); % Non-zero weights
theta(theta_index) = theta_red;
tmpindex = [1:Eidx(1)-1];
for k=2:LEidx,
tmpindex = [tmpindex Eidx(k-1)+1:Eidx(k)-1];
end
tmpindex = [tmpindex Eidx(LEidx)+1:length(theta_index)];
theta_index = theta_index(tmpindex);
theta_red = theta(theta_index);
reduced = reduced-length(Eidx); % Remaining weights
theta_data(:,[reduced reduced+LEidx-1]) = theta(:,ones(1,LEidx)); % Store theta
D = D(tmpindex);
% --- Update inverse Hessian ---
H_inv = H_inv(tmpindex,tmpindex)-H_inv(tmpindex,Eidx)...
*inv(H_inv(Eidx,Eidx))*H_inv(Eidx,tmpindex);
if D~=0, R = R(tmpindex,tmpindex); end
bpr = bpr + LEidx-1;
end
bpr = bpr + 1;
end
pr = pr + bpr-1; % Total # of pruned weights
% -- Put the parameters back into the weight matrices --
W1 = reshape(theta(parameters2+1:parameters),inputs+1,hidden)';
W2 = reshape(theta(1:parameters2),hidden+1,outputs)';
FirstTimeFlag=0;
end
%----------------------------------------------------------------------------------
%------------- END OF NETWORK PRUNING -------------
%----------------------------------------------------------------------------------
fprintf('\n\n\n --> Pruning session terminated <--\n\n\n');
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -