⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ndfa_iter.m

📁 Nonlinear dynamical factor analysis Matlab package
💻 M
字号:
function [sources, net, tnet, params, status, fs, tfs] = ndfa_iter(...    data, sources, net, tnet, params, status, ...    missing, clamped, notimedep)% NDFA_ITER Perform the NDFA iteration%% Internal function, use NDFA to call.% Copyright (C) 2002-2005 Harri Valpola, Antti Honkela and Matti Tornio.% % This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package.  This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.% Make sure sources are acprobdist_alphaif (~isa(sources, 'acprobdist_alpha'))  sources = updatevar(acprobdist_alpha(sources));end% Make sure networks are probdist_alpha for the old update algorithmif ~isa(net.w1, 'probdist_alpha') & ...     strcmp(status.updatealg, 'old')  net.w1 = probdist_alpha(net.w1);  net.w2 = probdist_alpha(net.w2);  net.b1 = probdist_alpha(net.b1);  net.b2 = probdist_alpha(net.b2);  tnet.w1 = probdist_alpha(tnet.w1);  tnet.w2 = probdist_alpha(tnet.w2);  tnet.b1 = probdist_alpha(tnet.b1);  tnet.b2 = probdist_alpha(tnet.b2);end% Extract some model parameters for easier use[Ns T] = size(sources);Nx = size(data, 1);truesources = 1:(size(sources, 1) - status.controlchannels);nsources = size(sources, 1);ncontrols = size(status.controlchannels, 1);controls = nsources - 1 + (1:ncontrols);iters_left = status.iters;% Save first checkpoint after 100 (full) iterationsiters_before_cp = 100 - min([status.updatenet status.updatetnet ...			      status.updateparams status.updatesrcs ... 			      status.updatesrcvars]);% Reset the gradient if approriate for the new update algorithmsif ~strcmp(status.updatealg, 'old'),  if isfield(status, 'oldgrads'),    oldgrads = status.oldgrads;  else    oldgrads = reset_cg([], net, tnet, sources);  endendwhile iters_left > 0  dcp_dnetm = netgrad_zeros(net);  dcp_dnetv = netgrad_zeros(net);  dcp_dtnetm = netgrad_zeros(tnet);  dcp_dtnetv = netgrad_zeros(tnet);  fs = probdist(zeros(size(data)), ones(size(data)));  tfs = probdist(zeros(size(sources)), ones(size(sources)));  % Do feedforward calculations  x = feedfw(sources, net, status.approximation);  tx = acfeedfw(sources, tnet, status.approximation);    fs = probdist(x{4}.e, x{4}.var);  if (status.freeinitial)    tfs = probdist([sources.e(:,1) tx{4}.e], ...		   [zeros(size(sources, 1), 1) tx{4}.var]);  else    tfs = probdist([zeros(size(sources, 1), 1) tx{4}.e], ...		   [zeros(size(sources, 1), 1) tx{4}.var]);  end    % Update estimates for different parameters if appropriate  if status.updateparams < 0    status.updateparams = status.updateparams + 1;    if (status.updateparams == 0) && (~strcmp(status.updatealg, 'old')),      oldgrads = reset_cg(oldgrads, net, tnet, sources);    end  else    params = update_params(fs, tfs, sources, data, net, tnet, params);  end    % Calculate the current value of Kullback-Leibler divergence  [newkls, kls_data, kls_param, kls_ac, kls_nets] = kldiv(...      tfs, fs, sources, data, net, tnet, params, missing, notimedep, status);  % Store the cost function and elapsed CPU time  status.kls = [status.kls newkls];  status.cputime = [status.cputime cputime];  % Display the current value of the cost function  fprintf('Iteration #%d: %f = %f+%f+%f.\n', size(status.kls, 2) - 1, ...	  newkls, kls_data, kls_param + kls_nets, kls_ac)  % Store the components of the cost function for debugging  if status.debug >= 2,    Nh = size(status.kls, 2);    [t, status.history.obs(:,:,Nh)] = kl_data(...	probdist(fs.e - data, fs.var), probdist(0), params.noise, 2, missing);    [t, status.history.dyn(:,:,Nh)] = kl_acparam(...	sources, tfs, params.src, 2, notimedep);    [t, status.history.net_w2(:,:,Nh)]  = kl_param(...	net.w2,  probdist(0), params.net.w2var, 1);    [t, status.history.net_w1(:,:,Nh)]  = kl_param(...	net.w1,  probdist(0), probdist(0));    [t, status.history.tnet_w2(:,:,Nh)] = kl_param(...	tnet.w2, probdist(0), params.tnet.w2var, 1);    [t, status.history.tnet_w1(:,:,Nh)] = kl_param(...	tnet.w1, probdist(0), probdist(0));    status.history.noise(:,Nh) = normalvar(params.noise);    status.history.src(:,Nh) = normalvar(params.src);  end    % Check if overflow has occurred  if ~isfinite(newkls),    iters_left = 0;    fprintf('Cost is NaN, bailing out ...\n');    status.kls = [status.kls newkls];    status.cputime = [status.cputime cputime];    return;  end   % Check if the iteration has converged  if (size(status.kls, 2) > 400 && ...      ((min(diff(status.kls(end-10:end))) > 0) && ...       (min(diff(status.kls(end-400:end))) > -status.epsilon))),    fprintf('The iteration appears to have converged, bailing out...\n');    iters_left = 0;  end    % Calculate partial derivatives for the parameters to adapt  % Feedback for the observation mapping  [dcp_dsm, dcp_dsv, newdcp_dnetm, newdcp_dnetv] =...      feedback(x, net, sources, data, params.noise, status, ...	       missing);    dcp_dnetm = sum_structs(dcp_dnetm, newdcp_dnetm);  dcp_dnetv = sum_structs(dcp_dnetv, newdcp_dnetv);    % Feedback for the dynamical mapping  [newdcp_dsm, newdcp_dsv, newdcp_dtnetm, newdcp_dtnetv] = ...      feedbackac(tx, tnet, sources(:, 1:end-1), sources(:, 2:end).e, ...		 params.src, status, notimedep);  dcp_dsm(:,1:end-1) = dcp_dsm(:,1:end-1) + newdcp_dsm;  dcp_dsv(:,1:end-1) = dcp_dsv(:,1:end-1) + newdcp_dsv;  dcp_dtnetm = sum_structs(dcp_dtnetm, newdcp_dtnetm);  dcp_dtnetv = sum_structs(dcp_dtnetv, newdcp_dtnetv);  % Feedback for the source priors  [newdcp_dsm, dcp_dsvn] = ...      feedback_srcpriors(sources - tfs, params.src, notimedep);  dcp_dsm = dcp_dsm + newdcp_dsm;  [dcp_dsvn, newac] = computevard(dcp_dsv, dcp_dsvn, sources.ac, ...				  tx{5});  % Drop the variance dependencies between independent samples  i = sparse([zeros(Ns,1) notimedep]);  newac(find(i)) = sources.ac(find(i));      % Feedback for the observation mapping priors  [newdcp_dnetm, newdcp_dnetv] = ...      feedback_netpriors(net, params.net, params.hyper.net);    dcp_dnetm = sum_structs(dcp_dnetm, newdcp_dnetm);  dcp_dnetv = sum_structs(dcp_dnetv, newdcp_dnetv);    % Feedback for the dynamical mapping priors  [newdcp_dtnetm, newdcp_dtnetv] = ...      feedback_netpriors(tnet, params.tnet, params.hyper.tnet);  dcp_dtnetm = sum_structs(dcp_dtnetm, newdcp_dtnetm);  dcp_dtnetv = sum_structs(dcp_dtnetv, newdcp_dtnetv);    % Calculate the 'spilled' gradients  dcp_dsm2(:,1) = zeros(1, Ns);  for i = 2:T,    dcp_dsm2(:,i) = tx{4}.multi(:,:,i-1) * dcp_dsm2(:,i-1) + dcp_dsm(:,i);  end  dcp_dsm3(:,T) = zeros(1, Ns);  for i = (T-1):-1:1,    dcp_dsm3(:,i) = tx{4}.multi(:,:,i) * (dcp_dsm(:,i+1) + dcp_dsm3(:,i+1));  end  %  dcp_dsm = 1/3*(dcp_dsm + dcp_dsm2 + dcp_dsm3);% Old updatealg  if strcmp(status.updatealg, 'old'),    % Get new values for sources and alphas if appropriate    if max([status.updatesrcs, status.updatesrcvars]) >= 0      newsources = ...          updatesources(sources, dcp_dsm, dcp_dsvn, x{4}.multi, tx{5}, ...	  	        params.src, params.noise, newac, clamped);      if status.updatesrcs < 0        sources.var = newsources.var;        sources.nvar = newsources.nvar;        sources.valpha = newsources.valpha;        sources.vsign = newsources.vsign;      else        sources = newsources;      end      sources.ac = newac;      sources = updatevar(sources);    end    if status.debug >= 2,      [newc, datac, restc, dync, netc] = kldiv(...	  [], [], sources, data, net, tnet, params, ...	  missing, notimedep, status);      fprintf('Cost after source update:  c=%.16g\n', newc);    end        % Update the network and alphas if appropriate      if status.updatenet >= 0      net = updatenetwork(net, dcp_dnetm, dcp_dnetv);    end    if status.updatetnet >= 0      tnet = updatenetwork(tnet, dcp_dtnetm, dcp_dtnetv);    end    % Old updatealg ends  elseif strcmp(status.updatealg, 'natgrad') % Natural gradient    [sources, net, tnet, oldgrads, status] = update_natgrad(...	sources, net, tnet, dcp_dsm, dcp_dsvn, x{4}, tx{4}, params, ...	dcp_dnetm, dcp_dnetv, dcp_dtnetm, dcp_dtnetv, newac, ...	data, newkls, status, oldgrads, clamped, missing, notimedep);  else % Rest of the new updatealgs    [sources, net, tnet, oldgrads, status] = update_everything(...        sources, net, tnet, dcp_dsm, dcp_dsvn, x{4}, tx{4}, params, ...        dcp_dnetm, dcp_dnetv, dcp_dtnetm, dcp_dtnetv, newac, ...	data, newkls, status, oldgrads, clamped, missing, notimedep);  end  if status.updatesrcs < 0    status.updatesrcs = status.updatesrcs + 1;    if (status.updatesrcs == 0) && (~strcmp(status.updatealg, 'old')),      oldgrads = reset_cg(oldgrads, net, tnet, sources);    end  end  if status.updatesrcvars < 0    status.updatesrcvars = status.updatesrcvars + 1;    if (status.updatesrcs == 0) && (~strcmp(status.updatealg, 'old')),      sources.ac(:,2:end) = deal(1e-10);    end  end  if status.updatenet < 0    status.updatenet = status.updatenet + 1;  end  if status.updatetnet < 0    status.updatetnet = status.updatetnet + 1;  end    iters_left = iters_left - 1;  % Prune the embedded model after embed.iters has been reached  if status.embed.iters < 0,    status.embed.iters = status.embed.iters + 1;    if status.embed.iters == 0,      [data, net, sources, params, clamped, notimedep, missing] = prune(...	  data, net, sources, params, clamped, notimedep, missing,...	  status.embed.datadim, status.embed.timedim);      fprintf('Pruning model to match the original data.\n');    end         end    % Prune neurons and sources if approriate  if status.prune.hidneurons > 0 || status.prune.thidneurons > 0 || ...	status.prune.sources > 0,    if status.prune.iters == 0,      fprintf('Pruning model.\n');      [sources, net, tnet, params, status] = prune_neurons(...	  data, sources, net, tnet, params, missing, notimedep, status);      if status.prune.hidneurons > 0 || status.prune.thidneurons > 0 || ...,	status.prune.sources > 0,	status.prune.iters = -10;	status.updatesrcs = 0;	status.updatesrcvars = 0;	status.updatenet = 0;	status.updatetnet = 0;      end    else      status.prune.iters = status.prune.iters + 1;    end  end  % Force constraints on control channels  if isfield(status, 'constraints'),    constraints = status.constraints(Nx+1:end,:);    sources.e(controls,:) = ...	cut(sources.e(controls,:), ...	    repmat(constraints(:,1), 1, T), ...	    repmat(constraints(:,2), 1, T));  end%  if mod(size(status.kls, 2), 5) == 0,%    %    result = pack_result(sources, net, tnet, params, status, ...%			 clamped, notimedep);%    save(['cp' num2str(size(status.kls, 2))], 'result');%    test_if_nan(result);%  end        % Checkpoints after the first are made every 100 iterations  iters_before_cp = iters_before_cp - 1;  if iters_before_cp == 0,    result = pack_result(sources, net, tnet, params, status, ...			 clamped, notimedep);    save cp data result;    iters_before_cp = 100;  endend% Store the old gradient for new updatealgsif ~strcmp(status.updatealg, 'old'),  status.oldgrads = oldgrads;end% Do feedforward calculationsfs = probdist(zeros(size(data)), ones(size(data)));tfs = probdist(zeros(size(sources)), ones(size(sources)));x = feedfw(sources, net);tx = acfeedfw(sources, tnet);fs = probdist(x{4}.e, x{4}.var);if (status.freeinitial)  tfs = probdist([sources.e(:,1) tx{4}.e], ...		 [zeros(size(sources, 1), 1) tx{4}.var]);else  tfs = probdist([zeros(size(sources, 1), 1) tx{4}.e], ...		 [zeros(size(sources, 1), 1) tx{4}.var]);end% Calculate and display the current value of Kullback-Leibler divergence[newkls, kls_data, kls_param, kls_ac, kls_nets] = kldiv(...  tfs, fs, sources, data, net, tnet, params, missing, notimedep, status);fprintf('Finally after %d iterations: %f = %f+%f+%f.\n', ...  size(status.kls, 2), newkls, kls_data, kls_param + kls_nets, kls_ac)function result = pack_result(...    sources, net, tnet, params, status, clamped, notimedep),% PACK_RESULT  Packs everything to a single structure% Copyright (C) 2002-2005 Harri Valpola, Antti Honkela and Matti Tornio.% % This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package.  This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.result.sources = sources;result.net = net;result.tnet = tnet;result.params = params;result.status = status;result.clamped = clamped;result.notimedep = notimedep;function [dc_dsm, dc_dsvn] = feedback_srcpriors(sources, srcparams, ignore)% FEEDBACK_SRCPRIORS Calculate the contribution of source priors%   to the gradients of the cost function with respect to source values%% Copyright (C) 2002 Harri Valpola and Antti Honkela.%% This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package.  This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.if nargin >= 3 & ~isempty(ignore),  ignore = full([ones(size(sources.e, 1), 1) ignore]);  sources.e = sources.e .* ~ignore;endsourcevar = normalvar(srcparams);nsampl = size(sources, 2);temp = sourcevar * ones(1, nsampl);dc_dsm = sources.e ./ temp;dc_dsvn = .5 ./ temp;function [dc_dnetm, dc_dnetv] = feedback_netpriors(net, params, hypers)% FEEDBACK_NETPRIORS Calculate the contribution of network priors%   to the gradients of the cost function with respect to network weights%% Copyright (C) 2002 Harri Valpola and Antti Honkela.%% This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package.  This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.w1var = ones(1, size(net.w1, 2));w2var = normalvar(params.w2var);[dc_dnetm.w2, dc_dnetv.w2, dc_dnetm.b2, dc_dnetv.b2] = ...    netgradsprior(net.w2, net.b2, w2var, hypers.b2);[dc_dnetm.w1, dc_dnetv.w1, dc_dnetm.b1, dc_dnetv.b1] = ...    netgradsprior(net.w1, net.b1, w1var, hypers.b1);function oldgrads = reset_cg(oldgrads, net, tnet, sources),% RESET_CG Reset the stored (conjugate) gradient% Copyright (C) 2002-2005 Harri Valpola, Antti Honkela and Matti Tornio.% % This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package.  This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.fprintf('Resetting CG\n');oldgrads.net = netgrad_zeros(net);oldgrads.tnet = netgrad_zeros(tnet);oldgrads.s = zeros(size(sources));oldgrads.norm = 0;function test_if_nan(result),if (any(any(result.sources.var < 0)) | ...    any(any(result.net.w1.var < 0)) | ...    any(any(result.net.w2.var < 0)) | ...    any(any(result.net.b1.var < 0)) | ...    any(any(result.net.b2.var < 0)) | ...    any(any(result.tnet.w1.var < 0)) | ...    any(any(result.tnet.w2.var < 0)) | ...    any(any(result.tnet.b1.var < 0)) | ...    any(any(result.tnet.b2.var < 0))),  warning('Negative variance detected');end  

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -