⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nnprune.m

📁 类神经网路─MATLAB的应用(范例程式)
💻 M
📖 第 1 页 / 共 2 页
字号:
function [theta_data,PI_vector,FPE_vector,PI_test_vec,deff_vec,pvec]=...
             nnprune(method,NetDef,W1,W2,U,Y,NN,trparms,prparms,U2,Y2,skip,Chat)
%  NNPRUNE
%  -------
%         This function applies the Optimal Brain Surgeon (OBS) strategy for
%         pruning neural network models of dynamic systems. That is networks
%         trained by NNARX, NNOE, NNARMAX1, NNARMAX2, or their recursive
%         counterparts.
%
%
%  CALL:
% [theta_data,NSSEvec,FPEvec,NSSEtestvec,deff,pvec]=...
%          nnprune(method,NetDef,W1,W2,U,Y,NN,trparms,prparms,U2,Y2,skip,Chat)
%
%  INPUT:
%  method             : The function applied for generating the model. For
%                       example method='nnarx' or method='nnoe'
%  NetDef, W1, W2,
%  U, Y, trparms      : See for example the function MARQ
%  U2,Y2              : Test data. This can be used for pointing out
%                       the optimal network architecture is achieved. Pass
%                       two []'s if a test set is not available.
%  skip (optional)    : See for example NNOE or NNARMAX1/2. If passed as []
%                       it is set to 0.
%  Chat (optional)    : See NNARMAX1
%  prparms            : Parameters associated with the pruning session
%                       prparms = [iter RePercent]
%                       iter      : Max. number of retraining iterations
%                       RePercent : Prune 'RePercent' percent of the
%                                   remaining weights (0 = prune one at a time)
%                       If passed as [], prparms=[50 0] will be used.
%  
%  OUTPUT:
%  theta_data  : Matrix containing the parameter vectors saved after each
%                weight elimination round.
%  NSSEvec     : Vector containing the training error (SSE/2N) after each
%                weight elimination.
%  FPEvec      : Contains the FPE estimate of the average generalization error
%  NSSEtestvec : Contains the normalized SSE evaluated on the test set
%  deff        : Contains the "effective" number of weights
%  pvec        : Index to the above vectors
%
%  SEE ALSO: OBSPRUNE and OBDPRUNE on how to prune ordinary feedforward
%            networks. See also the function NETSTRUC on how to extract
%            the weight matrices from the matrix theta_data (notice that for
%            NNARMAX1-models one must remove the bottom deg(C) rows first).
% 
%  Programmed by : Magnus Norgaard, IAU/IMM, Technical Univ. of Denmark
%  LastEditDate  : July 17, 1996

%----------------------------------------------------------------------------------
%--------------             NETWORK INITIALIZATIONS                   -------------
%----------------------------------------------------------------------------------
more off
if ~isempty(Y2), TestDataFlag = 1;      % Check if test data was given as argument
else TestDataFlag = 0;end
if isempty(prparms),
  prparms=[50 0];
end
iter      = prparms(1);                 % Max. retraining iterations
RePercent = prparms(2);                 % % of remaining weights to prune
[outputs,N] = size(Y);                  % # of outputs and # of data
[hidden,inputs] = size(W1);             % # of hidden units 
inputs=inputs-1;                        % # of inputs
L_hidden = find(NetDef(1,:)=='L')';     % Location of linear hidden neurons
H_hidden = find(NetDef(1,:)=='H')';     % Location of tanh hidden neuron
L_output = find(NetDef(2,:)=='L')';     % Location of linear output neurons
H_output = find(NetDef(2,:)=='H')';     % Location of tanh output neurons
parameters1= hidden*(inputs+1);         % # of input-to-hidden weights
parameters2= outputs*(hidden+1);        % # of hidden-to-output weights
parameters = parameters1 + parameters2; % Total # of weights
                                        % Parameter vector containing all weights
