smooth_evidence_fast.m

来自「Bayes网络工具箱」· M 代码 · 共 186 行

M
186
字号
function [marginal, msg, loglik] = smooth_evidence_fast(engine, evidence)[ss T] = size(evidence);bnet = bnet_from_engine(engine);onodes = engine.onodes(:)';hnodes = mysetdiff(1:ss, onodes);hnodes = hnodes(:)';ns = bnet.node_sizes(:);onodes2 = [onodes(:); onodes(:)+ss];ns(onodes2) = 1;	   verbose = 0;pot_type = 'd';niter = engine.max_iter;if verbose, fprintf('new smooth\n'); end% msg(i1,t1,i2,j2) (i1,t1) -> (i2,t2)%lambda_msg = cell(ss,T,ss,T);%pi_msg = cell(ss,T,ss,T);% intra_lambda_msg(i,j,t) (i,t) -> (j,t), i is child% inter_lambda_msg(i,j,t) (i,t+1) -> (j,t), i is child% inter_pi_msg(i,j,t) (i,t-1) -> (j,t), i is parentintra_lambda_msg = cell(ss,ss,T);inter_lambda_msg = cell(ss,ss,T);inter_pi_msg = cell(ss,ss,T);lambda = cell(ss,T);pi = cell(ss,T);for t=1:T  for i=1:ss    lambda{i,t} = ones(ns(i), 1);    pi{i,t} = ones(ns(i), 1);        cs = children(bnet.intra, i);    for c=cs(:)'      intra_lambda_msg{c,i,t} = ones(ns(i),1);    end        cs = children(bnet.inter, i);    for c=cs(:)'      inter_lambda_msg{c,i,t} = ones(ns(i),1);    end        ps = parents(bnet.inter, i);    for p=ps(:)'      inter_pi_msg{p,i,t} = ones(ns(i), 1); % not used for t==1    end  endend% Convert CPDs of instantiated nodes to potential formCPDpot = cell(ss,T);      ns = repmat(bnet.node_sizes_slice(:), 1, T);cnodes = unroll_set(bnet.cnodes(:), ss, T);t = 1;for n=onodes  fam = family(bnet.dag, n);  e = bnet.equiv_class(n, 1);  CPDpot{n,t} = CPD_to_table(bnet.CPD{e}, fam, ns, cnodes, evidence(:,1));endfor n=onodes  fam = family(bnet.dag, n, 2);  doms = unroll_set(fam, ss, T-1);  e = bnet.equiv_class(n, 2);  CPDpot(n,2:T) = CPD_to_tables(bnet.CPD{e}, doms, ns, cnodes, evidence);end% each hidden node absorbs lambda from its observed child (if any)for t=1:T  for i=hnodes    c = engine.obschild(i);    if c > 0      intra_lambda_msg{c,i,t} = normalise(CPDpot{c,t});    end  endendfor iter=1:engine.max_iter  % FORWARD  for t=1:T    % update pi    for i=hnodes      if t==1	e = bnet.equiv_class(i,1);	temp = struct(bnet.CPD{e});	pi{i,t} = temp.CPT;      else	e = bnet.equiv_class(i,2);	CPD = struct(bnet.CPD{e});	ps = parents(bnet.inter, i);	temp = CPD.CPT;	for p=ps(:)'	  temp(:) = temp(:) .* inter_pi_msg{p,i,t}(engine.mult_parent_ndx{i,p});	end	pi{i,t} = sum(temp(engine.marg_self_ndx{i}), 2);      end      if verbose, fprintf('%d updates pi\n', i+(t-1)*ss); disp(pi{i,t}); end    end        % send pi msg to (hidden) children in next slice    if t < T      for i=hnodes	cs = children(bnet.inter, i);	for c=cs(:)'	  pot = pi{i,t};	  for k=cs(:)'	    if k ~= c	      pot = pot .* inter_lambda_msg{k,i,t};	    end	  end	  cs2 = children(bnet.intra, i);	  for k=cs2(:)'	    pot = pot .* intra_lambda_msg{k,i,t};	  end	  old_msg = inter_pi_msg{i,c,t+1};	  new_msg = normalise(pot);	  pi_msg = engine.momentum * old_msg + (1-engine.momentum)*new_msg;	  inter_pi_msg{i,c,t+1} = pi_msg;	  if verbose, fprintf('%d sends pi to %d\n', i+(t-1)*ss, c+t*ss); disp(inter_pi_msg{i,c,t+1}); end	end      end    end      end  if verbose, fprintf('backwards\n'); end  % BACKWARD  for t=T:-1:1    % update lambda    for i=hnodes      pot = ones(ns(i), 1);      cs = children(bnet.inter, i);      for c=cs(:)'	pot = pot .* inter_lambda_msg{c,i,t};      end      cs = children(bnet.intra, i);      for c=cs(:)'	pot = pot .* intra_lambda_msg{c,i,t};      end      lambda{i,t} = normalise(pot);      if verbose, fprintf('%d computes lambda\n', i+(t-1)*ss); disp(lambda{i,t}); end    end        % send lambda msgs to (hidden) parents in prev slice    for i=hnodes      ps = parents(bnet.inter, i);      if t > 1	e = bnet.equiv_class(i, 2);	CPD = struct(bnet.CPD{e});	for p=ps(:)'	  temp = CPD.CPT(:) .* lambda{i,t}(engine.mult_self_ndx{i});	  for k=ps(:)'	    if k ~= p	      temp(:) = temp(:) .* inter_pi_msg{k,i,t}(engine.mult_parent_ndx{i,k});	    end	  end	  new_msg = normalise(sum(temp(engine.marg_parent_ndx{i,p}), 2));	  old_msg = inter_lambda_msg{i,p,t-1};	  lam_msg = engine.momentum * old_msg + (1-engine.momentum)*new_msg;	  inter_lambda_msg{i,p,t-1} = lam_msg;	  if verbose, fprintf('%d sends lambda to %d\n', i+(t-1)*ss, p+(t-2)*ss); disp(inter_lambda_msg{i,p,t-1}); end	end      end    end  endendmarginal = cell(ss,T);for t=1:T  for i=hnodes    marginal{i,t} = normalise(pi{i,t} .* lambda{i,t});       endendloglik = 0;msg.inter_pi_msg = inter_pi_msg;msg.inter_lambda_msg = inter_lambda_msg;msg.intra_lambda_msg = intra_lambda_msg;msg.pi = pi;msg.lambda = lambda;

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?