enter_evidence.m

来自「Bayes网络工具箱」· M 代码 · 共 159 行

M
159
字号
function [engine, iter] = enter_evidence(engine, evidence, filename)% ENTER_EVIDENCE Add the specified evidence to the network (loopy_pearl)% [engine, num_iter] = enter_evidence(engine, evidence, filename)% evidence{i} = [] if if X(i) is hidden, and otherwise contains its observed value (scalar or column vector)%% If filename is specified, we will print the first component of the marginal prob. of each% hidden node at every iteration to said file. This can be used to monitor convergence while the% process is running.%% 'num_iter' contains the number of iterations used.if nargin < 3, filename = []; endif ~isempty(filename)  fid = fopen(filename, 'w');  assert(fid > 0);else  fid = 0;endbnet = bnet_from_engine(engine);N = length(bnet.dag);ns = bnet.node_sizes(:);msg = init_msgs(bnet.dag, ns, evidence);% Convergence criterion is that the last W bel's are approximately the same,% so we use a wrap-around buffer.if engine.momentum == 0  W = 2; % check bel(t) == bel(t-1)else  W = 3; % check bel(t) == bel(t-1) == bel(t-2)endhistory = cell(1,N);for n=1:N  history{n} = zeros(W, ns(n));endconverged = 0;iter = 1;hidden = find(isemptycell(evidence));while ~converged & (iter <= engine.max_iter)  % Everybody updates their state in parallel  for n=1:N    cs = children(bnet.dag, n);    msg{n}.lambda = compute_lambda(n, cs, msg);    ps = parents(bnet.dag, n);    msg{n}.pi = compute_pi(bnet.CPD{bnet.equiv_class(n)}, n, ps, msg);  end  [converged, history] = check_converged(iter, msg, history, hidden, engine.tol, fid);  if ~converged    % Everybody sends to all their neighbors in parallel    for n=1:N      % lambda msgs to parents      ps = parents(bnet.dag, n);      for p=ps(:)'	j = engine.child_index{p}(n); % n is p's j'th child	old_msg = msg{p}.lambda_from_child{j}(:);	new_msg = normalise(compute_lambda_msg(bnet.CPD{bnet.equiv_class(n)}, n, ps, msg, p));	lam_msg = engine.momentum * old_msg + (1-engine.momentum)*new_msg;	msg{p}.lambda_from_child{j} = lam_msg(:);      end       % pi msgs to children      cs = children(bnet.dag, n);      for c=cs(:)'	j = engine.parent_index{c}(n); % n is c's j'th parent	old_msg = msg{c}.pi_from_parent{j}(:);	new_msg = normalise(compute_pi_msg(n, cs, msg, c));	pi_msg = engine.momentum * old_msg + (1-engine.momentum)*new_msg;	msg{c}.pi_from_parent{j} = pi_msg;      end    end    iter = iter + 1;  endendengine.marginal = cell(1,N);for n=1:N  [bel, lik] = normalise(msg{n}.pi .* msg{n}.lambda);       engine.marginal{n} = bel;endengine.evidence = evidence; % needed by marginal_nodes and marginal_familyengine.msg = msg;  % needed by marginal_familyloglik = log(lik);%%%%%%%function lambda = compute_lambda(n, cs, msg)% Pearl p183 eq 4.50lambda = prod_lambda_msgs(n, cs, msg);%%%%%%%function pi_msg = compute_pi_msg(n, cs, msg,  c)% Pearl p183 eq 4.53 and 4.51pi_msg = msg{n}.pi .* prod_lambda_msgs(n, cs, msg, c);%%%%%%%%%function lam = prod_lambda_msgs(n, cs, msg, except)if nargin < 4, except = -1; endlam = msg{n}.lambda_from_self(:);for i=1:length(cs)  c = cs(i);  if c ~= except    lam = lam .* msg{n}.lambda_from_child{i};  endend   %%%%%%%%function [converged, history] = check_converged(iter, msg, history, hidden, tol, fid)converged = 0;W = size(history{1},1);if iter < W  return;endchanged = 0;if ~isempty(fid)  fprintf(fid, 'iteration %d\n', iter);endfor n=hidden(:)'  [bel, lik] = normalise(msg{n}.pi .* msg{n}.lambda);       if ~isempty(fid)    fprintf(fid, '%9.7f ', bel(1));  end  history{n}(wrap(iter, W), :) = bel;  for i=2:W    if ~approxeq(history{n}(1,:), history{n}(i,:), tol)      changed = 1;    end  endendif ~isempty(fid)  fprintf(fid, '\n\n');endif ~changed, converged = 1; end%%%%%%%function v = wrap(u,N)% WRAP Wrap a vector of indices around a torus.% v = wrap(u,N)%% e.g., wrap([-1 0 1 2 3 4], 3)   =   2 3 1 2 3 1v = mod(u-1,N)+1;       

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?