input_relevance_snn.m

来自「神经网络的工具箱, 神经网络的工具箱,」· M 代码 · 共 104 行

M
104
字号
function [expl_var, y, ds, Xall] = input_relevance_snn(net, data, F, groups)%INPUT_RELEVANCE_SNN Determine order of relevance of the inputs.%% Syntax%%   [expl_var, indices] = input_relevance_snn(net, data)%   [expl_var, indices] = input_relevance_snn(net, data, F)%   [expl_var, indices] = input_relevance_snn(net, data, F, groups)%% Description%%   INPUT_RELEVANCE_SNN takes%       net    - the network for which the input relevance is computed%       data   - the training data%       F      - fisher matrix (optional)%       groups - cell matrix with group indices (optinal)%   and returns%       expl_var  - the remaining explained variance%       indices   - indices of inputs (or groups of inputs) in order%                   of removal%% Algorithm%%   See: P. van de Laar and T. Heskes; Pruning Using Parameter and%   Neuronal Metrics, Neural Computation, 1999.%if (nargin < 4)   % groups = no groups;   N0 = size(net.weights{1}, 2);%   groups = mat2cell_snn([1:N0], 1, ones(1,N0));   groups = tocell_snn([1:N0]);endif (nargin < 3) % F = fisherff_snn      F = fisherff_snn(net, data);end% invert F; if not invertible, add a small diagonal term. oldwarning = lastwarn;lastwarn('');s = warning;warning off;F_inv = inv(F);warning(s);if lastwarn    options.disp = 0;   l_max = eigs(F, 1, 'LM', options);   F_inv = inv(F + 1e-7 * l_max * eye(size(F,1)));endlastwarn(oldwarning)X = getx_snn(net);NGT = prod(size(groups));Xall = zeros(size(X,1), NGT); Xall(:,1) = X;y = zeros(NGT,1);ds = NaN * ones(NGT, NGT);todo = [1:NGT];for k = 1:NGT    for i = todo         dX = prune_input_snn(net, data, groups{i}, F_inv, X);        ds(k, i) = 0.5 * dX' * F * dX;    end    [minima, ind] = min(ds(k,:));    winner = ind(1);    y(k) = winner;    E_quad(k) = minima;    [dX, dF_inv] = ...           prune_input_snn(net, data, groups{winner}, F_inv, X);    F_inv = F_inv + dF_inv;    X = X + dX;    Xall(:, k+1) = X;    todo = setdiff(todo, winner);end%#function wcf_snn argmin_wcf_snny_min = feval(feval(net.costFcn.name, 'argmin'), net, data);E_min = feval(net.costFcn.name, net, ...            setfield(data, 'Y', repmat(y_min,1,size(data.T,2))));%E = zeros(size(Xall,2), 1);%for k = 1:size(Xall,2)%    E(k) = feval(net.costFcn.name, Xall(:,k), net, data);%end%expl_var = 100 * (E_min - E)/E_min;E_0 = feval(net.costFcn.name, Xall(:,1), net, data);E = zeros(NGT, 1);for k = 1:NGT    E(k) = E_0 + (E_quad(k)/E_quad(NGT))*(E_min - E_0);endexpl_var = 100 * (E_min - [E_0; E])/E_min;function B = tocell_snn(A)B = cell(size(A));for i = 1:prod(size(A))    B{i} = A(i);end

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?