⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 kpm_learn_struct_mcmc.m

📁 麻省理工学院的人工智能工具箱,很珍贵,希望对大家有用!
💻 M
字号:
function [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, varargin)% LEARN_STRUCT_MCMC  Monte Carla Markov Chain search over DAGs assuming fully observed data% [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, ...)% % data(i,m) is the value of node i in case m.% ns(i) is the number of discrete values node i can take on.%% sampled_graphs{m} is the m'th sampled graph.% accept_ratio(t) = acceptance ratio at iteration t% num_edges(t) = number of edges in model at iteration t%% The following optional arguments can be specified in the form of name/value pairs:% [default value in brackets]%% scoring_fn - 'bayesian' or 'bic' [ 'bayesian' ]%              Currently, only networks with all tabular nodes support Bayesian scoring.% type       - type{i} is the type of CPD to use for node i, where the type is a string%              of the form 'tabular', 'noisy_or', 'gaussian', etc. [ all cells contain 'tabular' ]% params     - params{i} contains optional arguments passed to the CPD constructor for node i,%              or [] if none.  [ all cells contain {'prior', 1}, meaning use uniform Dirichlet priors ]% discrete   - the list of discrete nodes [ 1:N ]% clamped    - clamped(i,m) = 1 if node i is clamped in case m [ zeros(N, ncases) ]% nsamples   - number of samples to draw from the chain after burn-in [ 100*N ]% burnin     - number of steps to take before drawing samples [ 5*N ]% init_dag   - starting point for the search [ zeros(N,N) ]%% e.g., samples = learn_struct_mcmc(data, ns, 'nsamples', 1000);%% This interface is not backwards compatible with BNT2,% but is designed to be compatible with the other learn_struct_xxx routines.%% Note: We currently assume a uniform structural prior.[n ncases] = size(data);% set default paramstype = cell(1,n);params = cell(1,n);for i=1:n  type{i} = 'tabular';  %params{i} = { 'prior', 1 };  params{i} = { 'prior_type', 'dirichlet', 'dirichlet_weight', 1 };endscoring_fn = 'bayesian';discrete = 1:n;clamped = zeros(n, ncases);nsamples = 100*n;burnin = 5*n;dag = zeros(n);args = varargin;nargs = length(args);for i=1:2:nargs  switch args{i},   case 'nsamples',   nsamples = args{i+1};   case 'burnin',     burnin = args{i+1};   case 'init_dag',   dag = args{i+1};   case 'scoring_fn', scoring_fn = args{i+1};   case 'type',       type = args{i+1};    case 'discrete',   discrete = args{i+1};    case 'clamped',    clamped = args{i+1};    case 'params',     if isempty(args{i+1}), params = cell(1,n); else params = args{i+1};  end  endend% We implement the fast acyclicity check described by P. Giudici and R. Castelo,% "Improving MCMC model search for data mining", submitted to J. Machine Learning, 2001.use_giudici = 1;if use_giudici  [nbrs, ops, nodes] = mk_nbrs_of_digraph(dag);  A = init_ancestor_matrix(dag);else  [nbrs, ops, nodes] = mk_nbrs_of_dag(dag);  A = [];endnum_accepts = 1;num_rejects = 1;T = burnin + nsamples;accept_ratio = zeros(1, T);num_edges = zeros(1, T);sampled_graphs = cell(1, nsamples);%sampled_bitv = zeros(nsamples, n^2);for t=1:T  [dag, nbrs, ops, nodes, A, accept] = take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ...						 scoring_fn, discrete, type, params);  num_edges(t) = sum(dag(:));  num_accepts = num_accepts + accept;  num_rejects = num_rejects + (1-accept);  accept_ratio(t) =  num_accepts/num_rejects;  if t > burnin    sampled_graphs{t-burnin} = dag;    %sampled_bitv(t-burnin, :) = dag(:)';  endend%%%%%%%%%function [new_dag, new_nbrs, new_ops, new_nodes, A, accept] = ...    take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ...	      scoring_fn, discrete, type, params)use_giudici = ~isempty(A);if use_giudici  [new_dag, op, i, j] = pick_digraph_nbr(dag, nbrs, ops, nodes, A);  %assert(acyclic(new_dag));  [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_digraph(new_dag);else  d = sample_discrete(normalise(ones(1, length(nbrs))));  new_dag = nbrs{d};  op = ops{d};  i = nodes(d, 1); j = nodes(d, 2);  [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_dag(new_dag);endbf =  bayes_factor(dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params);%R = bf * (new_prior / prior) * (length(nbrs) / length(new_nbrs)); R = bf * (length(nbrs) / length(new_nbrs)); u = rand(1,1);if u > min(1,R) % reject the move  accept = 0;  new_dag = dag;  new_nbrs = nbrs;  new_ops = ops;  new_nodes = nodes;else  accept = 1;  if use_giudici    A = update_ancestor_matrix(A, op, i, j, new_dag);  endend%%%%%%%%%function bfactor = bayes_factor(old_dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params)u = find(clamped(j,:)==0);LLnew = score_family(j, parents(new_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j});LLold = score_family(j, parents(old_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j});bf1 = exp(LLnew - LLold);if strcmp(op, 'rev')  % must also multiply in the changes to i's family  u = find(clamped(i,:)==0);  LLnew = score_family(i, parents(new_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i});  LLold = score_family(i, parents(old_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i});  bf2 = exp(LLnew - LLold);else  bf2 = 1;endbfactor = bf1 * bf2;%%%%%%%% Giudici stuff follows %%%%%%%%%%function [new_dag, op, i, j] = pick_digraph_nbr(dag, digraph_nbrs, ops, nodes, A)legal = 0;while ~legal  d = sample_discrete(normalise(ones(1, length(digraph_nbrs))));  i = nodes(d, 1); j = nodes(d, 2);  switch ops{d}   case 'add',    if A(i,j)==0      legal = 1;    end   case 'del',    legal = 1;   case 'rev',    ps = mysetdiff(parents(dag, j), i);    % if any(A(ps,i)) then there is a path i -> parent of j -> j    % so reversing i->j would create a cycle    legal = ~any(A(ps, i));  endend%new_dag = digraph_nbrs{d};new_dag = digraph_nbrs(:,:,d);op = ops{d};i = nodes(d, 1); j = nodes(d, 2);%%%%%%%%%%%%%%function A = update_ancestor_matrix(A, op, i, j, dag)switch op case 'add',  A = do_addition(A, op, i, j, dag); case 'del',   A = do_removal(A, op, i, j, dag); case 'rev',   A = do_removal(A, op, i, j, dag);  A = do_addition(A, op, j, i, dag);end  %%%%%%%%%%%%function A = do_addition(A, op, i, j, dag)A(j,i) = 1; % i is an ancestor of janci = find(A(i,:));if ~isempty(anci)  A(j,anci) = 1; % all of i's ancestors are added to Anc(j)endancj = find(A(j,:));descj = find(A(:,j)); if ~isempty(ancj)  for k=descj(:)'    A(k,ancj) = 1; % all of j's ancestors are added to each descendant of j  endend%%%%%%%%%%%function A = do_removal(A, op, i, j, dag)% find all the descendants of j, and put them in topological order%descj = find(A(:,j)); R = reachability_graph(dag);descj = find(R(j,:)); order = topological_sort(dag);descj_topnum = order(descj);[junk, perm] = sort(descj_topnum);descj = descj(perm);% Update j and all its descendantsA = update_row(A, j, dag);for k = descj(:)'  A = update_row(A, k, dag);end%%%%%%%%%function A = update_row(A, j, dag)% We compute row j of AA(j, :) = 0;ps = parents(dag, j);if ~isempty(ps)  A(j, ps) = 1;endfor k=ps(:)'  anck = find(A(k,:));  if ~isempty(anck)    A(j, anck) = 1;  endend  %%%%%%%%function A = init_ancestor_matrix(dag)order = topological_sort(dag);A = zeros(length(dag));for j=order(:)'  A = update_row(A, j, dag);end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -