📄 ndfa_iter.m
字号:
function [sources, net, tnet, params, status, fs, tfs] = ndfa_iter(... data, sources, net, tnet, params, status, ... missing, clamped, notimedep)% NDFA_ITER Perform the NDFA iteration%% Internal function, use NDFA to call.% Copyright (C) 2002-2005 Harri Valpola, Antti Honkela and Matti Tornio.% % This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package. This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.% Make sure sources are acprobdist_alphaif (~isa(sources, 'acprobdist_alpha')) sources = updatevar(acprobdist_alpha(sources));end% Make sure networks are probdist_alpha for the old update algorithmif ~isa(net.w1, 'probdist_alpha') & ... strcmp(status.updatealg, 'old') net.w1 = probdist_alpha(net.w1); net.w2 = probdist_alpha(net.w2); net.b1 = probdist_alpha(net.b1); net.b2 = probdist_alpha(net.b2); tnet.w1 = probdist_alpha(tnet.w1); tnet.w2 = probdist_alpha(tnet.w2); tnet.b1 = probdist_alpha(tnet.b1); tnet.b2 = probdist_alpha(tnet.b2);end% Extract some model parameters for easier use[Ns T] = size(sources);Nx = size(data, 1);truesources = 1:(size(sources, 1) - status.controlchannels);nsources = size(sources, 1);ncontrols = size(status.controlchannels, 1);controls = nsources - 1 + (1:ncontrols);iters_left = status.iters;% Save first checkpoint after 100 (full) iterationsiters_before_cp = 100 - min([status.updatenet status.updatetnet ... status.updateparams status.updatesrcs ... status.updatesrcvars]);% Reset the gradient if approriate for the new update algorithmsif ~strcmp(status.updatealg, 'old'), if isfield(status, 'oldgrads'), oldgrads = status.oldgrads; else oldgrads = reset_cg([], net, tnet, sources); endendwhile iters_left > 0 dcp_dnetm = netgrad_zeros(net); dcp_dnetv = netgrad_zeros(net); dcp_dtnetm = netgrad_zeros(tnet); dcp_dtnetv = netgrad_zeros(tnet); fs = probdist(zeros(size(data)), ones(size(data))); tfs = probdist(zeros(size(sources)), ones(size(sources))); % Do feedforward calculations x = feedfw(sources, net, status.approximation); tx = acfeedfw(sources, tnet, status.approximation); fs = probdist(x{4}.e, x{4}.var); if (status.freeinitial) tfs = probdist([sources.e(:,1) tx{4}.e], ... [zeros(size(sources, 1), 1) tx{4}.var]); else tfs = probdist([zeros(size(sources, 1), 1) tx{4}.e], ... [zeros(size(sources, 1), 1) tx{4}.var]); end % Update estimates for different parameters if appropriate if status.updateparams < 0 status.updateparams = status.updateparams + 1; if (status.updateparams == 0) && (~strcmp(status.updatealg, 'old')), oldgrads = reset_cg(oldgrads, net, tnet, sources); end else params = update_params(fs, tfs, sources, data, net, tnet, params); end % Calculate the current value of Kullback-Leibler divergence [newkls, kls_data, kls_param, kls_ac, kls_nets] = kldiv(... tfs, fs, sources, data, net, tnet, params, missing, notimedep, status); % Store the cost function and elapsed CPU time status.kls = [status.kls newkls]; status.cputime = [status.cputime cputime]; % Display the current value of the cost function fprintf('Iteration #%d: %f = %f+%f+%f.\n', size(status.kls, 2) - 1, ... newkls, kls_data, kls_param + kls_nets, kls_ac) % Store the components of the cost function for debugging if status.debug >= 2, Nh = size(status.kls, 2); [t, status.history.obs(:,:,Nh)] = kl_data(... probdist(fs.e - data, fs.var), probdist(0), params.noise, 2, missing); [t, status.history.dyn(:,:,Nh)] = kl_acparam(... sources, tfs, params.src, 2, notimedep); [t, status.history.net_w2(:,:,Nh)] = kl_param(... net.w2, probdist(0), params.net.w2var, 1); [t, status.history.net_w1(:,:,Nh)] = kl_param(... net.w1, probdist(0), probdist(0)); [t, status.history.tnet_w2(:,:,Nh)] = kl_param(... tnet.w2, probdist(0), params.tnet.w2var, 1); [t, status.history.tnet_w1(:,:,Nh)] = kl_param(... tnet.w1, probdist(0), probdist(0)); status.history.noise(:,Nh) = normalvar(params.noise); status.history.src(:,Nh) = normalvar(params.src); end % Check if overflow has occurred if ~isfinite(newkls), iters_left = 0; fprintf('Cost is NaN, bailing out ...\n'); status.kls = [status.kls newkls]; status.cputime = [status.cputime cputime]; return; end % Check if the iteration has converged if (size(status.kls, 2) > 400 && ... ((min(diff(status.kls(end-10:end))) > 0) && ... (min(diff(status.kls(end-400:end))) > -status.epsilon))), fprintf('The iteration appears to have converged, bailing out...\n'); iters_left = 0; end % Calculate partial derivatives for the parameters to adapt % Feedback for the observation mapping [dcp_dsm, dcp_dsv, newdcp_dnetm, newdcp_dnetv] =... feedback(x, net, sources, data, params.noise, status, ... missing); dcp_dnetm = sum_structs(dcp_dnetm, newdcp_dnetm); dcp_dnetv = sum_structs(dcp_dnetv, newdcp_dnetv); % Feedback for the dynamical mapping [newdcp_dsm, newdcp_dsv, newdcp_dtnetm, newdcp_dtnetv] = ... feedbackac(tx, tnet, sources(:, 1:end-1), sources(:, 2:end).e, ... params.src, status, notimedep); dcp_dsm(:,1:end-1) = dcp_dsm(:,1:end-1) + newdcp_dsm; dcp_dsv(:,1:end-1) = dcp_dsv(:,1:end-1) + newdcp_dsv; dcp_dtnetm = sum_structs(dcp_dtnetm, newdcp_dtnetm); dcp_dtnetv = sum_structs(dcp_dtnetv, newdcp_dtnetv); % Feedback for the source priors [newdcp_dsm, dcp_dsvn] = ... feedback_srcpriors(sources - tfs, params.src, notimedep); dcp_dsm = dcp_dsm + newdcp_dsm; [dcp_dsvn, newac] = computevard(dcp_dsv, dcp_dsvn, sources.ac, ... tx{5}); % Drop the variance dependencies between independent samples i = sparse([zeros(Ns,1) notimedep]); newac(find(i)) = sources.ac(find(i)); % Feedback for the observation mapping priors [newdcp_dnetm, newdcp_dnetv] = ... feedback_netpriors(net, params.net, params.hyper.net); dcp_dnetm = sum_structs(dcp_dnetm, newdcp_dnetm); dcp_dnetv = sum_structs(dcp_dnetv, newdcp_dnetv); % Feedback for the dynamical mapping priors [newdcp_dtnetm, newdcp_dtnetv] = ... feedback_netpriors(tnet, params.tnet, params.hyper.tnet); dcp_dtnetm = sum_structs(dcp_dtnetm, newdcp_dtnetm); dcp_dtnetv = sum_structs(dcp_dtnetv, newdcp_dtnetv); % Calculate the 'spilled' gradients dcp_dsm2(:,1) = zeros(1, Ns); for i = 2:T, dcp_dsm2(:,i) = tx{4}.multi(:,:,i-1) * dcp_dsm2(:,i-1) + dcp_dsm(:,i); end dcp_dsm3(:,T) = zeros(1, Ns); for i = (T-1):-1:1, dcp_dsm3(:,i) = tx{4}.multi(:,:,i) * (dcp_dsm(:,i+1) + dcp_dsm3(:,i+1)); end % dcp_dsm = 1/3*(dcp_dsm + dcp_dsm2 + dcp_dsm3);% Old updatealg if strcmp(status.updatealg, 'old'), % Get new values for sources and alphas if appropriate if max([status.updatesrcs, status.updatesrcvars]) >= 0 newsources = ... updatesources(sources, dcp_dsm, dcp_dsvn, x{4}.multi, tx{5}, ... params.src, params.noise, newac, clamped); if status.updatesrcs < 0 sources.var = newsources.var; sources.nvar = newsources.nvar; sources.valpha = newsources.valpha; sources.vsign = newsources.vsign; else sources = newsources; end sources.ac = newac; sources = updatevar(sources); end if status.debug >= 2, [newc, datac, restc, dync, netc] = kldiv(... [], [], sources, data, net, tnet, params, ... missing, notimedep, status); fprintf('Cost after source update: c=%.16g\n', newc); end % Update the network and alphas if appropriate if status.updatenet >= 0 net = updatenetwork(net, dcp_dnetm, dcp_dnetv); end if status.updatetnet >= 0 tnet = updatenetwork(tnet, dcp_dtnetm, dcp_dtnetv); end % Old updatealg ends elseif strcmp(status.updatealg, 'natgrad') % Natural gradient [sources, net, tnet, oldgrads, status] = update_natgrad(... sources, net, tnet, dcp_dsm, dcp_dsvn, x{4}, tx{4}, params, ... dcp_dnetm, dcp_dnetv, dcp_dtnetm, dcp_dtnetv, newac, ... data, newkls, status, oldgrads, clamped, missing, notimedep); else % Rest of the new updatealgs [sources, net, tnet, oldgrads, status] = update_everything(... sources, net, tnet, dcp_dsm, dcp_dsvn, x{4}, tx{4}, params, ... dcp_dnetm, dcp_dnetv, dcp_dtnetm, dcp_dtnetv, newac, ... data, newkls, status, oldgrads, clamped, missing, notimedep); end if status.updatesrcs < 0 status.updatesrcs = status.updatesrcs + 1; if (status.updatesrcs == 0) && (~strcmp(status.updatealg, 'old')), oldgrads = reset_cg(oldgrads, net, tnet, sources); end end if status.updatesrcvars < 0 status.updatesrcvars = status.updatesrcvars + 1; if (status.updatesrcs == 0) && (~strcmp(status.updatealg, 'old')), sources.ac(:,2:end) = deal(1e-10); end end if status.updatenet < 0 status.updatenet = status.updatenet + 1; end if status.updatetnet < 0 status.updatetnet = status.updatetnet + 1; end iters_left = iters_left - 1; % Prune the embedded model after embed.iters has been reached if status.embed.iters < 0, status.embed.iters = status.embed.iters + 1; if status.embed.iters == 0, [data, net, sources, params, clamped, notimedep, missing] = prune(... data, net, sources, params, clamped, notimedep, missing,... status.embed.datadim, status.embed.timedim); fprintf('Pruning model to match the original data.\n'); end end % Prune neurons and sources if approriate if status.prune.hidneurons > 0 || status.prune.thidneurons > 0 || ... status.prune.sources > 0, if status.prune.iters == 0, fprintf('Pruning model.\n'); [sources, net, tnet, params, status] = prune_neurons(... data, sources, net, tnet, params, missing, notimedep, status); if status.prune.hidneurons > 0 || status.prune.thidneurons > 0 || ..., status.prune.sources > 0, status.prune.iters = -10; status.updatesrcs = 0; status.updatesrcvars = 0; status.updatenet = 0; status.updatetnet = 0; end else status.prune.iters = status.prune.iters + 1; end end % Force constraints on control channels if isfield(status, 'constraints'), constraints = status.constraints(Nx+1:end,:); sources.e(controls,:) = ... cut(sources.e(controls,:), ... repmat(constraints(:,1), 1, T), ... repmat(constraints(:,2), 1, T)); end% if mod(size(status.kls, 2), 5) == 0,% % result = pack_result(sources, net, tnet, params, status, ...% clamped, notimedep);% save(['cp' num2str(size(status.kls, 2))], 'result');% test_if_nan(result);% end % Checkpoints after the first are made every 100 iterations iters_before_cp = iters_before_cp - 1; if iters_before_cp == 0, result = pack_result(sources, net, tnet, params, status, ... clamped, notimedep); save cp data result; iters_before_cp = 100; endend% Store the old gradient for new updatealgsif ~strcmp(status.updatealg, 'old'), status.oldgrads = oldgrads;end% Do feedforward calculationsfs = probdist(zeros(size(data)), ones(size(data)));tfs = probdist(zeros(size(sources)), ones(size(sources)));x = feedfw(sources, net);tx = acfeedfw(sources, tnet);fs = probdist(x{4}.e, x{4}.var);if (status.freeinitial) tfs = probdist([sources.e(:,1) tx{4}.e], ... [zeros(size(sources, 1), 1) tx{4}.var]);else tfs = probdist([zeros(size(sources, 1), 1) tx{4}.e], ... [zeros(size(sources, 1), 1) tx{4}.var]);end% Calculate and display the current value of Kullback-Leibler divergence[newkls, kls_data, kls_param, kls_ac, kls_nets] = kldiv(... tfs, fs, sources, data, net, tnet, params, missing, notimedep, status);fprintf('Finally after %d iterations: %f = %f+%f+%f.\n', ... size(status.kls, 2), newkls, kls_data, kls_param + kls_nets, kls_ac)function result = pack_result(... sources, net, tnet, params, status, clamped, notimedep),% PACK_RESULT Packs everything to a single structure% Copyright (C) 2002-2005 Harri Valpola, Antti Honkela and Matti Tornio.% % This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package. This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.result.sources = sources;result.net = net;result.tnet = tnet;result.params = params;result.status = status;result.clamped = clamped;result.notimedep = notimedep;function [dc_dsm, dc_dsvn] = feedback_srcpriors(sources, srcparams, ignore)% FEEDBACK_SRCPRIORS Calculate the contribution of source priors% to the gradients of the cost function with respect to source values%% Copyright (C) 2002 Harri Valpola and Antti Honkela.%% This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package. This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.if nargin >= 3 & ~isempty(ignore), ignore = full([ones(size(sources.e, 1), 1) ignore]); sources.e = sources.e .* ~ignore;endsourcevar = normalvar(srcparams);nsampl = size(sources, 2);temp = sourcevar * ones(1, nsampl);dc_dsm = sources.e ./ temp;dc_dsvn = .5 ./ temp;function [dc_dnetm, dc_dnetv] = feedback_netpriors(net, params, hypers)% FEEDBACK_NETPRIORS Calculate the contribution of network priors% to the gradients of the cost function with respect to network weights%% Copyright (C) 2002 Harri Valpola and Antti Honkela.%% This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package. This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.w1var = ones(1, size(net.w1, 2));w2var = normalvar(params.w2var);[dc_dnetm.w2, dc_dnetv.w2, dc_dnetm.b2, dc_dnetv.b2] = ... netgradsprior(net.w2, net.b2, w2var, hypers.b2);[dc_dnetm.w1, dc_dnetv.w1, dc_dnetm.b1, dc_dnetv.b1] = ... netgradsprior(net.w1, net.b1, w1var, hypers.b1);function oldgrads = reset_cg(oldgrads, net, tnet, sources),% RESET_CG Reset the stored (conjugate) gradient% Copyright (C) 2002-2005 Harri Valpola, Antti Honkela and Matti Tornio.% % This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package. This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.fprintf('Resetting CG\n');oldgrads.net = netgrad_zeros(net);oldgrads.tnet = netgrad_zeros(tnet);oldgrads.s = zeros(size(sources));oldgrads.norm = 0;function test_if_nan(result),if (any(any(result.sources.var < 0)) | ... any(any(result.net.w1.var < 0)) | ... any(any(result.net.w2.var < 0)) | ... any(any(result.net.b1.var < 0)) | ... any(any(result.net.b2.var < 0)) | ... any(any(result.tnet.w1.var < 0)) | ... any(any(result.tnet.w2.var < 0)) | ... any(any(result.tnet.b1.var < 0)) | ... any(any(result.tnet.b2.var < 0))), warning('Negative variance detected');end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -