📄 nnprune.m
字号:
function [theta_data,PI_vector,FPE_vector,PI_test_vec,deff_vec,pvec]=...
nnprune(method,NetDef,W1,W2,U,Y,NN,trparms,prparms,U2,Y2,skip,Chat)
% NNPRUNE
% -------
% This function applies the Optimal Brain Surgeon (OBS) strategy for
% pruning neural network models of dynamic systems. That is networks
% trained by NNARX, NNOE, NNARMAX1, NNARMAX2, or their recursive
% counterparts.
%
%
% CALL:
% [theta_data,NSSEvec,FPEvec,NSSEtestvec,deff,pvec]=...
% nnprune(method,NetDef,W1,W2,U,Y,NN,trparms,prparms,U2,Y2,skip,Chat)
%
% INPUT:
% method : The function applied for generating the model. For
% example method='nnarx' or method='nnoe'
% NetDef, W1, W2,
% U, Y, trparms : See for example the function MARQ
% U2,Y2 : Test data. This can be used for pointing out
% the optimal network architecture is achieved. Pass
% two []'s if a test set is not available.
% skip (optional) : See for example NNOE or NNARMAX1/2. If passed as []
% it is set to 0.
% Chat (optional) : See NNARMAX1
% prparms : Parameters associated with the pruning session
% prparms = [iter RePercent]
% iter : Max. number of retraining iterations
% RePercent : Prune 'RePercent' percent of the
% remaining weights (0 = prune one at a time)
% If passed as [], prparms=[50 0] will be used.
%
% OUTPUT:
% theta_data : Matrix containing the parameter vectors saved after each
% weight elimination round.
% NSSEvec : Vector containing the training error (SSE/2N) after each
% weight elimination.
% FPEvec : Contains the FPE estimate of the average generalization error
% NSSEtestvec : Contains the normalized SSE evaluated on the test set
% deff : Contains the "effective" number of weights
% pvec : Index to the above vectors
%
% SEE ALSO: OBSPRUNE and OBDPRUNE on how to prune ordinary feedforward
% networks. See also the function NETSTRUC on how to extract
% the weight matrices from the matrix theta_data (notice that for
% NNARMAX1-models one must remove the bottom deg(C) rows first).
%
% Programmed by : Magnus Norgaard, IAU/IMM, Technical Univ. of Denmark
% LastEditDate : July 17, 1996
%----------------------------------------------------------------------------------
%-------------- NETWORK INITIALIZATIONS -------------
%----------------------------------------------------------------------------------
more off
if ~isempty(Y2), TestDataFlag = 1; % Check if test data was given as argument
else TestDataFlag = 0;end
if isempty(prparms),
prparms=[50 0];
end
iter = prparms(1); % Max. retraining iterations
RePercent = prparms(2); % % of remaining weights to prune
[outputs,N] = size(Y); % # of outputs and # of data
[hidden,inputs] = size(W1); % # of hidden units
inputs=inputs-1; % # of inputs
L_hidden = find(NetDef(1,:)=='L')'; % Location of linear hidden neurons
H_hidden = find(NetDef(1,:)=='H')'; % Location of tanh hidden neuron
L_output = find(NetDef(2,:)=='L')'; % Location of linear output neurons
H_output = find(NetDef(2,:)=='H')'; % Location of tanh output neurons
parameters1= hidden*(inputs+1); % # of input-to-hidden weights
parameters2= outputs*(hidden+1); % # of hidden-to-output weights
parameters = parameters1 + parameters2; % Total # of weights
% Parameter vector containing all weights
if strcmp(method,'nnarmax1') | strcmp(method,'nnrarmx1'),
mflag=2;
parameters12 = parameters;
nc = length(Chat)-1;
parameters = parameters12+nc;
theta = [reshape(W2',parameters2,1) ; reshape(W1',parameters1,1) ; Chat(2:nc+1)'];
else
theta = [reshape(W2',parameters2,1) ; reshape(W1',parameters1,1)];
end
theta_index = find(theta); % Index to weights<>0
theta_red = theta(theta_index); % Reduced parameter vector
reduced = length(theta_index); % The # of parameters in theta_red
reduced0 = reduced; % Copy of 'reduced'. Will be constant
theta_data=zeros(parameters,parameters);% Matrix used for collecting theta vectors
theta_data(:,reduced) = theta; % Insert 'initial' theta
p0 = 1e6; % Diag. element of H_inv (no weight decay)
H_inv = p0*eye(reduced); % Initial inverse Hessian (no weight decay)
Ident = eye(outputs); % Identity matrix
PI_vector= zeros(1,reduced); % A vector containing the collected PI's
FPE_vector= zeros(1,reduced); % Vector used for collecting FPE estimates
if length(trparms)==4, % Scalar weight decay parameter
D0 = trparms(4*ones(1,reduced))';
elseif length(trparms)==5, % Two weight decay parameters
D0 = trparms([4*ones(1,parameters2) 5*ones(1,parameters1)])';
D0 = D0(theta_index);
else % No weight decay D = 0;
D0 = 0;
end
D = D0;
deff_vec = zeros(1,reduced); % The effective number of parameters
minweights = 2; % Prune until 'minweights' weights remain
FirstTimeFlag=1; % Initialize flag
pr = 0; % Initialize counter
pvec=[]; % Initialize index vector
HiddenIndex = ones((hidden+1),1); % Connection to hidden no.
for k=1:hidden,
HiddenIndex = [HiddenIndex;k*ones(inputs+1,1)];
end
ConnectToHidden = (inputs+1)*ones(hidden,1); % Connections to each hidden unit
if ~exist('skip')
skip=0;
elseif isempty(skip),
skip=0;
end
skip=skip+1;
if ~exist('Chat'), Chat=[]; end;
% ---------- NNARX model ----------
if strcmp(method,'nnarx') | strcmp(method,'nnrarx'),
mflag=1;
if length(NN)==1 % nnar model
nb = 0;
nk = 0;
nu = 0;
else % nnarx or nnoe model
[nu,N] = size(U);
nb = NN(2:1+nu);
nk = NN(2+nu:1+2*nu);
end
nc = 0;
% --------- NNARMAX1 model --------
elseif strcmp(method,'nnarmax1') | strcmp(method,'nnrarmx1'),
mflag=2;
% --------- NNARMAX2 model --------
elseif strcmp(method,'nnarmax2') | strcmp(method,'nnrarmx2'),
mflag=3;
% --------- NNOE model --------
elseif strcmp(method,'nnoe'),
mflag=4;
if length(NN)==1 % nnar model
nb = 0;
nk = 0;
nu = 0;
else % nnarx or nnoe model
[nu,N] = size(U);
nb = NN(2:1+nu);
nk = NN(2+nu:1+2*nu);
end
nc = 0;
else
disp('Unknown method!!!!!!!!');
break
end
if mflag==2 | mflag==3,
if length(NN)==2 % nnarma model
nc = NN(2);
nb = 0;
nk = 0;
nu = 0;
else % nnarmax model
[nu,Ndat]= size(U);
nb = NN(2:1+nu);
nc = NN(2+nu);
nk = NN(2+nu+1:2+2*nu);
end
end
% --------- Common initializations --------
Ndat = length(Y); % # of data
na = NN(1);
nab = na+sum(nb);
nabc = nab+nc;
nmax = max([na,nb+nk-1,nc]); % 'Oldest' signal used as input to the model
N = Ndat - nmax; % Size of training set
NN2 = N-skip +1;
if TestDataFlag, % Initializations if a test set exists
Ndat2 = length(Y2); % Total # of data in test set
N2 = Ndat2 - nmax; % Size of test set
N2tot = N2 - skip+1;
ytest1 = zeros(hidden,N2); % Hidden layer outputs
ytest1 = [ytest1;ones(1,N2)]; % Hidden layer outputs
ytest2 = zeros(outputs,N2); % Network output
%------ CONSTRUCT THE REGRESSION MATRIX PHI ------
if mflag~=3,
PHI2 = zeros(nab,N2);
else
PHI2 = zeros(nabc,N2);
end
jj = nmax+1:Ndat2;
for k = 1:na, PHI2(k,:) = Y2(jj-k); end
index4 = na;
for kk = 1:nu,
for k = 1:nb(kk), PHI2(k+index4,:) = U2(kk,jj-k-nk(kk)+1); end
index4 = index4 + nb(kk);
end
PHI2_aug = [PHI2;ones(1,N2)]; % Augment PHI with a row containing ones
Y2 = Y2(nmax+1:Ndat2);
PI_test_vec = zeros(1,reduced); % Collected PI's for the test set
end
%----------------------------------------------------------------------------------
%--------------- MAIN LOOP --------------
%----------------------------------------------------------------------------------
while reduced>=minweights,
% >>>>>>>>>>>>>>>>>>>>>>>>> Retrain Network <<<<<<<<<<<<<<<<<<<<<<<<<<<
% -- Don't retrain the first time --
if ~FirstTimeFlag,
if mflag==1,
[W1,W2,dummy1,dummy2,dummy3] = nnarx(NetDef,NN,W1,W2,[iter,0,1,D'],Y,U);
elseif mflag==2,
[W1,W2,Chat] = nnarmax1(NetDef,NN,W1,W2,Chat,[iter,0,1,D'],skip,Y,U);
elseif mflag==3,
[W1,W2,dummy1,dummy2,dummy3] = nnarmax2(NetDef,NN,W1,W2,[iter,0,1,D'],skip,Y,U);
elseif mflag==4,
[W1,W2,dummy1,dummy2,dummy3] = nnoe(NetDef,NN,W1,W2,[iter,0,1,D'],skip,Y,U);
end
if mflag==2,
theta = [reshape(W2',parameters2,1) ; reshape(W1',parameters1,1) ; Chat(2:nc+1)'];
else
theta = [reshape(W2',parameters2,1) ; reshape(W1',parameters1,1)];
end
theta_red = theta(theta_index); % Vector containing non-zero parameters
if ElimWeights==1, % Store parameter vector
theta_data(:,reduced) = theta;
else
theta_data(:,[reduced reduced+1]) = [theta theta];
end
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -