📄 graphalgorithms.cc.svn-base
字号:
#include <boost/graph/subgraph.hpp>#include <boost/graph/copy.hpp>#include <boost/graph/depth_first_search.hpp>#include <GraphAlgorithms.h>#include <Generic.h>#include <Message.h>#include <GraphTreePartition.h>#include <Potential.h>#include <Node.h>#include <RandomVariable.h>#include <OutputGraph.h>#include <Experiment.h>using namespace boost;using namespace std;// Explicit Template Instantiations //typedef property< edge_index_t, unsigned int, DiscretePotential * > DiscretePairwiseEdgeProperty;typedef adjacency_list<vecS, vecS, undirectedS, MessageNode <DiscreteRandomVariable, boost::adjacency_list_traits<vecS, vecS, undirectedS>::vertex_descriptor> *, DiscretePairwiseEdgeProperty > DiscretePairwiseGraph;typedef subgraph <DiscretePairwiseGraph> SubDiscretePairwiseGraph;typedef property< edge_index_t, unsigned int > EdgeProperty;typedef adjacency_list<vecS, vecS, undirectedS, MessageNode <DiscreteRandomVariable, boost::adjacency_list_traits<vecS, vecS, undirectedS>::vertex_descriptor> *, EdgeProperty > DiscreteFactorGraph;typedef subgraph <DiscreteFactorGraph> SubDiscreteFactorGraph;//End of typedefstemplate void belief_propagation < DiscretePairwiseGraph > (DiscretePairwiseGraph &);template void belief_propagation < DiscreteFactorGraph > (DiscreteFactorGraph &);template void loopy_belief_propagation < DiscretePairwiseGraph > (DiscretePairwiseGraph &, const unsigned int, const bool = false);template void loopy_belief_propagation < DiscreteFactorGraph > (DiscreteFactorGraph &, const unsigned int, const bool = false);template void gibbs_sampler < DiscretePairwiseGraph > (DiscretePairwiseGraph &, const unsigned int, const unsigned int, const bool);template void gibbs_sampler < DiscreteFactorGraph > (DiscreteFactorGraph &, const unsigned int, const unsigned int, const bool);template void tree_gibbs_sampler < DiscretePairwiseGraph > (DiscretePairwiseGraph &, const unsigned int, const unsigned int = 0, const bool = true, const bool = false);template void tree_gibbs_sampler < DiscreteFactorGraph > (DiscreteFactorGraph &, const unsigned int, const unsigned int = 0, const bool = true, const bool = false);// End Explicit Template Instantiations ////******** Helper functions (need to be specialized) ************//// Generic version// Version for the DiscretePairwiseGraph (and SubGraph, needed for the Tree Partition...)template <>inline void setup_message_computation(DiscretePairwiseGraph & g, const graph_traits<DiscretePairwiseGraph>::vertex_descriptor origin, const graph_traits<DiscretePairwiseGraph>::vertex_descriptor destination){ static_cast <MRFMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscretePairwiseGraph>::vertex_descriptor> *> (g[origin])->message_potential = g[edge (origin, destination, g).first ]; }template <>inline void setup_message_computation(SubDiscretePairwiseGraph & g, const graph_traits<SubDiscretePairwiseGraph>::vertex_descriptor origin, const graph_traits<SubDiscretePairwiseGraph>::vertex_descriptor destination){ static_cast <MRFMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<SubDiscretePairwiseGraph>::vertex_descriptor> *> (g[origin])->message_potential = g[edge (origin, destination, g).first ]; }template <>void setup_gibbs_proposal(DiscretePairwiseGraph & g, const graph_traits<DiscretePairwiseGraph>::vertex_descriptor node){ graph_traits<DiscretePairwiseGraph>::out_edge_iterator current_edge, end_edge; MRFMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscretePairwiseGraph>::vertex_descriptor> * mrf_node = static_cast <MRFMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscretePairwiseGraph>::vertex_descriptor> *> (g[node]); mrf_node->sample_proposal = new ChainedSingleDiscretePotential(* mrf_node->variable_ptr); mrf_node->sample_proposal->add_potential( mrf_node->potential_ptr); for ( tie (current_edge, end_edge) = out_edges( node, g); current_edge != end_edge; ++current_edge) { mrf_node->sample_proposal->add_potential( g[*current_edge]); } mrf_node->variable_ptr->remove_prior_samples();}template <>void setup_tree_gibbs_sampler(SubDiscretePairwiseGraph & g, vector < vector <double> > & sum_marginals, vector < Potential * > & original_potentials){ property_map<SubDiscretePairwiseGraph, vertex_index_t>::type vertex_index_map = get(vertex_index, g); graph_traits<SubDiscretePairwiseGraph>::vertex_iterator u, u_end; graph_traits<SubDiscretePairwiseGraph>::adjacency_iterator v, v_end; subgraph<DiscretePairwiseGraph>::children_iterator current_tree, tree_end; graph_traits<DiscretePairwiseGraph>::vertex_descriptor w; for ( tie(current_tree, tree_end) = g.children(); current_tree != tree_end; ++current_tree) { for (tie(u,u_end) = vertices (*current_tree); u!= u_end; ++u) { w = current_tree->local_to_global(*u); MRFMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> * node_casted = static_cast < MRFMessageNode<DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> *> (g[w]); sum_marginals[w] = vector < double> (node_casted->variable_ptr->get_number_values(),0.0); original_potentials[ get(vertex_index_map, w)] = static_cast < DiscretePotential *> (g[w]->get_potential()); ChainedSingleDiscretePotential * p = new ChainedSingleDiscretePotential(* node_casted->variable_ptr); p->add_potential (static_cast < DiscretePotential *> (original_potentials[get(vertex_index_map, w)])); node_casted->potential_ptr = p; for (tie(v, v_end)= adjacent_vertices ( w, g ); v != v_end; ++v) { if (!current_tree->find_vertex(*v).second) { static_cast < ChainedSingleDiscretePotential *> (g[w]->get_potential())->add_potential(g [edge(w, *v, g).first]); } } } } }template <>void cleanup_tree_gibbs_sampler(DiscretePairwiseGraph & g, vector < Potential * > & original_potentials){ property_map<DiscretePairwiseGraph, vertex_index_t>::type vertex_index_map = get(vertex_index, g); graph_traits<DiscretePairwiseGraph>::vertex_iterator u, u_end; for (tie(u,u_end) = vertices (g); u!= u_end; ++u) { MRFMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> * node_casted = static_cast < MRFMessageNode<DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> *> (g[*u]); delete node_casted->potential_ptr; node_casted->potential_ptr = static_cast <DiscretePotential *> (original_potentials[ get(vertex_index_map, *u)] ); } }template <>void remove_temp_observations(SubDiscretePairwiseGraph &){}template <>void set_temp_observations(SubDiscretePairwiseGraph &){}template <>void initialize_graph_variables_random(DiscretePairwiseGraph & g){ graph_traits<DiscretePairwiseGraph>::vertex_iterator u, u_end; graph_traits<DiscretePairwiseGraph>::out_edge_iterator e, e_end; for (tie(u,u_end) = vertices (g); u!= u_end; ++u) { g[*u]->get_random_variable()->set_value_at_random(); for (tie(e,e_end) = out_edges (*u, g); e!= e_end; ++e) { g[*e]->set_variable_value(* g[*u]->get_random_variable()); } } }// Version for the DiscreteFactorGraph (nothing needs to be done here)template <>inline void setup_message_computation(DiscreteFactorGraph &, const graph_traits<DiscreteFactorGraph>::vertex_descriptor, const graph_traits<DiscreteFactorGraph>::vertex_descriptor){}template <>inline void setup_message_computation(SubDiscreteFactorGraph &, const graph_traits<SubDiscreteFactorGraph>::vertex_descriptor, const graph_traits<SubDiscreteFactorGraph>::vertex_descriptor){}template <>void setup_gibbs_proposal(DiscreteFactorGraph & g, const graph_traits<DiscreteFactorGraph>::vertex_descriptor node){ graph_traits<DiscreteFactorGraph>::adjacency_iterator current_neighbour, end_neighbour; VariableMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> * vmn = dynamic_cast < VariableMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> *> ( g[node]); if ( vmn != NULL) { vmn->sample_proposal = new ChainedSingleDiscretePotential(* vmn->variable_ptr); vmn->sample_proposal->add_potential( vmn->potential_ptr); for ( tie (current_neighbour, end_neighbour) = adjacent_vertices( node, g); current_neighbour != end_neighbour; ++current_neighbour) { vmn->sample_proposal->add_potential( static_cast < DiscretePotential * > (g[*current_neighbour]->get_potential())); } vmn->variable_ptr->remove_prior_samples(); }}template <>void setup_tree_gibbs_sampler(SubDiscreteFactorGraph & g, vector < vector <double> > & sum_marginals, vector < Potential * > & original_potentials){ property_map<SubDiscreteFactorGraph, vertex_index_t>::type vertex_index_map = get(vertex_index, g); graph_traits<SubDiscreteFactorGraph>::vertex_iterator u, u_end; graph_traits<SubDiscreteFactorGraph>::adjacency_iterator v, v_end; graph_traits<DiscreteFactorGraph>::vertex_descriptor w; subgraph<DiscreteFactorGraph>::children_iterator current_tree, tree_end; for ( tie(current_tree, tree_end) = g.children(); current_tree != tree_end; ++current_tree) { for (tie(u,u_end) = vertices (*current_tree); u!= u_end; ++u) { w = current_tree->local_to_global(*u); VariableMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> * vmn = dynamic_cast < VariableMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> *> ( g[w]); if ( vmn != NULL) { sum_marginals[get(vertex_index_map, w)] = vector < double> (vmn->variable_ptr->get_number_values(),0.0); original_potentials[ get(vertex_index_map, w)] = static_cast < DiscretePotential *> (g[w]->get_potential()); ChainedSingleDiscretePotential * p = new ChainedSingleDiscretePotential(* vmn->variable_ptr); p->add_potential (static_cast < DiscretePotential *> (original_potentials[get(vertex_index_map, w)])); vmn->potential_ptr = p; for (tie(v, v_end)= adjacent_vertices ( w, g ); v != v_end; ++v) { if (! (current_tree->find_vertex(*v).second)) { // Add that potential (Factor Graphs) p->add_potential (static_cast < PotentialMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> * > (g [*v])->potential_ptr); } } } } }}template <>void remove_temp_observations(SubDiscreteFactorGraph & g){ graph_traits<SubDiscreteFactorGraph>::vertex_iterator u, u_end; graph_traits<SubDiscreteFactorGraph>::adjacency_iterator v, v_end; for ( tie (u, u_end) = vertices(g); u != u_end; ++u) { VariableMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> * vmn = dynamic_cast < VariableMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> *> ( g[*u]); if (vmn != NULL) { if (!vmn->variable_ptr->is_conditionned() ) { vmn->potential_ptr->unset_observed_variable(* vmn->variable_ptr); for ( tie (v, v_end) = adjacent_vertices(*u,g); v != v_end; ++v) { static_cast < PotentialMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> *> (g[*v])->potential_ptr->unset_observed_variable(* vmn->variable_ptr); } } } }}template <>void cleanup_tree_gibbs_sampler(DiscreteFactorGraph & g, vector < Potential * > & original_potentials){ property_map<DiscreteFactorGraph, vertex_index_t>::type vertex_index_map = get(vertex_index, g); graph_traits<DiscreteFactorGraph>::vertex_iterator u, u_end; for (tie(u,u_end) = vertices (g); u!= u_end; ++u) { VariableMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> * vmn = dynamic_cast < VariableMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> *> ( g[*u]); if (vmn != NULL) { delete vmn->potential_ptr; vmn->potential_ptr = static_cast <DiscretePotential *> (original_potentials[ get(vertex_index_map, *u)]); } } }template <>void set_temp_observations(SubDiscreteFactorGraph & g){ graph_traits<SubDiscreteFactorGraph>::vertex_iterator u, u_end; graph_traits<SubDiscreteFactorGraph>::adjacency_iterator v, v_end; for ( tie (u, u_end) = vertices(g); u != u_end; ++u) { VariableMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> * vmn = dynamic_cast < VariableMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> *> ( g[*u]); if (vmn != NULL) { vmn->potential_ptr->set_observed_variable(* vmn->variable_ptr); for ( tie (v, v_end) = adjacent_vertices(*u,g); v != v_end; ++v) { static_cast < PotentialMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> *> (g[*v])->potential_ptr->set_observed_variable(* vmn->variable_ptr); } } } }template <>void initialize_graph_variables_random(DiscreteFactorGraph & g){ graph_traits<DiscreteFactorGraph>::vertex_iterator u, u_end; graph_traits<DiscreteFactorGraph>::adjacency_iterator v, v_end; for (tie(u,u_end) = vertices (g); u!= u_end; ++u) { VariableMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> * vmn = dynamic_cast < VariableMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> *> ( g[*u]); if (vmn != NULL) { vmn->variable_ptr->set_value_at_random(); for ( tie (v, v_end) = adjacent_vertices(*u,g); v != v_end; ++v) { static_cast < PotentialMessageNode <DiscreteRandomVariable, DiscretePotential, graph_traits<DiscreteFactorGraph>::vertex_descriptor> *> (g[*v])->potential_ptr->set_variable_value(* vmn->variable_ptr); } } } }//************** Visitor RecordParent Implementation *************//template <class Graph>RecordParent<Graph>::RecordParent ( Graph & a ) : g(a) {}template <class Graph>template <class Edge>void RecordParent<Graph>::operator()(Edge e, const Graph& ){ g[target(e,g)]->parent_node = source(e,g); //cout << "I recorded " << source(e,g) << " as the parent of " << target(e,g) << endl;} //************** End of Visitor RecordParent Implementation *************////************** Visitor Collect Implementation *************//template <class Graph>Collect<Graph>::Collect(Graph & a): g(a){}template <class Graph>template <class VertexDescriptor>void Collect<Graph>::operator()(VertexDescriptor u, const Graph & ){ // Getting the parent of u VertexDescriptor v = g[u]->parent_node; if (u == v) { // Root Vertex. We don't do anything. (The root vertex is its own parent) return; } // u is the origin, v the destination. //cout << "Sending collect message from " << g[u]->index << " to " << g[v]->index << "." << endl; setup_message_computation (g, u , v); g[v]->receive_message ( g[u]->send_message( * g[v] ) ); } //************** End of Visitor Collect Implementation *************////************** Visitor Distribute Implementation *************//template <class Graph>Distribute<Graph>::Distribute(Graph & a): g(a){}template <class Graph>template <class Vertex>void Distribute<Graph>::operator()(Vertex v, const Graph& ) { // Be conscious of the fact, that u is still the origin and v the destination... (it is NOT switched from Collect in semantics, but in // terms of the graph algorithm IT IS... // What I mean is the visitor is called no longer on the origin but on the destination. Vertex u = g[v]->parent_node; if (u == v) { // Root Vertex. We don't do anything. return; } // u is the origin, v the destination. //cout << "Sending distribute message from " << g[u]->index << " to " << g[v]->index << "." << endl; setup_message_computation (g, u , v); g[v]->receive_message ( g[u]->send_message( * g[v] ) );}//************** End of Visitor Distribute Implementation *************////************** Visitor BackwardSampling Implementation *************//template <class Graph>BackwardSampling<Graph>::BackwardSampling(Graph & a): g(a){}template <class Graph>template <class Vertex>void BackwardSampling<Graph>::operator()(Vertex v, const Graph& ){ // Note that as it samples, the recorded samples are kept in memory Vertex u = g[v]->parent_node; //cout << "Sampling from: " << g[v]->index << endl; if (u == v) { // Root Vertex. We sample from the marginal g[v]->sample_from_joint_root(); return; } // We setup the message computation, this will be used even in the joint sampling part setup_message_computation(g, u, v); // This is the joint sampling part g[v]->sample_from_joint(* g[u]);}//************** End of Visitor BackwardSampling Implementation *************//template <class Graph>void initialize_graph_contents(Graph & g){ typename graph_traits<Graph>::vertex_iterator u, u_end; for (tie(u,u_end) = vertices (g); u!= u_end; ++u) { g[*u]->initialize(); } remove_temp_observations(g);}template <class Graph>void tree_gibbs_sampler(Graph & g, const unsigned int max_steps, const unsigned int burnoff_steps = 0, const bool rao_blackwell = true, const bool time = false){ typename graph_traits<Graph>::vertex_iterator u, u_end; typename graph_traits<Graph>::adjacency_iterator v, v_end;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -