⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nnprune.m

📁 类神经网路─MATLAB的应用(范例程式)
💻 M
📖 第 1 页 / 共 2 页
字号:


  % >>>>>>>>>>>>>  COMPUTE NETWORK OUTPUT FROM TEST DATA y2(theta)   <<<<<<<<<<<<<<
  % -- Compute only if a test set is present -- 
  if TestDataFlag,
% ---------- NNARX model ----------
if mflag==1,
  htest1 = W1*PHI2_aug;  
  ytest1(H_hidden,:) = pmntanh(htest1(H_hidden,:));
  ytest1(L_hidden,:) = htest1(L_hidden,:);
    
  htest2 = W2*ytest1;
  ytest2(H_output,:) = pmntanh(htest2(H_output,:));
  ytest2(L_output,:) = htest2(L_output,:);

  E     = Y2 - ytest2;                    % Error between Y and deterministic part


% --------- NNARMAX1 model --------
elseif mflag==2,
  htest1 = W1*PHI2_aug;  
  ytest1(H_hidden,:) = pmntanh(htest1(H_hidden,:));
  ytest1(L_hidden,:) = htest1(L_hidden,:);
    
  htest2 = W2*ytest1;
  ytest2(H_output,:) = pmntanh(htest2(H_output,:));
  ytest2(L_output,:) = htest2(L_output,:);

  Ebar     = Y2 - ytest2;                    % Error between Y and deterministic part
  E        = filter(1,Chat,Ebar);            % Prediction error
  ytest2   = ytest2 - E;                       % One step ahead prediction


% --------- NNARMAX2 model --------
elseif mflag==3,
  for t=1:N2,
    htest1 = W1*PHI2_aug(:,t);  
    ytest1(H_hidden,t) = pmntanh(htest1(H_hidden));
    ytest1(L_hidden,t) = htest1(L_hidden);    

    htest2 = W2*ytest1(:,t);
    ytest2(H_output,t) = pmntanh(htest2(H_output,:));
    ytest2(L_output,t) = htest2(L_output,:);

    E(:,t) = Y2(:,t) - ytest2(:,t);          % Prediction error
    for d=1:min(nc,N2-t),
      PHI2_aug(nab+d,t+d) = E(:,t);
    end
  end


% ---------- NNOE model ----------
elseif mflag==4,
  for t=1:N2,
    htest1 = W1*PHI2_aug(:,t);;  
    ytest1(H_hidden,t) = pmntanh(htest1(H_hidden));
    ytest1(L_hidden,t) = htest1(L_hidden);    

    htest2 = W2*ytest1(:,t);
    ytest2(H_output,t) = pmntanh(htest2(H_output,:));
    ytest2(L_output,t) = htest2(L_output,:);

    for d=1:min(na,N2-t),
      PHI2_aug(d,t+d) = ytest2(:,t);
    end
  end
  E     = Y2 - ytest2;                    % Error between Y and deterministic part
end

    SSE     = E(skip:N2)*E(skip:N2)';     % Sum of squared errors (SSE)
    PI_test = SSE/(2*N2tot);              % Cost function evaluated on test data
    PI_test_vec(reduced) = PI_test;       % Collect PI_test in vector
end


  % >>>>>>>>>>>>>>>>>>>>>>  GET NETWORK OUTPUT AND GRADIENT   <<<<<<<<<<<<<<<<<<<<<<
  [PSI,E] = getgrad(method,NetDef,NN,W1,W2,Chat,Y,U);
  PI = E(skip:N)*E(skip:N)'/(2*NN2);      % Performance index
  PI_vector(reduced) = PI;                % Collect PI in vector

        
        
  % >>>>>>>>>>>>>>>>>>>>>>>>    COMPUTE THE HESSIAN MATRIX   <<<<<<<<<<<<<<<<<<<<<<
  PSI_red = PSI(theta_index,skip:N);

  % --- Inverse Hessian if no weight decay ---
  if D==0,
    H_inv = p0*eye(reduced);
    for k=1:outputs:NN2,                       % Iterative solution
      psi=PSI_red(:,(k-1)*outputs+1:k*outputs);
      H_inv = H_inv - H_inv*psi*inv(Ident + psi'*H_inv*psi)*psi'*H_inv;
    end
    FPE = (NN2+reduced)*PI/(NN2-reduced);     % FPE estimate
    gamma1 = reduced;

  % --- Inverse Hessian if weight decay is being used ---
  else
    R     = PSI_red*PSI_red';
    H     = R;
    index3   = 1:(reduced+1):(reduced^2);       % A third useful vector
    H(index3) = H(index3) + D';                 % Add weight deacy to diagonal
    H_inv = inv(H);                             % Inverse Hessian
    RHinv = R*H_inv;
    gamma1=trace(RHinv*RHinv);                  % Effective # of parameters
    gamma2=trace(RHinv);        
    FPE = (NN2+gamma1)*PI/(NN2+gamma1-2*gamma2);% FPE estimate
  end 
  FPE_vector(reduced) = FPE;                  % Collect FPE estimate
  deff_vec(reduced)=gamma1;                     % Collect effective # of param.


  % >>>>>>>>>>>>    PLOT THE PI's AND THE CURRENT NETWORK STRUCTURE   <<<<<<<<<<<<<
  % --- Draw PI's ---
  figure(1);
  pvec=[reduced pvec];
  if TestDataFlag
    plot(pvec,PI_vector(pvec),'x',pvec,FPE_vector(pvec),'+',...
    pvec,PI_test_vec(pvec),'o')
    title('x = training error,   + = FPE,   o = test error')
  else
    plot(pvec,PI_vector(pvec),'x',pvec,FPE_vector(pvec),'+')
    title('x = training error,  + = FPE')
  end
  set(gca,'Xlim',[0 reduced0]);
  xlabel('Parameters');
  drawnow
     
  % --- Draw pruned network ---
  figure(2);
  drawnet(W1,W2,eps);
  title(['Network after having pruned ',int2str(pr),' weights']);
  figure(2); drawnow


  % >>>>>>>>>>>>>  ELIMINATE THE WEIGHT HAVING THE SMALLEST SALIENCY  <<<<<<<<<<<<<
   nopruned = floor(max(1,reduced*RePercent/100)); % No of parms to prune
  if reduced<=minweights, break; end
  bpr = 1;
  while bpr<=nopruned
    if D==0,
      gamma = theta_red./diag(H_inv);             % Lagrange multipliers
      zeta = theta_red.*gamma;                    % Salincies if D=0
    else
      gamma = theta_red./diag(H_inv);             % Lagrange multipliers
      HRH   = H_inv*R*H_inv;
                                                  % Saliencies
      zeta  = gamma.*(H_inv*(D.*theta_red))+gamma.*gamma.*diag(HRH)/2; 
    end
    Critical=[]; WhileFlag=1; ElimWeights=1;
    while WhileFlag,
      WhileFlag = 0;
      HiddenLeft = hidden-length(find(ConnectToHidden==0)); % Hidden units;
      [zeta_min,min_index] = min(zeta(HiddenLeft+1:reduced));% Find smallest saliency
      min_index = min_index+HiddenLeft;
      HiddenNo = HiddenIndex(theta_index(min_index)); % Conection to hidden no?
      if theta_index(min_index)~=(hidden+1),  % Not the bias
        if ConnectToHidden(HiddenNo)==1
          HiddenNoRed = find(theta_index(1:HiddenLeft)==HiddenNo);
          Eidx = [HiddenNoRed;min_index];
          gamma_new = inv(H_inv(Eidx,Eidx))*theta_red(Eidx);
          if find(Critical==min_index);
            ElimWeights=2;
            break;
          end
          if D==0,
            zeta(min_index) = gamma_new'*theta_red(Eidx);
          else
            zeta(min_index) = gamma_new'*H_inv(Eidx,:)*(D.*theta_red)...
                              + gamma_new'*HRH(Eidx,Eidx)*gamma_new/2;
          end
          Critical = [Critical;min_index];
          WhileFlag = 1;
        end
      end
    end
    
    if theta_index(min_index)~=(hidden+1),
      ConnectToHidden(HiddenNo) = ConnectToHidden(HiddenNo) - 1;
    end
    
    % ----- Eliminate one weight -----
    if ElimWeights==1, 
      theta_red = theta_red - gamma(min_index)*H_inv(:,min_index);
      theta_red(min_index) = 0;                     % Non-zero weights
      theta(theta_index) = theta_red;
      tmpindex = [1:min_index-1 min_index+1:length(theta_index)];
      theta_index = theta_index(tmpindex);
      theta_red = theta(theta_index);
      reduced  = reduced-1;                         % Remaining weights
      theta_data(:,reduced) = theta;                % Store parameter vector
      D = D([1:min_index-1 min_index+1:length(D)]);
    
      % --- Update inverse Hessian ---
      H_inv = H_inv(tmpindex,tmpindex)-H_inv(tmpindex,min_index)...
               *H_inv(min_index,tmpindex)/H_inv(min_index,min_index);
      if D~=0, R = R(tmpindex,tmpindex); end
      
     % ----- Eliminate two weight -----
    elseif ElimWeights==2,
      theta_red = theta_red - H_inv(:,Eidx)*gamma_new;
      theta_red(Eidx) = [0;0];                      % Non-zero weights
      theta(theta_index) =theta_red;
      tmpindex = [1:Eidx(1)-1 Eidx(1)+1:Eidx(2)-1 Eidx(2)+1:length(theta_index)];
      theta_index = theta_index(tmpindex);
      theta_red = theta(theta_index);
      reduced  = reduced-2;                         % Remaining weights
      theta_data(:,[reduced reduced+1]) = [theta theta]; % Store parameter vector
      D = D(tmpindex);
    
      % --- Update inverse Hessian ---
      H_inv = H_inv(tmpindex,tmpindex)-H_inv(tmpindex,Eidx)...
               *inv(H_inv(Eidx,Eidx))*H_inv(Eidx,tmpindex);
      if D~=0, R = R(tmpindex,tmpindex); end
      bpr = bpr + 1;
    end
    bpr = bpr + 1;
  end    
  pr = pr + bpr-1;                             % Total # of pruned weights
  
 
  % -- Put the parameters back into the weight matrices --  
  if mflag==2,
    W1 = reshape(theta(parameters2+1:parameters12),inputs+1,hidden)';
    W2 = reshape(theta(1:parameters2),hidden+1,outputs)';
    Chat   = [1 theta(parameters12+1:parameters)'];
  else
    W1 = reshape(theta(parameters2+1:parameters),inputs+1,hidden)';
    W2 = reshape(theta(1:parameters2),hidden+1,outputs)';
  end
  FirstTimeFlag=0;
end
%----------------------------------------------------------------------------------
%-------------                END OF NETWORK PRUNING                  -------------
%----------------------------------------------------------------------------------
fprintf('\n\n\n  -->  Pruning session terminated  <--\n\n\n');

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -