jtree_ndx_dbn_inf_engine.m

来自「麻省理工学院的人工智能工具箱,很珍贵,希望对大家有用!」· M 代码 · 共 131 行

M
131
字号
function engine = jtree_ndx_dbn_inf_engine(bnet, varargin)% JTREE_NDX_DBN_INF_ENGINE Junction tree inference algorithm for DBNs which pre-computes indices.% engine = jtree_ndx_dbn_inf_engine(bnet, ...)%% This is to jtree_dbn_inf_engine what jtree_ndx_inf_engine is to jtree_inf_engine.%% Optional arguments% ndx_type - 'B', 'D', or 'SD' [ 'SD' ]ss = length(bnet.intra);clusters = {};ndx_type = 'SD';if nargin >= 2  args = varargin;  nargs = length(args);  for i=1:2:nargs    switch args{i},     case 'clusters', clusters = args{i+1};     case 'ndx_type', ndx_type = args{i+1};    end  endendengine.evidence = [];engine.node_sizes = [];engine.ndx_type = ndx_type;[int, engine.persist, engine.transient] = compute_interface_nodes(bnet.intra, bnet.inter);%engine.interface = engine.persist; % WRONG!engine.interface = int;engine.nonint = mysetdiff(1:ss, int);engine.bnet_parents = {};for i=1:2*ss  engine.bnet_parents{i} = parents(bnet.dag, i);end% figure out when we can speedup computation of the observation likelihoodengine.parents_in_same_slice = ones(1,ss);engine.parents_hidden = ones(1,ss);engine.time_invariant = ones(1,ss);engine.tabular_node = zeros(1,ss);engine.gaussian_node = zeros(1,ss);for i=1:ss  ps = parents(bnet.dag, i+ss);  if any(ps <= ss)    engine.parents_in_same_slice(i) = 0;  end  if mysubset(ps, onodes)    engine.parents_hidden(i) = 0;  end  if bnet.eclass1(i) ~= bnet.eclass2(i)    engine.time_invariant(i) = 0;  end  engine.simple_obs(i) = engine.parents_in_same_slice(i) & engine.parents_hidden(i) & ...      engine.time_invariant(i);  % the following is not always correct...  engine.tabular_node(i) = myismember(i, bnet.dnodes);   engine.gaussian_node(i) = myismember(i, bnet.cnodes);end  % Create a "1.5 slice" jtree, containing slice 1 and the interface nodes of slice 2% To keep the node numbering the same, we simply disconnect the non-interface nodes% from slice 2, and set their size to 1.% We do this to speed things up, and so that the likelihood is computed correctly - we do not need to do% this if we just want to compute marginals. intra15 = bnet.intra;for i=engine.nonint(:)'  intra15(i,:) = 0;  intra15(:,i) = 0;enddag15 = [bnet.intra bnet.inter;	 zeros(ss)    intra15];ns = bnet.node_sizes(:);ns(engine.nonint+ss) = 1; % disconnected nodes get size 1obs_nodes = [onodes(:) onodes(:)+ss];bnet15 = mk_bnet(dag15, ns, 'discrete', bnet.dnodes, 'equiv_class', bnet.equiv_class(:), ...		 'observed', obs_nodes(:));ns(obs_nodes) = 1;% use unconstrained elimination,% but force there to be a clique containing both interfacesclusters(end+1:end+2) = {int, int+ss};engine.jtree_engine = jtree_ndxSD_inf_engine(bnet15, 'clusters', clusters, 'root', int+ss);engine.in_clq = clq_containing_nodes(engine.jtree_engine, int);engine.out_clq = clq_containing_nodes(engine.jtree_engine, int+ss);engine.jtree_struct = struct(engine.jtree_engine); % violate object privacycliques = engine.jtree_struct.cliques;engine.mult_int_onto_inclq_ndx_id = add_ndxSD(cliques{engine.in_clq}, engine.interface, ns);engine.marg_root_onto_int2_ndx_id = add_ndxSD(cliques{engine.jtree_struct.root_clq}, engine.interface+ss, ns);engine.marg_inclq_onto_int_ndx_id = add_ndxSD(cliques{engine.in_clq}, engine.interface, ns);engine.marg_outclq_onto_int2_ndx_id = add_ndxSD(cliques{engine.out_clq}, engine.interface+ss, ns);engine.mult_int2_onto_outclq_ndx_id = add_ndxSD(cliques{engine.out_clq}, engine.interface+ss, ns);engine.mult_int_onto_int_ndx_id = add_ndxSD(engine.interface, engine.interface, ns);% Also create an engine just for slice 1bnet1 = mk_bnet(bnet.intra1, bnet.node_sizes_slice, bnet.dnodes, bnet.equiv_class(:,1));for i=1:max(bnet1.equiv_class)  bnet1.CPD{i} = bnet.CPD{i};endengine.jtree_engine1 = jtree_ndxSD_inf_engine(bnet1, onodes, 'clusters', {int},  'root', int);engine.in_clq1 = clq_containing_nodes(engine.jtree_engine1, int);engine.jtree_struct1 = struct(engine.jtree_engine1); % violate object privacycliques = engine.jtree_struct1.cliques;engine.mult_int_onto_inclq1_ndx_id = add_ndxSD(cliques{engine.in_clq1}, engine.interface, ns);engine.marg_inclq1_onto_int_ndx_id = add_ndxSD(cliques{engine.in_clq1}, engine.interface, ns);% stuff needed by marginal_nodesengine.clpot = [];engine.T = [];engine.maximize = [];engine.actual_node_sizes = [];engine.eff_node_sizes = [];engine = class(engine, 'jtree_ndxSD_dbn_inf_engine', inf_engine(bnet));

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?