⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nlfa_iter.m

📁 非线型因素分析matlab仿真程序包
💻 M
字号:
function [sources, net, params, status, fs] = ...    nlfa_iter(data, sources, net, params, status)% NLFA_ITER  Perform the NLFA iteration% Copyright (C) 1999-2004 Antti Honkela, Harri Valpola,% and Xavier Giannakopoulos.%% This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package.  This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.nsampl = size(data, 2);%nlfa_batches = 1:status.batch_size:nsampl;%%nlfa_batch = [nlfa_batches', [nlfa_batches(2:end)-1, nsampl]'];iters_left = status.iters;if ~strcmp(status.updatealg, 'old'),  if isfield(status, 'oldgrads') && status.cgreset ~= -1,    oldgrads = status.oldgrads;  else    fprintf('Resetting CG\n');    oldgrads.net = netgrad_zeros(net);    oldgrads.s = zeros(size(sources));    oldgrads.norm = 0;  endendwhile iters_left > 0  dcp_dnetm = netgrad_zeros(net);  dcp_dnetv = netgrad_zeros(net);  fs = probdist(zeros(size(data)), ones(size(data)));  newkls = kl_static(net, params);    %  for k = 1:size(nlfa_batch, 1),  %curbatch = nlfa_batch(k,1):nlfa_batch(k,2);  curbatch = 1:nsampl;    % Do feedforward calculations  x = feedfw( sources(:, curbatch) , net, status.approximation);  fs(:, curbatch) = probdist(x{4}.e, x{4}.var);  % Calculate and possibly display current value of the cost function  newkls = newkls + kl_batch(fs(:, curbatch), sources(:, curbatch), ...			     data(:, curbatch), params);      %if k == size(nlfa_batch, 1)  fprintf('Iteration #%d: %f\n', size(status.kls, 2), newkls);  if isnan(newkls),    iters_left = 0;    %if size(nlfa_batch, 1) == 1,    fprintf('Cost is NaN, bailing out...\n');    return    %end  end  if (size(status.kls, 2) > 400 && ...      ((min(diff(status.kls(end-10:end))) > 0) || ...       (min(diff(status.kls(end-200:end))) > -status.epsilon))),    fprintf('The iteration appears to have converged, bailing out...\n');    iters_left = 0;  end    status.kls = [status.kls newkls];  status.cputime = [status.cputime cputime];  %end  % Calculate partial derivatives for parameters to adapt  [dcp_dsm, dcp_dsv, newdcp_dnetm, newdcp_dnetv] =...      feedback(x, net, sources(:, curbatch), data(:, curbatch), ...	       params.noise, status);  [newdcp_dsm, newdcp_dsv] = ...      feedback_srcpriors(sources(:, curbatch), params.src);  dcp_dsm = dcp_dsm + newdcp_dsm;  dcp_dsv = dcp_dsv + newdcp_dsv;  dcp_dnetm = sum_structs(dcp_dnetm, newdcp_dnetm);  dcp_dnetv = sum_structs(dcp_dnetv, newdcp_dnetv);  [newdcp_dnetm, newdcp_dnetv] = ...      feedback_netpriors(net, params.net, params.hyper.net);  dcp_dnetm = sum_structs(dcp_dnetm, newdcp_dnetm);  dcp_dnetv = sum_structs(dcp_dnetv, newdcp_dnetv);  if strcmp(status.updatealg, 'old'),    % Get new values for sources and alphas if appropriate    if max([status.updatesrcs, status.updatesrcvars]) >= 0      sources = probdist_alpha(sources);      newsources = ...          updatesources(sources(:, curbatch), dcp_dsm, dcp_dsv, x{4}, ...                        params.src, params.noise);      if status.updatesrcs < 0        sources = ...            probdist_alpha(sources.e(:, curbatch), newsources.var, ...                           sources.malpha(:, curbatch), newsources.valpha, ...                           sources.msign(:, curbatch), newsources.vsign);      else        sources = newsources;      end    end    if status.updatenet >= 0      net = updatenetwork(net, dcp_dnetm, dcp_dnetv);    end  else % new updatealg    [sources, net, oldgrads, status] = update_everything(...	sources, net, dcp_dsm, dcp_dsv, x{4}, params, dcp_dnetm, dcp_dnetv, ...	data, newkls, status, oldgrads);      if (status.cgreset > 0) && (mod(length(status.kls), status.cgreset) == 0),      fprintf('Resetting CG\n');      oldgrads.net = netgrad_zeros(net);      oldgrads.s = zeros(size(sources));      oldgrads.norm = 0;    end  end % updatealg    if status.updatesrcs < 0    status.updatesrcs = status.updatesrcs + 1;    if (status.updatesrcs == 0) && (~strcmp(status.updatealg, 'old')),      fprintf('Resetting CG\n');      oldgrads.net = netgrad_zeros(net);      oldgrads.s = zeros(size(sources));      oldgrads.norm = 0;    end  end  if status.updatesrcvars < 0    status.updatesrcvars = status.updatesrcvars + 1;  end  if status.updatenet < 0    status.updatenet = status.updatenet + 1;  end  % Update estimates for different parameters if appropriate  if status.updateparams < 0    status.updateparams = status.updateparams + 1;    if (status.updateparams == 0) && (~strcmp(status.updatealg, 'old')),      fprintf('Resetting CG\n');      oldgrads.net = netgrad_zeros(net);      oldgrads.s = zeros(size(sources));      oldgrads.norm = 0;    end  else    params.noise = estimatevars(probdist(fs.e-data, fs.var), ...				params.hyper.noise, params.noise);    params.src   = estimatevars(sources, params.hyper.src, params.src);    params.net.w2var = estimatevars(net.w2, params.hyper.net.w2var, ...				    params.net.w2var, 1);        [params.hyper.net.w2var.mean, params.hyper.net.w2var.var] = ...	estimatemeanvars(params.net.w2var, params.prior.net.w2var.mean, ...		       params.prior.net.w2var.var, params.hyper.net.w2var.var);    [params.hyper.noise.mean, params.hyper.noise.var] = ...	estimatemeanvars(params.noise, params.prior.noise.mean, ...		       params.prior.noise.var, params.hyper.noise.var, 1);    [params.hyper.net.b1.mean, params.hyper.net.b1.var] = ...	estimatemeanvars(net.b1, params.prior.net.b1.mean, ...		       params.prior.net.b1.var, params.hyper.net.b1.var, 1);    [params.hyper.net.b2.mean, params.hyper.net.b2.var] = ...	estimatemeanvars(net.b2, params.prior.net.b2.mean, ...		       params.prior.net.b2.var, params.hyper.net.b2.var, 1);    [params.hyper.src.mean, params.hyper.src.var] = ...	estimatemeanvars(params.src, params.prior.src.mean, ...		       params.prior.src.var, params.hyper.src.var, 1);  end  if strcmp(status.updatealg, 'old'),    if (size(sources, 1) > 1),      [sources, net, params] = ...	  scalesources(sources, net, params);    end  end    iters_left = iters_left - 1;endif ~strcmp(status.updatealg, 'old'),  status.oldgrads = oldgrads;endfs = probdist(zeros(size(data)), ones(size(data)));newkls = kl_static(net, params);% Do feedforward calculationscurbatch = 1:nsampl;x = feedfw( sources(:, curbatch) , net, status.approximation);fs(:, curbatch) = probdist(x{4}.e, x{4}.var);newkls = newkls + kl_batch(fs(:, curbatch), sources(:, curbatch), ...			   data(:, curbatch), params);fprintf('Finally after %d iterations: %f\n', size(status.kls, 2), newkls);function [dc_dsm, dc_dsv] = feedback_srcpriors(sources, srcparams)% FEEDBACK_SRCPRIORS Calculate the contribution of source priors%   to the gradients of the cost function with respect to source valuessourcevar = normalvar(srcparams);nsampl = size(sources, 2);temp = sourcevar * ones(1, nsampl);dc_dsm = sources.e ./ temp;dc_dsv = .5 ./ temp;function [dc_dnetm, dc_dnetv] = feedback_netpriors(net, params, hypers)% FEEDBACK_NETPRIORS Calculate the contribution of network priors%   to the gradients of the cost function with respect to network weightsw1var = ones(1, size(net.w1, 2));w2var = normalvar(params.w2var);[dc_dnetm.w2, dc_dnetv.w2, dc_dnetm.b2, dc_dnetv.b2] = ...    netgradsprior(net.w2, net.b2, w2var, hypers.b2);[dc_dnetm.w1, dc_dnetv.w1, dc_dnetm.b1, dc_dnetv.b1] = ...    netgradsprior(net.w1, net.b1, w1var, hypers.b1);function [dcp_dwm, dcp_dwv, dcp_dbm, dcp_dbv] = ...    netgradsprior(w, b, wprior, bprior)% NETGRADSPRIOR Calculate the contribution of priors to partial%   derivatives of kldiv with respect to network weightswpvar = repmat(wprior, [size(w, 1) 1]);bpexp = repmat(bprior.mean.e, size(b));bpvar = repmat(normalvar(bprior.var), size(b));dcp_dwm = w.e ./ wpvar;dcp_dwv = .5 ./ wpvar;dcp_dbm = (b.e - bpexp) ./ bpvar;dcp_dbv = .5 ./ bpvar;function grad = netgrad_zeros(net)grad.w2 = zeros(size(net.w2));grad.b2 = zeros(size(net.b2));grad.w1 = zeros(size(net.w1));grad.b1 = zeros(size(net.b1));function s = sum_structs(s1, s2)% SUM_STRUCTS  Add all the fields of two structures togetherf = fieldnames(s1);c1 = struct2cell(s1);c2 = struct2cell(s2);if size(c1) ~= size(c2)  error('sum_structs: Structures must be of same type')endc = cell(size(c1));for k=1:length(c1),  c{k} = c1{k} + c2{k};ends = cell2struct(c, f, 1);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -