📄 marginal_nodes.m
字号:
function [marginal, engine] = marginal_nodes(engine, nodes, varargin);% MARGINAL_NODES Compute the marginal on the specified query nodes% (gibbs_sampling_engine)% [marginal, engine] = marginal_nodes(engine, nodes, ...)%% returns Pr(X(nodes) | X(observedNodes))%% The engine is also modified, and so it is returned as well, since% Matlab doesn't support passing by reference(!) So% if you want to, for example, incrementally run gibbs for a few 100% steps at a time, you should use the returned value.%% Optional arguments :%% 'reset_counts' is 1 if you want to reset the counts made in the% past, and 0 otherwise (if the current query nodes are different% from the previous query nodes, or if marginal_nodes has not been% called before, reset_counts should be set to 1).% By default it is 1.reset_counts = 1;if (nargin > 3) args = varargin; nargs = length(args); for i = 1:2:nargs switch args{i} case 'reset_counts' reset_counts = args{i+1}; otherwise error(['Incorrect argument to gibbs_sampling_engine/' ... ' marginal_nodes']); end endend% initialization stuff bnet = bnet_from_engine(engine);slice_size = engine.slice_size;hnodes = engine.hnodes;onodes = engine.onodes;nonqnodes = mysetdiff(1:slice_size, nodes);gap = engine.gap;burnin = engine.burnin;T_max = engine.T;ns = bnet.node_sizes(nodes);% Cache the strides for the marginal tablemarg_strides = [1 cumprod(ns(1:end-1))]; % Reset counts if necessaryif (reset_counts == 1) %state = sample_bnet(bnet, 1, 0); %state = cell2num(sample_bnet(bnet, 'evidence', num2cell(engine.evidence))); state = cell2num(sample_bnet(bnet)); state(onodes) = engine.evidence(onodes); if (length(ns) == 1) marginal_counts = zeros(ns(1),1); else marginal_counts = zeros(ns); end % Otherwise, use the counts that have been stored in the engine else state = engine.state; state(onodes, :) = engine.evidence(onodes, :); marginal_counts = engine.marginal_counts;endif (engine.deterministic == 1) pos = 1; order = engine.order; orderSize = length(engine.order);else sampling_dist = normalise(engine.sampling_dist);endfor t = 1:(T_max*gap+burnin) % First, select node m to sample if (engine.deterministic == 1) m = engine.order(pos); pos = pos+1; if (pos > orderSize) pos = 1; end else m = my_sample_discrete(sampling_dist); end % If the node is observed, then don't bother resampling if (myismember(m, onodes)) continue; end % Next, compute the posterior post = compute_posterior (bnet, state, m, engine.strides, engine.families, ... engine.children, engine.CPT); state(m) = my_sample_discrete(post); % Now update our monte carlo estimate of the posterior % distribution on the query node if ((mod(t-burnin, gap) == 0) & (t > burnin)) vals = state(nodes); index = 1+marg_strides*(vals-1); marginal_counts(index) = marginal_counts(index)+1; endend% Store results for future computation. Note that we store% unnormalized countsengine.state = state;engine.marginal_counts = marginal_counts;marginal.T = normalise(marginal_counts);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -