⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 obsprune.m

📁 类神经网路─MATLAB的应用(范例程式)
💻 M
📖 第 1 页 / 共 2 页
字号:
      PSI(index(j):index(j)+inputs,index2+i) = tmp(ones_i,:)...
                                                .*tmp2(ones_i,:).* PHI_aug;
    end
    % ---------------------------------------------------------------------
  end
        
        
  % >>>>>>>>>>>>>>>>>>>>>>>>    COMPUTE THE HESSIAN MATRIX   <<<<<<<<<<<<<<<<<<<<<<
  PSI_red = PSI(theta_index,:);

  % --- Inverse Hessian if no weight decay ---
  if D==0,
    H_inv = p0*eye(reduced);
    for k=1:outputs:N,                        % Iterative solution
      psi=PSI_red(:,(k-1)*outputs+1:k*outputs);
      H_inv = H_inv - H_inv*psi*inv(Ident + psi'*H_inv*psi)*psi'*H_inv;
    end
    FPE = (N+reduced)*PI/(N-reduced);         % FPE estimate
    gamma1 = reduced;
    
  % --- Inverse Hessian if weight decay is being used ---
  else
    R     = PSI_red*PSI_red';
    H     = R;
    index3   = 1:(reduced+1):(reduced^2);       % A third useful vector
    H(index3) = H(index3) + D';                 % Add weight deacy to diagonal
    H_inv = inv(H);                             % Inverse Hessian
    RHinv = R*H_inv;
    gamma1=trace(RHinv*RHinv);                  % Effective # of parameters
    gamma2=trace(RHinv);        
    FPE = (N+gamma1)*PI/(N+gamma1-2*gamma2);    % FPE estimate
  end
  FPE_vector(reduced) = FPE;                    % Collect FPE estimate
  deff_vec(reduced)=gamma1;                     % Collect effective # of param.


  % >>>>>>>>>>>>    PLOT THE PI's AND THE CURRENT NETWORK STRUCTURE   <<<<<<<<<<<<<
  % --- Draw PI's ---
  figure(1);
  pvec=[reduced pvec];
  if TestDataFlag
    plot(pvec,PI_vector(pvec),'x',pvec,FPE_vector(pvec),'+',...
    pvec,PI_test_vec(pvec),'o')
    title('x = training error,   + = FPE,   o = test error')
  else
    plot(pvec,PI_vector(pvec),'x',pvec,FPE_vector(pvec),'+')
    title('x = training error,  + = FPE')
  end
  set(gca,'Xlim',[0 reduced0]);
  xlabel('Parameters');
  drawnow
     
  % --- Draw pruned network ---
  figure(2);
  drawnet(W1,W2,eps);
  title(['Network after having pruned ',int2str(pr),' weights']);
  figure(2); drawnow


  % >>>>>>>>>>>>>  ELIMINATE THE WEIGHT HAVING THE SMALLEST SALIENCY  <<<<<<<<<<<<<
  nopruned = floor(max(1,reduced*RePercent/100)); % No of parms to prune
  if reduced<=minweights, break; end
  bpr = 1;
  while bpr<=nopruned
  
    % ----- Calculate all saliences -----
    if D==0,
      gamma = theta_red./diag(H_inv);             % Lagrange multipliers
      zeta = theta_red.*gamma;                    % Salincies if D=0
    else
      gamma = theta_red./diag(H_inv);             % Lagrange multipliers
      HRH   = H_inv*R*H_inv;
                                                  % Saliencies
      zeta  = gamma.*(H_inv*(D.*theta_red))+gamma.*gamma.*diag(HRH)/2; 
    end
    
  
   % ----- Add a large number to uninteresting saliences -----
   index5 = find(ConnectFromHidden==1);
   [zeta_max,max_index] = max(zeta);              % Find largest saliency
   for hidno=index5',
     % Find the weight's location in theta
     index6 = hidno:(hidden+1):(hidden+1)*outputs;
     IndexInTheta = find(theta(index6)~=0);
     IndexInTheta = index6(IndexInTheta);
     
     % Find the location in zeta
     IndexInZeta = find(theta_index == IndexInTheta);
     zeta(IndexInZeta) = 2*abs(zeta_max);
   end
   
    Critical=[]; WhileFlag=1; ElimWeights=1;
    while WhileFlag,
      WhileFlag = 0;
      [zeta_min,min_index] = min(zeta);               % Find smallest saliency
      
      % -- Check if weight in question belongs to W1 or W2 --
      if (theta_index(min_index)>=1 & theta_index(min_index)<=(hidden+1)*outputs),
        W=2;     % Belong to W2
      else
        W=1;     % Belong to W1
      end

      HiddenNo = HiddenIndex(theta_index(min_index)); % Connection to hidden no?

      % -- If not a bias for an output unit --
      if ~( any([hidden+1:hidden+1:outputs*(hidden+1)]==theta_index(min_index)) ),
      
        % -- If weight has only one connection leading to it --
        if W==1 & ConnectToHidden(HiddenNo)==1,
          index6 = HiddenNo:(hidden+1):(hidden+1)*outputs;
          HiddenNoRed = [];
          for k=index6,
            HiddenNoRed = [HiddenNoRed,find(theta_index==k)];% Index in zeta vector
          end
          Eidx = [HiddenNoRed,min_index];     % Indices to weights to and from unit
          gamma_new = inv(H_inv(Eidx,Eidx))*theta_red(Eidx);
          if find(Critical==min_index);
            ElimWeights=2;
            break;
          end
          if D==0,
            zeta(min_index) = gamma_new'*theta_red(Eidx);
          else
            zeta(min_index) = gamma_new'*H_inv(Eidx,:)*(D.*theta_red)...
                              + gamma_new'*HRH(Eidx,Eidx)*gamma_new/2;
          end
          Critical = [Critical;min_index];
          WhileFlag = 1;
        end
      end      
    end
    
    % If not a bias for an output unit
    if ~(any([hidden+1:hidden+1:outputs*(hidden+1)]==theta_index(min_index))),
    
      % -- If a W2 weight is eliminated --
      if W==2,
        ConnectFromHidden(HiddenNo) = ConnectFromHidden(HiddenNo) - 1;
        
      % --if a W1 weight is eliminated -- 
      else
        ConnectToHidden(HiddenNo) = ConnectToHidden(HiddenNo) - 1;
      end
      
      if ElimWeights==2,
        ConnectFromHidden(HiddenNo) = 0;
      end
    end
    
    % ----- Eliminate one weight -----
    if ElimWeights==1, 
      theta_red = theta_red - gamma(min_index)*H_inv(:,min_index);
      theta_red(min_index) = 0;                     % Non-zero weights
      theta(theta_index) = theta_red;
      tmpindex = [1:min_index-1 min_index+1:length(theta_index)];
      theta_index = theta_index(tmpindex);
      theta_red = theta(theta_index);
      reduced  = reduced-1;                         % Remaining weights
      theta_data(:,reduced) = theta;                % Store parameter vector
      D = D([1:min_index-1 min_index+1:length(D)]);
    
      % --- Update inverse Hessian ---
      H_inv = H_inv(tmpindex,tmpindex)-H_inv(tmpindex,min_index)...
               *H_inv(min_index,tmpindex)/H_inv(min_index,min_index);
      if D~=0, R = R(tmpindex,tmpindex); end
      
     % ----- Eliminate more than one weight -----
    elseif ElimWeights==2,
      LEidx = length(Eidx);                  % # of weights to be pruned
      theta_red = theta_red - H_inv(:,Eidx)*gamma_new;
      theta_red(Eidx) = zeros(LEidx,1);      % Non-zero weights
      theta(theta_index) = theta_red;
      tmpindex = [1:Eidx(1)-1];
      for k=2:LEidx,
        tmpindex = [tmpindex Eidx(k-1)+1:Eidx(k)-1];
      end
      tmpindex = [tmpindex Eidx(LEidx)+1:length(theta_index)];
      theta_index = theta_index(tmpindex);
      theta_red = theta(theta_index);
      reduced  = reduced-length(Eidx);                   % Remaining weights
      theta_data(:,[reduced reduced+LEidx-1]) = theta(:,ones(1,LEidx)); % Store theta
      D = D(tmpindex);
    
      % --- Update inverse Hessian ---
      H_inv = H_inv(tmpindex,tmpindex)-H_inv(tmpindex,Eidx)...
               *inv(H_inv(Eidx,Eidx))*H_inv(Eidx,tmpindex);
      if D~=0, R = R(tmpindex,tmpindex); end
      bpr = bpr + LEidx-1;
    end
    bpr = bpr + 1;
  end    
  pr = pr + bpr-1;                             % Total # of pruned weights
  
  % -- Put the parameters back into the weight matrices --
  W1 = reshape(theta(parameters2+1:parameters),inputs+1,hidden)';
  W2 = reshape(theta(1:parameters2),hidden+1,outputs)';
  FirstTimeFlag=0;
end
%----------------------------------------------------------------------------------
%-------------                END OF NETWORK PRUNING                  -------------
%----------------------------------------------------------------------------------
fprintf('\n\n\n  -->  Pruning session terminated  <--\n\n\n');

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -