enter_evidence.m
来自「Bayes网络工具箱」· M 代码 · 共 156 行
M
156 行
function [engine, loglik] = enter_evidence(engine, evidence, filter)% ENTER_EVIDENCE Add the specified evidence to the network (frontier)% [engine, loglik] = enter_evidence(engine, evidence, filter)%% evidence{i,t} = [] if if X(i,t) is hidden, and otherwise contains its observed value (scalar or column vector)% If filter = 1, we do filtering, otherwise smoothing (default).if nargin < 3, filter = 0; end[ss T] = size(evidence);bnet = bnet_from_engine(engine);ns = repmat(bnet.node_sizes_slice(:), 1, T);eclass = [bnet.equiv_class(:,1) repmat(bnet.equiv_class(:,2), 1, T-1)];onodes = find(~isemptycell(evidence));cnodes = unroll_set(bnet.cnodes(:), ss, T);big_dag = unroll_dbn_topology(bnet.intra1, bnet.intra, bnet.inter, T);pot_type = determine_pot_type(onodes, cnodes, big_dag);ns = ns(:)';cnodes = cnodes(:)';% Convert evidence-specific CPDs to potentialsCPD = cell(ss,T);for t=1:T for i=1:ss fam = family(bnet.dag, i, t); CPD{i,t} = CPD_to_pot(pot_type, bnet.CPD{eclass(i,t)}, fam, ns, cnodes, evidence); endend% FORWARDSfwd = cell(ss,T);ll = zeros(1,T);S = 2*ss; % num. intermediate frontiers to get from t to t+1frontier = cell(S,T);% Start with empty frontier, and add each node in slice 1init = mk_initial_pot(pot_type, [], ns, cnodes, onodes); t = 1;s = 1;j = 1;frontier{s,t} = update(init, t, s, j, 1, CPD{j}, engine.fdom1{s}, pot_type, ns, cnodes, onodes);fwd{j} = frontier{s,t};for s=2:ss j = s; % add node j at step s frontier{s,t} = update(frontier{s-1,t}, t, s, j, 1, CPD{j}, engine.fdom1{s}, pot_type, ns, cnodes, onodes); fwd{j} = frontier{s,t};endfrontier{S,t} = frontier{ss,t};[frontier{S,t}, ll(1)] = normalize_pot(frontier{S,t});% Now move frontier from slice to sliceOPS = engine.ops;add = OPS>0;nodes = [zeros(S,1) unroll_set(abs(OPS(:)), ss, T-1)];for t=2:T offset = (t-2)*ss; for s=1:S if s==1 prev_ndx = (t-2)*S + S; % S,t-1 else prev_ndx = (t-1)*S + s-1; % s-1,t end j = nodes(s,t); frontier{s,t} = update(frontier{prev_ndx}, t, s, j, add(s), CPD{j}, engine.fdom{s}+offset, pot_type, ns, cnodes, onodes); if add(s) fwd{j} = frontier{s,t}; end end [frontier{S,t}, ll(t)] = normalize_pot(frontier{S,t});endloglik = sum(ll);engine.fwd = fwd;if 1 % hfilter engine.fwdback = fwd; return;end% BACKWARDSback = cell(ss,T);add = ~add; % forwards add = backwards remove frontier = cell(S,T+1);t = T;dom = (1:ss) + (t-1)*ss;frontier{1,T+1} = mk_initial_pot(pot_type, dom, ns, cnodes, onodes); % all 1s for last slicefor t=T:-1:2 offset = (t-2)*ss; for s=S:-1:1 % reverse order if s==S prev_ndx = t*S + 1; % 1,t+1 else prev_ndx = (t-1)*S + (s+1); % s+1,t end j = nodes(s,t); if ~add(s) back{j} = frontier{prev_ndx}; % save frontier before removing end frontier{s,t} = rev_update(frontier{prev_ndx}, j, add(s), CPD{j}, engine.fdom{s}+offset, pot_type, ns, cnodes, onodes); end frontier{1,t} = normalize_pot(frontier{1,t});end% Remove each node in first slice until left with empty sett = 1;frontier{ss+1,t} = frontier{1,2};for s=ss:-1:1 j = s; % remove node j at step s back{j} = frontier{s+1,t}; frontier{s,t} = rev_update(frontier{s+1,t}, j, 0, CPD{j}, 1:s, pot_type, ns, cnodes, onodes);end% COMBINEfor t=1:T for i=1:ss engine.fwdback{i,t} = normalize_pot(multiply_pots(fwd{i,t}, back{i,t})); endend% for debuggingengine.back = back;%%%%%%%%%%function new_frontier = update(old_frontier, t, s, j, add, CPD, newdom, pot_type, ns, cnodes, onodes)if add new_frontier = mk_initial_pot(pot_type, newdom, ns, cnodes, onodes); new_frontier = multiply_by_pot(new_frontier, old_frontier); new_frontier = multiply_by_pot(new_frontier, CPD);else new_frontier = marginalize_pot(old_frontier, mysetdiff(domain_pot(old_frontier), j)); endfprintf('t=%d, s=%d\n', t, s);domain_pot(new_frontier)domain_pot(old_frontier)%%%%%%function new_frontier = rev_update(old_frontier, j, add, CPD, newdom, pot_type, ns, cnodes, onodes)if add % add: extend domain to include i by multiplying by 1 new_frontier = mk_initial_pot(pot_type, newdom, ns, cnodes, onodes); new_frontier = multiply_by_pot(new_frontier, old_frontier);else % remove: multiply in CPT and then marginalize it out temp = mk_initial_pot(pot_type, newdom, ns, cnodes, onodes); temp = multiply_by_pot(temp, old_frontier); temp = multiply_by_pot(temp, CPD); new_frontier = marginalize_pot(temp, mysetdiff(domain_pot(temp), j));end
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?