marginal_nodes.m

来自「Bayes网络工具箱」· M 代码 · 共 85 行

M
85
字号
function marginal = marginal_nodes(engine, nodes, t, isfam)% MARGINAL_NODES Compute the marginal on the specified query nodes (bk_fast)%%   marginal = marginal_nodes(engine, i, t)% returns Pr(X(i,t) | Y(1:T)), where X(i,t) is the i'th node in the t'th slice.% If enter_evidence used filtering instead of smoothing, this will return  Pr(X(i,t) | Y(1:t)).%%   marginal = marginal_nodes(engine, query, t)% returns Pr(X(query(1),t), ... X(query(end),t) | Y(1:T)),% where X(q,t) is the q'th node in the t'th slice. If q > ss (slice size), this is equal% to X(q mod ss, t+1). That is, 't' specifies the time slice of the earliest node.% 'query' cannot span more than 2 time slices.% Example:% Consider a DBN with 2 nodes per slice.% Then t=2, nodes=[1 3] refers to node 1 in slice 2 and node 1 in slice 3.if nargin < 3, t = 1; endif nargin < 4, isfam = 0; end% clpot{t} contains slice t-1 and t% Example% clpot #: 1    2    3% slices:  1  1,2  2,3% For filtering, we must take care not to take future evidence into account.% For smoothing, clpot{1} does not exist.bnet = bnet_from_engine(engine);ss = length(bnet.intra);nodes2 = nodes;if ~engine.filter  if t < engine.T    slice = t+1;  else % earliest t is T, so all nodes fit in one slice    slice = engine.T;    nodes2 = nodes + ss;  endelse  if t == 1   slice = 1;  else    if all(nodes<ss)      slice = t;      nodes2 = nodes + ss;    elseif t == engine.T      slice = t;    else      slice = t + 1;    end  endend  if engine.filter & t==1  c = clq_containing_nodes(engine.sub_engine1, nodes2, isfam);  jengine = struct(engine.sub_engine1); % violate object privacyelse  c = clq_containing_nodes(engine.sub_engine, nodes2, isfam);  jengine = struct(engine.sub_engine); % violate object privacyendassert(c >= 1);jengine2 = struct(jengine.jtree_inf_engine);  bigpot = engine.cltables{c, slice};ns = bnet.node_sizes(:);onodes2 = [engine.onodes(:); engine.onodes(:)+ss];ns(onodes2) = 1;if isfam  ndx = jengine.marg_fam_ndx{nodes2(end)};else  ndx = mk_marginalise_table_ndx(jengine2.cliques{c}, nodes2, ns);end%pot = sum(bigpot(ndx), 2);%marginal.T = myreshape(pot, ns(nodes));marginal.T = marginalise_table(bigpot, ndx, jengine2.cliques{c}, nodes2, ns);% we convert the domain to the unrolled numbering system% so that update_ess extracts the right evidence.marginal.domain = nodes+(t-1)*ss;

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?