loopy_dbn_inf_engine.m
来自「Bayes网络工具箱」· M 代码 · 共 80 行
M
80 行
function engine = loopy_dbn_inf_engine(bnet, onodes, max_iter, fast, momentum, obj)% LOOPY_DBN_INF_ENGINE Loopy Pearl version of forwards-backwards% engine = loopy_dbn_inf_engine(bnet, onodes, max_iter, fast, momentum, object_oriented)%% 'onodes' specifies which nodes are observed.% 'max_iter' specifies the max num. forward-backward passes to perform (default: 1).% 'fast' means pre-compute indices for table manipulation (as in jtree_fast)% 'momentum' is as in loopy_pearl: default = 1 means use current msg only% 'object_oriented' = 1 will call a method to compute msgs (good for noisy-or, but slow in general)%% The model must obey the same topological restrictions as hmm_inf_engine.% In addition, each hidden node is assumed to have at most one observed child,% and each observed child is assumed to have exactly one hidden parent. %% For details of this algorithm, see% "The Factored Frontier Algorithm for Approximate Inference in DBNs",% Kevin Murphy and Yair Weiss, submitted to NIPS 2000.%% WARNING: THIS IS HIGHLY EXPERIMENTAL CODE!if nargin < 3, max_iter = 1; endif nargin < 4 | isempty(fast), fast = 0; else fast = 1; endif nargin < 5, momentum = 0; endif nargin < 6 | isempty(obj), obj = 0; else obj = 1; endengine.object_oriented = obj;engine.fast = fast;engine.momentum = momentum;engine.max_iter = max_iter;engine.onodes = onodes;engine.marginal = [];engine.evidence = [];engine.msg = [];engine.parent_index = [];engine.child_index = [];%[engine.parent_index, engine.child_index] = mk_pearl_msg_indices(bnet); % need to unroll firstss = length(bnet.intra);hnodes = mysetdiff(1:ss, onodes);obschild = zeros(1,ss);for i=hnodes(:)' %ocs = myintersect(children(bnet.dag, i), onodes); ocs = children(bnet.intra, i); assert(length(ocs) <= 1); if length(ocs)==1 obschild(i) = ocs(1); endendengine.obschild = obschild;engine.mult_self_ndx = [];engine.mult_parent_ndx = [];engine.marg_self_ndx = [];engine.marg_parent_ndx = [];if fast onodes2 = [onodes(:); onodes(:)+ss]; ns = bnet.node_sizes(:); ns(onodes2) = 1; engine.mult_self_ndx = cell(1, ss); engine.mult_parent_ndx = cell(ss, ss); engine.marg_self_ndx = cell(1, ss); engine.marg_parent_ndx = cell(ss, ss); for i=hnodes(:)' ps = parents(bnet.inter, i); fam = [ps i+ss]; engine.mult_self_ndx{i} = mk_multiply_table_ndx(fam, i+ss, ns); engine.marg_self_ndx{i} = mk_marginalise_table_ndx(fam, i+ss, ns); for p=ps(:)' engine.mult_parent_ndx{i,p} = mk_multiply_table_ndx(fam, p, ns); engine.marg_parent_ndx{i,p} = mk_marginalise_table_ndx(fam, p, ns); end endendengine = class(engine, 'loopy_dbn_inf_engine', inf_engine(bnet));
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?