if strcmp(method,'nnarmax1') | strcmp(method,'nnrarmx1'),
  mflag=2;
  parameters12 = parameters;
  nc = length(Chat)-1;
  parameters = parameters12+nc;
  theta = [reshape(W2',parameters2,1) ; reshape(W1',parameters1,1) ; Chat(2:nc+1)'];
else
  theta = [reshape(W2',parameters2,1) ; reshape(W1',parameters1,1)];
end
theta_index = find(theta);              % Index to weights<>0
theta_red = theta(theta_index);         % Reduced parameter vector
reduced  = length(theta_index);         % The # of parameters in theta_red
reduced0 = reduced;                     % Copy of 'reduced'. Will be constant
theta_data=zeros(parameters,parameters);% Matrix used for collecting theta vectors
theta_data(:,reduced) = theta;          % Insert 'initial' theta
p0       = 1e6;                         % Diag. element of H_inv (no weight decay)
H_inv    = p0*eye(reduced);             % Initial inverse Hessian (no weight decay)
Ident    = eye(outputs);                % Identity matrix
PI_vector= zeros(1,reduced);            % A vector containing the collected PI's
FPE_vector= zeros(1,reduced);           % Vector used for collecting FPE estimates
if length(trparms)==4,                  % Scalar weight decay parameter
  D0 = trparms(4*ones(1,reduced))';      
elseif length(trparms)==5,              % Two weight decay parameters
  D0 = trparms([4*ones(1,parameters2) 5*ones(1,parameters1)])';
  D0 = D0(theta_index);
else                                    % No weight decay  D = 0;
  D0 = 0;
end
D = D0;

deff_vec = zeros(1,reduced);            % The effective number of parameters
minweights = 2;                         % Prune until 'minweights' weights remain
FirstTimeFlag=1;                        % Initialize flag
pr = 0;                                 % Initialize counter
pvec=[];                                % Initialize index vector
HiddenIndex = ones((hidden+1),1);       % Connection to hidden no.
for k=1:hidden,
  HiddenIndex = [HiddenIndex;k*ones(inputs+1,1)];
end
ConnectToHidden = (inputs+1)*ones(hidden,1); % Connections to each hidden unit

if ~exist('skip')
  skip=0;
elseif isempty(skip),
  skip=0;
end
skip=skip+1;
if ~exist('Chat'), Chat=[]; end;

% ---------- NNARX model ----------
if strcmp(method,'nnarx') | strcmp(method,'nnrarx'),
  mflag=1;
    if length(NN)==1                      % nnar model
    nb = 0;
    nk = 0;
    nu = 0;
  else                                  % nnarx or nnoe model
    [nu,N] = size(U); 
    nb = NN(2:1+nu); 
    nk = NN(2+nu:1+2*nu);
  end
  nc = 0;
  
% --------- NNARMAX1 model --------
elseif strcmp(method,'nnarmax1') | strcmp(method,'nnrarmx1'),
  mflag=2;

% --------- NNARMAX2 model --------
elseif strcmp(method,'nnarmax2') | strcmp(method,'nnrarmx2'),
  mflag=3;

% --------- NNOE model --------
elseif strcmp(method,'nnoe'),
 mflag=4;
  if length(NN)==1                      % nnar model
    nb = 0;
    nk = 0;
    nu = 0;
  else                                  % nnarx or nnoe model
    [nu,N] = size(U); 
    nb = NN(2:1+nu); 
    nk = NN(2+nu:1+2*nu);
  end
  nc = 0;
else
  disp('Unknown method!!!!!!!!');
  break
end

if mflag==2 | mflag==3,
  if length(NN)==2                      % nnarma model
    nc     = NN(2);
    nb     = 0;
    nk     = 0;
    nu     = 0;
  else                                  % nnarmax model
    [nu,Ndat]= size(U); 
    nb     = NN(2:1+nu);
    nc     = NN(2+nu);
    nk     = NN(2+nu+1:2+2*nu);
  end
end

% --------- Common initializations --------
Ndat     = length(Y);                   % # of data
na = NN(1);
nab = na+sum(nb);
nabc = nab+nc;
nmax     = max([na,nb+nk-1,nc]);        % 'Oldest' signal used as input to the model
N        = Ndat - nmax;                 % Size of training set
NN2      = N-skip +1;
if TestDataFlag,                        % Initializations if a test set exists
  Ndat2 = length(Y2);                   % Total # of data in test set
  N2    = Ndat2 - nmax;                 % Size of test set
  N2tot = N2 - skip+1;
  ytest1 = zeros(hidden,N2);            % Hidden layer outputs 
  ytest1 = [ytest1;ones(1,N2)];         % Hidden layer outputs 
  ytest2 = zeros(outputs,N2);           % Network output

  %------  CONSTRUCT THE REGRESSION MATRIX PHI ------
  if mflag~=3,
    PHI2 = zeros(nab,N2);
  else
    PHI2 = zeros(nabc,N2);
  end
  jj  = nmax+1:Ndat2;
  for k = 1:na, PHI2(k,:)    = Y2(jj-k); end
  index4 = na;
  for kk = 1:nu,
    for k = 1:nb(kk), PHI2(k+index4,:) = U2(kk,jj-k-nk(kk)+1); end
    index4 = index4 + nb(kk);
  end
  PHI2_aug    = [PHI2;ones(1,N2)];      % Augment PHI with a row containing ones
  Y2    = Y2(nmax+1:Ndat2);
  PI_test_vec = zeros(1,reduced);       % Collected PI's for the test set
end

%----------------------------------------------------------------------------------
%---------------                    MAIN LOOP                        --------------
%----------------------------------------------------------------------------------
while reduced>=minweights,   
  % >>>>>>>>>>>>>>>>>>>>>>>>>      Retrain Network      <<<<<<<<<<<<<<<<<<<<<<<<<<< 
  % -- Don't retrain the first time --
  if ~FirstTimeFlag,
    if mflag==1,
      [W1,W2,dummy1,dummy2,dummy3] = nnarx(NetDef,NN,W1,W2,[iter,0,1,D'],Y,U);
    elseif mflag==2,
      [W1,W2,Chat] = nnarmax1(NetDef,NN,W1,W2,Chat,[iter,0,1,D'],skip,Y,U);
    elseif mflag==3,
      [W1,W2,dummy1,dummy2,dummy3] = nnarmax2(NetDef,NN,W1,W2,[iter,0,1,D'],skip,Y,U);
    elseif mflag==4,
      [W1,W2,dummy1,dummy2,dummy3] = nnoe(NetDef,NN,W1,W2,[iter,0,1,D'],skip,Y,U);
    end
    if mflag==2,
      theta = [reshape(W2',parameters2,1) ; reshape(W1',parameters1,1) ; Chat(2:nc+1)'];
    else
      theta = [reshape(W2',parameters2,1) ; reshape(W1',parameters1,1)];
    end
    theta_red = theta(theta_index);       % Vector containing non-zero parameters
    if ElimWeights==1,                    % Store parameter vector
      theta_data(:,reduced) = theta;
    else
      theta_data(:,[reduced reduced+1]) = [theta theta];
    end
  end  

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -