input_relevance_snn.m
来自「神经网络的工具箱, 神经网络的工具箱,」· M 代码 · 共 104 行
M
104 行
function [expl_var, y, ds, Xall] = input_relevance_snn(net, data, F, groups)%INPUT_RELEVANCE_SNN Determine order of relevance of the inputs.%% Syntax%% [expl_var, indices] = input_relevance_snn(net, data)% [expl_var, indices] = input_relevance_snn(net, data, F)% [expl_var, indices] = input_relevance_snn(net, data, F, groups)%% Description%% INPUT_RELEVANCE_SNN takes% net - the network for which the input relevance is computed% data - the training data% F - fisher matrix (optional)% groups - cell matrix with group indices (optinal)% and returns% expl_var - the remaining explained variance% indices - indices of inputs (or groups of inputs) in order% of removal%% Algorithm%% See: P. van de Laar and T. Heskes; Pruning Using Parameter and% Neuronal Metrics, Neural Computation, 1999.%if (nargin < 4) % groups = no groups; N0 = size(net.weights{1}, 2);% groups = mat2cell_snn([1:N0], 1, ones(1,N0)); groups = tocell_snn([1:N0]);endif (nargin < 3) % F = fisherff_snn F = fisherff_snn(net, data);end% invert F; if not invertible, add a small diagonal term. oldwarning = lastwarn;lastwarn('');s = warning;warning off;F_inv = inv(F);warning(s);if lastwarn options.disp = 0; l_max = eigs(F, 1, 'LM', options); F_inv = inv(F + 1e-7 * l_max * eye(size(F,1)));endlastwarn(oldwarning)X = getx_snn(net);NGT = prod(size(groups));Xall = zeros(size(X,1), NGT); Xall(:,1) = X;y = zeros(NGT,1);ds = NaN * ones(NGT, NGT);todo = [1:NGT];for k = 1:NGT for i = todo dX = prune_input_snn(net, data, groups{i}, F_inv, X); ds(k, i) = 0.5 * dX' * F * dX; end [minima, ind] = min(ds(k,:)); winner = ind(1); y(k) = winner; E_quad(k) = minima; [dX, dF_inv] = ... prune_input_snn(net, data, groups{winner}, F_inv, X); F_inv = F_inv + dF_inv; X = X + dX; Xall(:, k+1) = X; todo = setdiff(todo, winner);end%#function wcf_snn argmin_wcf_snny_min = feval(feval(net.costFcn.name, 'argmin'), net, data);E_min = feval(net.costFcn.name, net, ... setfield(data, 'Y', repmat(y_min,1,size(data.T,2))));%E = zeros(size(Xall,2), 1);%for k = 1:size(Xall,2)% E(k) = feval(net.costFcn.name, Xall(:,k), net, data);%end%expl_var = 100 * (E_min - E)/E_min;E_0 = feval(net.costFcn.name, Xall(:,1), net, data);E = zeros(NGT, 1);for k = 1:NGT E(k) = E_0 + (E_quad(k)/E_quad(NGT))*(E_min - E_0);endexpl_var = 100 * (E_min - [E_0; E])/E_min;function B = tocell_snn(A)B = cell(size(A));for i = 1:prod(size(A)) B{i} = A(i);end
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?