loopy_dbn_inf_engine.m

来自「Bayes网络工具箱」· M 代码 · 共 80 行

M
80
字号
function engine = loopy_dbn_inf_engine(bnet, onodes, max_iter, fast, momentum, obj)% LOOPY_DBN_INF_ENGINE Loopy Pearl version of forwards-backwards% engine = loopy_dbn_inf_engine(bnet, onodes, max_iter, fast, momentum, object_oriented)%% 'onodes' specifies which nodes are observed.% 'max_iter' specifies the max num. forward-backward passes to perform (default: 1).% 'fast' means pre-compute indices for table manipulation (as in jtree_fast)% 'momentum' is as in loopy_pearl: default = 1 means use current msg only% 'object_oriented' = 1 will call a method to compute msgs (good for noisy-or, but slow in general)%% The model must obey the same topological restrictions as hmm_inf_engine.% In addition, each hidden node is assumed to have at most one observed child,% and each observed child is assumed to have exactly one hidden parent.     %% For details of this algorithm, see% "The Factored Frontier Algorithm for Approximate Inference in DBNs",% Kevin Murphy and Yair Weiss, submitted to NIPS 2000.%% WARNING: THIS IS HIGHLY EXPERIMENTAL CODE!if nargin < 3, max_iter = 1; endif nargin < 4 | isempty(fast), fast = 0; else fast = 1; endif nargin < 5, momentum = 0; endif nargin < 6 | isempty(obj), obj = 0; else obj = 1; endengine.object_oriented = obj;engine.fast = fast;engine.momentum = momentum;engine.max_iter = max_iter;engine.onodes = onodes;engine.marginal = [];engine.evidence = [];engine.msg = [];engine.parent_index = [];engine.child_index = [];%[engine.parent_index, engine.child_index] = mk_pearl_msg_indices(bnet); % need to unroll firstss = length(bnet.intra);hnodes = mysetdiff(1:ss, onodes);obschild = zeros(1,ss);for i=hnodes(:)'  %ocs = myintersect(children(bnet.dag, i), onodes);  ocs = children(bnet.intra, i);  assert(length(ocs) <= 1);  if length(ocs)==1    obschild(i) = ocs(1);  endendengine.obschild = obschild;engine.mult_self_ndx = [];engine.mult_parent_ndx = [];engine.marg_self_ndx = [];engine.marg_parent_ndx = [];if fast  onodes2 = [onodes(:); onodes(:)+ss];  ns = bnet.node_sizes(:);  ns(onodes2) = 1;  engine.mult_self_ndx = cell(1, ss);  engine.mult_parent_ndx = cell(ss, ss);  engine.marg_self_ndx = cell(1, ss);  engine.marg_parent_ndx = cell(ss, ss);  for i=hnodes(:)'    ps = parents(bnet.inter, i);    fam = [ps i+ss];    engine.mult_self_ndx{i} = mk_multiply_table_ndx(fam, i+ss, ns);    engine.marg_self_ndx{i} = mk_marginalise_table_ndx(fam, i+ss, ns);    for p=ps(:)'      engine.mult_parent_ndx{i,p} = mk_multiply_table_ndx(fam, p, ns);      engine.marg_parent_ndx{i,p} = mk_marginalise_table_ndx(fam, p, ns);    end  endendengine = class(engine, 'loopy_dbn_inf_engine', inf_engine(bnet));

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?