📄 node.h.svn-base
字号:
#ifndef NODE_H#define NODE_H#include <string>#include <list>#include <vector>#include <boost/graph/graph_traits.hpp>#ifndef NDEBUG#include <iostream>#endifclass Message;class Potential;class DiscretePotential;class ChainedSingleDiscretePotential;class RandomVariable;class DiscreteRandomVariable;//template <class A> class ComputeMessage; // forward declaration of templated class// ********************* Node Declaration **********************//// Currently, the top class, Node is never used; its private variable are also not used.class Node{ public: // Constructors Node(); Node(const unsigned int); // Debug methods #ifndef NDEBUG template <class Graph, class C, class P> friend bool get_tree(Graph & , C, P); template <class Graph, class C, class P> friend bool check_rv( Graph &, typename boost::graph_traits<Graph>::vertex_descriptor, std::list < typename boost::graph_traits<Graph>::vertex_descriptor> &, C &, P &);#endif protected: const std::string name; const unsigned int index;};// ********************* End of Node Declaration **********************//// ********************* GraphicalModelNode Declaration **********************////template <class RanVar, class VertexDescriptor>class GraphicalModelNode: public Node{public: GraphicalModelNode (RandomVariable * const, Potential * const); RandomVariable * get_random_variable() const; Potential * get_potential() const; void set_potential(Potential * const); private: RandomVariable * variable_ptr; Potential * potential_ptr;};inline RandomVariable * GraphicalModelNode::get_random_variable() const{ return variable_ptr;}inline Potential * GraphicalModelNode::get_potential() const{ return potential_ptr;}inline void GraphicalModelNode::set_potential( Potential * const p){ potential_ptr = p;}// ********************* End of GraphicalModelNode Declaration **********************//// ********************* MessageNode Declaration **********************//// Currently, I derive 3 classes from this base class; MRF, Variable, and Potential. MRF corresponds to an MRF, while Variable and // Potential are both used in a Factor Graphtemplate <class RanVar, class VertexDescriptor>class MessageNode: public Node{ public: typedef RanVar random_variable_type; // Constructor, destructor MessageNode (); MessageNode (const unsigned int); virtual ~MessageNode(); // Virtual methods (interface) virtual void initialize() = 0; virtual Message send_message(MessageNode &) = 0; virtual void receive_message(Message); virtual void update_message(Message); virtual void sample_from_proposal(bool) = 0; virtual void sample_from_joint_root () = 0; virtual void sample_from_joint (MessageNode &) = 0; virtual void serialize_marginals(std::ofstream &) const = 0; virtual void set_variable_to_joint_sampled_value() = 0; virtual void compute_marginals_from_messages() = 0; virtual void compute_marginals_from_samples() = 0; // Deprecated methods virtual Potential * get_potential() const = 0; virtual void set_potential(Potential *); // should never be called from the base class virtual RanVar * get_random_variable() const; // should never be called from the base class // Normal methods std::list <Message> & obtain_messages_list(); // Friends template <class G> friend void belief_propagation (G &); template <class G> friend void belief_propagation_with_backward_sampling (G &); template <class G> friend class RecordParent; template <class G> friend class BackwardSampling; template <class G> friend class Distribute; template <class G> friend class Collect; template <class G> friend void tree_gibbs_sampler (G &, const unsigned int, const unsigned int, const bool, const bool); protected: double get_message_value( RanVar & rv); std::list <Message> messages_list; // one of the things these classes all have in common is that they store Messages VertexDescriptor parent_node; // we also need to store the parent of each node private: virtual void add_rao_blackwell_estimates( std::vector <double > &) const = 0; virtual void set_rao_blackwell_estimates( const std::vector <double > &) = 0;};// Inline functions followtemplate <class RanVar, class VertexDescriptor >inline std::list <Message> & MessageNode <RanVar, VertexDescriptor >::obtain_messages_list(){ return messages_list;}// ********************* End MessageNode Declaration **********************//// ********************** MRFMessageNode Declaration **********************//template <class RanVar, class Pot, class VertexDescriptor>class MRFMessageNode : public MessageNode <RanVar, VertexDescriptor>{ public: typedef Pot potential_type; // Constructor, destructor MRFMessageNode (RanVar *, Pot *); ~MRFMessageNode(); // Inherited virtual functions void initialize(); Message send_message(MessageNode <RanVar, VertexDescriptor > &); void sample_from_proposal(bool); void serialize_marginals(std::ofstream &) const; void sample_from_joint_root (); void sample_from_joint (MessageNode <RanVar, VertexDescriptor > &); void set_variable_to_joint_sampled_value(); void compute_marginals_from_messages(); void compute_marginals_from_samples(); // Deprecated methods RanVar * get_random_variable() const; Pot * get_potential() const; // Normal methods void set_potential( Potential * ); // Friends template <class Graph> friend void setup_message_computation (Graph &, const typename boost::graph_traits<Graph>::vertex_descriptor, const typename boost::graph_traits<Graph>::vertex_descriptor); template <class Graph> friend void setup_gibbs_proposal (Graph &, const typename boost::graph_traits<Graph>::vertex_descriptor); template <class Graph> friend void setup_tree_gibbs_sampler(Graph &, std::vector < std::vector <double> > &, std::vector <Potential * > &); template <class Graph> friend void cleanup_tree_gibbs_sampler(Graph &, std::vector <Potential *> & ); private: void add_rao_blackwell_estimates( std::vector <double > &) const; void set_rao_blackwell_estimates( const std::vector <double > &); RanVar * variable_ptr; Pot * potential_ptr; unsigned int joint_sampled_value; Potential * message_potential; ChainedSingleDiscretePotential * sample_proposal;};template <class RanVar, class Pot, class VertexDescriptor >inline RanVar * MRFMessageNode < RanVar, Pot, VertexDescriptor >::get_random_variable() const{ return variable_ptr;}template <class RanVar, class Pot, class VertexDescriptor >inline Pot * MRFMessageNode<RanVar, Pot, VertexDescriptor >::get_potential() const{ return potential_ptr;}template <class RanVar, class Pot, class VertexDescriptor >inline void MRFMessageNode<RanVar, Pot, VertexDescriptor >::set_potential( Potential * p){ potential_ptr = static_cast < Pot * > (p);}// ***************** End of MRFMessageNode Declaration ****************//// ***************** VariableMessageNode Declaration ****************//template <class RanVar, class Pot, class VertexDescriptor>class VariableMessageNode : public MessageNode <RanVar, VertexDescriptor>{ public: // Constructors, destructor VariableMessageNode (RanVar *, Pot *); ~VariableMessageNode (); // Inherited virtual methods void initialize(); Message send_message(MessageNode <RanVar, VertexDescriptor > &); void sample_from_proposal(bool); void sample_from_joint_root (); void sample_from_joint (MessageNode <RanVar, VertexDescriptor > &); void serialize_marginals(std::ofstream &) const; void set_variable_to_joint_sampled_value(); void compute_marginals_from_samples(); void compute_marginals_from_messages(); // Deprecated methods RanVar * get_random_variable() const; Pot * get_potential() const; // Friends template <class R, class P, class VD> friend class PotentialMessageNode; template <class Graph> friend void setup_gibbs_proposal (Graph &, const typename boost::graph_traits<Graph>::vertex_descriptor); template <class Graph> friend void setup_tree_gibbs_sampler(Graph &, std::vector < std::vector <double> > &, std::vector <Potential *> & ); template <class Graph> friend void cleanup_tree_gibbs_sampler(Graph &, std::vector <Potential *> & ); template <class Graph> friend void remove_temp_observations(Graph &); template <class Graph> friend void set_temp_observations(Graph &); template <class Graph> friend void initialize_graph_variables_random(Graph &); private: void add_rao_blackwell_estimates( std::vector <double > &) const; void set_rao_blackwell_estimates( const std::vector <double > &); RanVar * variable_ptr; Pot * potential_ptr; ChainedSingleDiscretePotential * sample_proposal; unsigned int joint_sampled_value;};template <class RanVar, class Pot, class VertexDescriptor >inline RanVar * VariableMessageNode < RanVar, Pot, VertexDescriptor >::get_random_variable() const{ return variable_ptr;}template <class RanVar, class Pot, class VertexDescriptor >inline Pot * VariableMessageNode <RanVar, Pot, VertexDescriptor >::get_potential() const{ return potential_ptr;}// ***************** End of VariableMessageNode Declaration ****************//// ***************** PotentialMessageNode Declaration ****************//template <class RanVar, class Pot, class VertexDescriptor>class PotentialMessageNode : public MessageNode <RanVar, VertexDescriptor>{public: // Constructors, destructor PotentialMessageNode (const unsigned int, Pot *); ~PotentialMessageNode (); // Inherited virtual methods void initialize (); Message send_message(MessageNode <RanVar, VertexDescriptor > &); void receive_message(Message); // in PotentialMessageNode, we need a special implementation of receive_message, so that we will store the product of the messages void update_message(Message); void sample_from_joint_root (); void sample_from_joint (MessageNode <RanVar, VertexDescriptor > &); void sample_from_proposal(bool); void serialize_marginals(std::ofstream &) const; void set_variable_to_joint_sampled_value(); void compute_marginals_from_samples(); void compute_marginals_from_messages(); Pot * get_potential() const; bool get_phantom() const { return phantom_potential; } // Friends template <class R, class P, class VD> friend class VariableMessageNode; template <class Graph> friend void setup_tree_gibbs_sampler(Graph &, std::vector < std::vector <double> > &, std::vector <Potential * > &); template <class Graph> friend void remove_temp_observations(Graph &); template <class Graph> friend void set_temp_observations(Graph &); template <class Graph> friend void initialize_graph_variables_random(Graph &); template <class Graph, class ColorMap> friend unsigned int check_phantom_potential_status(typename boost::graph_traits<Graph>::vertex_descriptor, ColorMap &, Graph &); //friend bool get_phantom_potential(MessageNode *); private: void add_rao_blackwell_estimates( std::vector <double > &) const; void set_rao_blackwell_estimates( const std::vector <double > &); bool phantom_potential; Pot * potential_ptr; std::vector <double> product_messages; // we need this extra vector to store conveniently the product of received messages };template <class RanVar, class Pot, class VertexDescriptor >inline Pot * PotentialMessageNode < RanVar, Pot, VertexDescriptor >::get_potential() const{ return potential_ptr;}// ***************** End of PotentialMessageNode Declaration ****************//#endif
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -