⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 obdprune.m

📁 类神经网路─MATLAB的应用(范例程式)
💻 M
字号:
function [theta_data,PI_vector,FPE_vector,PI_test_vec,deff_vec,pvec]=...
                             obdprune(NetDef,W1,W2,PHI,Y,trparms,prparms,PHI2,Y2)
%  OBDPRUNE
%  --------
%           This function applies the Optimal Brain Damage (OBD) algorithm
%           for pruning ordinary feedforward neural networks.
%
%  CALL:
%   [theta_data,NSSEvec,FPEvec,NSSEtestvec,deff,pvec]=...
%                    obdprune(NetDef,W1,W2,PHI,Y,trparms,prparms,PHI2,Y2)
%
%  INPUT:
%  NetDef, W1, W2,
%  PHI, Y, trparms    : See for example the function MARQ
%  PHI2,Y2 (optional) : Test data. Can be used for pointing out the
%                       optimal network architecture. 
%  prparms            : Parameters associated with the pruning session
%                       prparms = [iter RePercent]
%                       iter      : Max. number of retraining iterations
%                       RePercent : Prune 'RePercent' percent of the
%                                   remaining weights (0 = prune one at a time)
%                       If passed as [], prparms=[50 0] will be used.
% 
%  OUTPUT:
%  theta_data  : Matrix containing all the parameter vectors
%  NSSEvec     : Vector containing the training error (SSE/2N) after each
%                weight elimination
%  FPEvec      : Contains the FPE estimate of the average generalization error 
%  NSSEtestvec : Contains the test error
%  deff        : Contains the "effective" number of weights
%  pvec        : Index into the above vectors
%
%  Programmed by : Magnus Norgaard, IAU/IMM, Technical University of Denmark
%  LastEditDate  : July 17, 1996


%----------------------------------------------------------------------------------
%--------------             NETWORK INITIALIZATIONS                   -------------
%----------------------------------------------------------------------------------
more off
if nargin>7, TestDataFlag = 1;          % Check if test data was given as argument
else TestDataFlag = 0;end
isempty(if prparms),
  prparms=[50 0];
end
iter      = prparms(1);                 % Max. retraining iterations
RePercent = prparms(2);                 % % of remaining weights to prune
[outputs,N] = size(Y);                  % # of outputs and # of data
[hidden,inputs] = size(W1);             % # of hidden units 
inputs=inputs-1;                        % # of inputs
L_hidden = find(NetDef(1,:)=='L')';     % Location of linear hidden neurons
H_hidden = find(NetDef(1,:)=='H')';     % Location of tanh hidden neuron
L_output = find(NetDef(2,:)=='L')';     % Location of linear output neurons
H_output = find(NetDef(2,:)=='H')';     % Location of tanh output neurons
y1       = zeros(hidden,N);             % Hidden layer outputs
y2       = zeros(outputs,N);            % Network output
index = outputs*(hidden+1) + 1 + [0:hidden-1]*(inputs+1); % A usefull vector!
index2 = (0:N-1)*outputs;               % Yet another usefull vector
PHI_aug  = [PHI;ones(1,N)];             % Augment PHI with a row containg ones
parameters1= hidden*(inputs+1);         % # of input-to-hidden weights
parameters2= outputs*(hidden+1);        % # of hidden-to-output weights
parameters = parameters1 + parameters2; % Total # of weights
ones_h   = ones(hidden+1,1);            % A vector of ones
ones_i   = ones(inputs+1,1);            % Another vector of ones
                                        % Parameter vector containing all weights
theta = [reshape(W2',parameters2,1) ; reshape(W1',parameters1,1)];
theta_index = find(theta);              % Index to weights<>0
theta_red = theta(theta_index);         % Reduced parameter vector
reduced  = length(theta_index);         % The # of parameters in theta_red
reduced0 = reduced;                     % Copy of 'reduced'. Will be constant
theta_data=zeros(parameters,parameters);% Matrix used for collecting theta vectors
theta_data(:,reduced) = theta;          % Insert 'initial' theta
PSI      = zeros(parameters,outputs*N); % Deriv. of each output w.r.t. each weight
p0       = 1e6;                         % Diag. element of H_inv (no weight decay)
H_inv    = p0*eye(reduced);             % Initial inverse Hessian (no weight decay)
Ident    = eye(outputs);                % Identity matrix
PI_vector= zeros(1,reduced);            % A vector containing the collected PI's
FPE_vector= zeros(1,reduced);           % Vector used for collecting FPE estimates
if length(trparms)==4,                  % Scalar weight decay parameter
  D0 = trparms(4*ones(1,reduced))';      
