enter_evidence.m
来自「Bayes网络工具箱」· M 代码 · 共 159 行
M
159 行
function [engine, iter] = enter_evidence(engine, evidence, filename)% ENTER_EVIDENCE Add the specified evidence to the network (loopy_pearl)% [engine, num_iter] = enter_evidence(engine, evidence, filename)% evidence{i} = [] if if X(i) is hidden, and otherwise contains its observed value (scalar or column vector)%% If filename is specified, we will print the first component of the marginal prob. of each% hidden node at every iteration to said file. This can be used to monitor convergence while the% process is running.%% 'num_iter' contains the number of iterations used.if nargin < 3, filename = []; endif ~isempty(filename) fid = fopen(filename, 'w'); assert(fid > 0);else fid = 0;endbnet = bnet_from_engine(engine);N = length(bnet.dag);ns = bnet.node_sizes(:);msg = init_msgs(bnet.dag, ns, evidence);% Convergence criterion is that the last W bel's are approximately the same,% so we use a wrap-around buffer.if engine.momentum == 0 W = 2; % check bel(t) == bel(t-1)else W = 3; % check bel(t) == bel(t-1) == bel(t-2)endhistory = cell(1,N);for n=1:N history{n} = zeros(W, ns(n));endconverged = 0;iter = 1;hidden = find(isemptycell(evidence));while ~converged & (iter <= engine.max_iter) % Everybody updates their state in parallel for n=1:N cs = children(bnet.dag, n); msg{n}.lambda = compute_lambda(n, cs, msg); ps = parents(bnet.dag, n); msg{n}.pi = compute_pi(bnet.CPD{bnet.equiv_class(n)}, n, ps, msg); end [converged, history] = check_converged(iter, msg, history, hidden, engine.tol, fid); if ~converged % Everybody sends to all their neighbors in parallel for n=1:N % lambda msgs to parents ps = parents(bnet.dag, n); for p=ps(:)' j = engine.child_index{p}(n); % n is p's j'th child old_msg = msg{p}.lambda_from_child{j}(:); new_msg = normalise(compute_lambda_msg(bnet.CPD{bnet.equiv_class(n)}, n, ps, msg, p)); lam_msg = engine.momentum * old_msg + (1-engine.momentum)*new_msg; msg{p}.lambda_from_child{j} = lam_msg(:); end % pi msgs to children cs = children(bnet.dag, n); for c=cs(:)' j = engine.parent_index{c}(n); % n is c's j'th parent old_msg = msg{c}.pi_from_parent{j}(:); new_msg = normalise(compute_pi_msg(n, cs, msg, c)); pi_msg = engine.momentum * old_msg + (1-engine.momentum)*new_msg; msg{c}.pi_from_parent{j} = pi_msg; end end iter = iter + 1; endendengine.marginal = cell(1,N);for n=1:N [bel, lik] = normalise(msg{n}.pi .* msg{n}.lambda); engine.marginal{n} = bel;endengine.evidence = evidence; % needed by marginal_nodes and marginal_familyengine.msg = msg; % needed by marginal_familyloglik = log(lik);%%%%%%%function lambda = compute_lambda(n, cs, msg)% Pearl p183 eq 4.50lambda = prod_lambda_msgs(n, cs, msg);%%%%%%%function pi_msg = compute_pi_msg(n, cs, msg, c)% Pearl p183 eq 4.53 and 4.51pi_msg = msg{n}.pi .* prod_lambda_msgs(n, cs, msg, c);%%%%%%%%%function lam = prod_lambda_msgs(n, cs, msg, except)if nargin < 4, except = -1; endlam = msg{n}.lambda_from_self(:);for i=1:length(cs) c = cs(i); if c ~= except lam = lam .* msg{n}.lambda_from_child{i}; endend %%%%%%%%function [converged, history] = check_converged(iter, msg, history, hidden, tol, fid)converged = 0;W = size(history{1},1);if iter < W return;endchanged = 0;if ~isempty(fid) fprintf(fid, 'iteration %d\n', iter);endfor n=hidden(:)' [bel, lik] = normalise(msg{n}.pi .* msg{n}.lambda); if ~isempty(fid) fprintf(fid, '%9.7f ', bel(1)); end history{n}(wrap(iter, W), :) = bel; for i=2:W if ~approxeq(history{n}(1,:), history{n}(i,:), tol) changed = 1; end endendif ~isempty(fid) fprintf(fid, '\n\n');endif ~changed, converged = 1; end%%%%%%%function v = wrap(u,N)% WRAP Wrap a vector of indices around a torus.% v = wrap(u,N)%% e.g., wrap([-1 0 1 2 3 4], 3) = 2 3 1 2 3 1v = mod(u-1,N)+1;
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?