elseif length(trparms)==5,              % Two weight decay parameters
  D0 = trparms([4*ones(1,parameters2) 5*ones(1,parameters1)])';
  D0 = D0(theta_index);
else                                    % No weight decay  D = 0;
  D0 = 0;
end
D = D0;
if TestDataFlag,                        % Initializations if a test set exists
  [tmp,N2]    = size(Y2);               % # of data in test set
  ytest1      = zeros(hidden,N2);       % Hidden layer outputs 
  ytest2      = zeros(outputs,N2);      % Network output
  PHI2_aug    = [PHI2;ones(1,N2)];      % Augment PHI with a row containing ones
  PI_test_vec = zeros(1,reduced);       % Collected PI's for the test set
end
deff_vec = zeros(1,reduced);            % The effective number of parameters
minweights = 2;                         % Prune until 'minweights' weights remain
FirstTimeFlag=1;                        % Initialize flag
pr = 0;                                 % Initialize counter
pvec=[];                                % Initialize index vector


%----------------------------------------------------------------------------------
%---------------                    MAIN LOOP                        --------------
%----------------------------------------------------------------------------------
while reduced>=minweights,   
        
  % >>>>>>>>>>>>>>>>>>>>>>>>>      Retrain Network      <<<<<<<<<<<<<<<<<<<<<<<<<<<
  % -- Don't retrain the first time --
  if ~FirstTimeFlag,
    [W1,W2,dummy1,dummy2,dummy3] = marq(NetDef,W1,W2,PHI,Y,[iter,0,1,D']);
    theta = [reshape(W2',parameters2,1) ; reshape(W1',parameters1,1)];
    theta_red = theta(theta_index);       % Vector containing  non-zero parameters
    theta_data(:,reduced) = theta;        % Store parameter vector
  end
      

  % >>>>>>>>>>>>>  COMPUTE NETWORK OUTPUT FROM TEST DATA y2(theta)   <<<<<<<<<<<<<<
  % -- Compute only if a test set is present --
  if TestDataFlag,
    htest1 = W1*PHI2_aug; 
    ytest1(H_hidden,:) = pmntanh(htest1(H_hidden,:));
    ytest1(L_hidden,:) = htest1(L_hidden,:);
    ytest1_aug=[ytest1;ones(1,N2)];
        
    htest2 = W2*ytest1_aug;
    ytest2(H_output,:) = pmntanh(htest2(H_output,:));
    ytest2(L_output,:) = htest2(L_output,:);

    E        = Y2 - ytest2;               % Training error
    E_vector = E(:);                      % Reshape E into a long vector
    SSE      = E_vector'*E_vector;        % Sum of squared errors (SSE)
    PI_test = SSE/(2*N2);                 % Cost function evaluated on test data
    PI_test_vec(reduced) = PI_test;       % Collect PI_test in vector
  end


  % >>>>>>>>>>>  COMPUTE NETWORK OUTPUT FROM TRAINING DATA y2(theta)   <<<<<<<<<<<<
  h1 = W1*PHI_aug;  
  y1(H_hidden,:) = pmntanh(h1(H_hidden,:));
  y1(L_hidden,:) = h1(L_hidden,:);
  y1_aug=[y1; ones(1,N)];

  h2 = W2*y1_aug;
  y2(H_output,:) = pmntanh(h2(H_output,:));
        y2(L_output,:) = h2(L_output,:);
        
  E        = Y - y2;                      % Training error
  E_vector = E(:);                        % Reshape E into a long vector
  SSE      = E_vector'*E_vector;          % Sum of squared errors (SSE)
  PI = SSE/(2*N);                         % Value of cost function
  PI_vector(reduced) = PI;                % Collect PI in vector


  % >>>>>>>>>>>>>>>>>>>>>>>>>   COMPUTE THE PSI MATRIX    <<<<<<<<<<<<<<<<<<<<<<<<<
  % (The derivative of each network output (y2) with respect to each weight)

  % ============   Elements corresponding to the linear output units   ============
  for i = L_output',
    index1 = (i-1) * (hidden + 1) + 1;

    % -- The part of PSI corresponding to hidden-to-output layer weights --
    PSI(index1:index1+hidden,index2+i) = y1_aug;
    % ---------------------------------------------------------------------
  
    % -- The part of PSI corresponding to input-to-hidden layer weights ---
    for j = L_hidden',
       PSI(index(j):index(j)+inputs,index2+i) = W2(i,j)*PHI_aug;
    end
      
    for j = H_hidden',
      tmp = W2(i,j)*(1-y1(j,:).*y1(j,:)); 
      PSI(index(j):index(j)+inputs,index2+i) = tmp(ones_i,:).*PHI_aug;
    end 
    % ---------------------------------------------------------------------    
  end
  
  % ======= Elements corresponding to the hyperbolic tangent output units   =======
  for i = H_output',
    index1 = (i-1) * (hidden + 1) + 1;

    % -- The part of PSI corresponding to hidden-to-output layer weights --
    tmp = 1 - y2(i,:).*y2(i,:);
    PSI(index1:index1+hidden,index2+i) = y1_aug.*tmp(ones_h,:);
    % ---------------------------------------------------------------------
         
    % -- The part of PSI corresponding to input-to-hidden layer weights ---
    for j = L_hidden',
      tmp = W2(i,j)*(1-y2(i,:).*y2(i,:));
      PSI(index(j):index(j)+inputs,index2+i) = tmp(ones_i,:).* PHI_aug;
    end
      
    for j = H_hidden',
      tmp  = W2(i,j)*(1-y1(j,:).*y1(j,:));
      tmp2 = (1-y2(i,:).*y2(i,:));
      PSI(index(j):index(j)+inputs,index2+i) = tmp(ones_i,:)...
                                                .*tmp2(ones_i,:).* PHI_aug;
    end
    % ---------------------------------------------------------------------
  end
        

  % >>>>>>>>>>>>>>>>>>>>>>>>    COMPUTE THE HESSIAN MATRIX   <<<<<<<<<<<<<<<<<<<<<<
  PSI_red = PSI(theta_index,:);
  R     = PSI_red*PSI_red';
  H     = R;
  index3   = 1:(reduced+1):(reduced^2);            % A third useful vector
  H(index3) = H(index3) + D';                      % Add weight decay to diagonal
  gamma1 = (diag(R)./diag(H))'*(diag(R)./diag(H)); % Effective # of parameters
  FPE = (N+gamma1)*PI/(N-gamma1);                  % FPE estimate
  
  % --- Use the following lines for a more precise FPE estimate ---
  % H_inv = inv(H);      
  % RHinv = R*H_inv;
  % gamma1=trace(RHinv*RHinv);
  % gamma2=trace(RHinv);        
  % FPE = (N+gamma1)*PI/(N+gamma1-2*gamma2);  
  
  FPE_vector(reduced) = FPE;                       % Collect FPE estimate
  deff_vec(reduced)=gamma1;                        % Collect effective # of param.


  % >>>>>>>>>>>>    PLOT THE PI's AND THE CURRENT NETWORK STRUCTURE   <<<<<<<<<<<<<
  % --- Draw PI's ---
  figure(1);
  pvec=[reduced pvec];
  if TestDataFlag
    plot(pvec,PI_vector(pvec),'x',pvec,FPE_vector(pvec),'+',...
    pvec,PI_test_vec(pvec),'o')
    title('x = training error,   + = FPE,   o = test error')
  else
    plot(pvec,PI_vector(pvec),'x',pvec,FPE_vector(pvec),'+')
    title('x = training error,  + = FPE')
  end
  set(gca,'Xlim',[0 reduced0]);
  xlabel('Parameters');
  drawnow
     
  % --- Draw pruned network ---
  figure(2);
  drawnet(W1,W2,eps);
  title(['Network after having pruned ',int2str(pr),' weights']);
  drawnow


  % >>>>>>>>>>>>>  ELIMINATE THE WEIGHT HAVING THE SMALLEST SALIENCY  <<<<<<<<<<<<<
  nopruned = floor(max(1,reduced*RePercent/100));       % No of parms to prune
  zeta = (D+diag(R)/2).*theta_red.*theta_red;           % Saliencies
  [zeta_sorted,min_index] = sort(zeta);                 % Sort in ascending order
  theta_red(min_index(1:nopruned)) = zeros(nopruned,1); % Eliminate weights
  theta(theta_index) =theta_red;
       
  theta_index = theta_index(sort(min_index(nopruned+1:reduced)));
  theta_red = theta(theta_index);                       % Non-zero weights
  reduced  = reduced - nopruned;                        % Remaining weights
  pr = pr + nopruned;                                   % Total # of pruned weights
  theta_data(:,reduced) = theta;                        % Store parameter vector
  D = D0(theta_index);   
   
  % -- Put the parameters back into the weight matrices --
  W1 = reshape(theta(parameters2+1:parameters),inputs+1,hidden)';
  W2 = reshape(theta(1:parameters2),hidden+1,outputs)';
  FirstTimeFlag=0;
end

%----------------------------------------------------------------------------------
%-------------                END OF NETWORK PRUNING                  -------------
%----------------------------------------------------------------------------------

fprintf('\n\n\n  -->  Pruning session terminated  <--\n\n\n');

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -