⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 node.cc.svn-base

📁 Probabilistic graphical models in matlab.
💻 SVN-BASE
📖 第 1 页 / 共 2 页
字号:
#include <boost/graph/adjacency_list.hpp>#include <boost/graph/graph_traits.hpp>#include <boost/mem_fn.hpp>#include <Node.h>#include <Generic.h>#include <RandomVariable.h>#include <Potential.h>#include <Message.h>using namespace std;using namespace boost;// ***************** Node Implementation************************//Node::Node() : name("default_name"), index(0){}Node::Node(const unsigned int a) : name("default_name"), index(a){}// ***************** End of Node Implementation************************//// ***************** GraphicalModelNode Implementation************************//GraphicalModelNode::GraphicalModelNode (RandomVariable * const rv, Potential * const  p) : variable_ptr(rv), potential_ptr (p){}// ***************** End of GraphicalModelNode Implementation************************//// ***************** MessageNode Implementation************************//// Constructors, destructortemplate <class RanVar, class VertexDescriptor >MessageNode< RanVar, VertexDescriptor >::MessageNode (){}template <class RanVar, class VertexDescriptor >MessageNode< RanVar, VertexDescriptor >::MessageNode (const unsigned int a): Node(a){	}template <class RanVar, class VertexDescriptor >MessageNode< RanVar, VertexDescriptor >::~MessageNode(){}// Normal methodstemplate <class RanVar, class VertexDescriptor >void MessageNode <RanVar, VertexDescriptor >::initialize(){	// For now we only clear the list of Messages		messages_list.clear();	}// Virtual functions followtemplate <class RanVar, class VertexDescriptor >void MessageNode < RanVar, VertexDescriptor >::receive_message(Message m){		messages_list.push_back(m);}template <class RanVar, class VertexDescriptor >void MessageNode<RanVar, VertexDescriptor>::update_message(Message msg){		unsigned int new_index = msg.get_index();		list < Message >::iterator it = find_if (this->messages_list.begin(), this->messages_list.end(), boost::bind (equal_to<unsigned int>(), boost::bind( & Message::get_index, _1), new_index) );			 	if ( it != this->messages_list.end() )	{		*it = msg;		}		else			{		this->messages_list.push_back(msg);	}}template <class RanVar, class VertexDescriptor >RanVar * MessageNode <RanVar,  VertexDescriptor >::get_random_variable() const{	cout << "ERROR (FATAL): We called get_random_variable() on the base class MessageNode, that should never happen." << endl;	return NULL;}template <class RanVar, class VertexDescriptor >void MessageNode <RanVar,  VertexDescriptor >::set_potential(Potential *){	cout << "ERROR (FATAL): We called set_potential() on the base class MessageNode, that should never happen." << endl;	}template <class RanVar, class VertexDescriptor >double MessageNode <RanVar,  VertexDescriptor >::get_message_value( RanVar & rv){	list < Message >::iterator it = find_if (this->messages_list.begin(), this->messages_list.end(), boost::bind (equal_to<unsigned int>(), boost::bind( & Message::get_index, _1), rv.get_index()) );		if (it != this->messages_list.end())	{		return it->get_value(rv);	}		else	{		return 1.0;	}}// *********************** End of MessageNode Implementation ********************** //// *********************** MRFMessageNode Implementation **************************//// Constructor, destructortemplate <class Var, class Pot, class VertexDescriptor >MRFMessageNode < Var, Pot, VertexDescriptor >::MRFMessageNode (Var * v = NULL, Pot * p = NULL) : variable_ptr(v), potential_ptr(p), sample_proposal(NULL){}template <class Var, class Pot, class VertexDescriptor >MRFMessageNode < Var, Pot, VertexDescriptor >::~MRFMessageNode(){	delete sample_proposal;	sample_proposal = NULL;}// Virtual inherited methodstemplate <class RanVar, class Pot, class VertexDescriptor >void MRFMessageNode<RanVar, Pot, VertexDescriptor>::initialize(){	delete sample_proposal;	sample_proposal = NULL;		MessageNode<RanVar, VertexDescriptor>::initialize();}template <class RanVar, class Pot, class VertexDescriptor >void MRFMessageNode<RanVar, Pot, VertexDescriptor>::sample_from_proposal(bool record){			while (this->variable_ptr->loop_over() )	{		sample_proposal->set_variable_value( *this->variable_ptr );					this->variable_ptr->set_current_probability(sample_proposal->get_potential_value());	}		//cout << "Unnormalized";		//this->variable_ptr->debug_display_probabilities();		this->variable_ptr->normalize_probabilities();		//this->variable_ptr->debug_display_probabilities();		if ( record)	{		this->variable_ptr->sample();	}		else 	{		this->variable_ptr->sample_without_recording();	}		// Don't forget to set all the potentials to the newly sampled value now		sample_proposal->set_variable_value( *this->variable_ptr );		/*	 this->potential_ptr->set_variable_value( *this->variable_ptr );	 	 for ( list <Potential * >::iterator current_edge = sample_proposal.begin(); current_edge != sample_proposal.end(); ++current_edge)	 {		 (*current_edge)->set_variable_value( *this->variable_ptr );	 }	 */}template <class RanVar, class Pot, class VertexDescriptor >Message MRFMessageNode<RanVar, Pot, VertexDescriptor>::send_message(MessageNode <RanVar, VertexDescriptor > & destination_node_uncasted ){	// We immediately static_cast to the right type		MRFMessageNode<RanVar, Pot, VertexDescriptor> & destination_node = static_cast < MRFMessageNode<RanVar, Pot, VertexDescriptor> & > (destination_node_uncasted);		Message outgoing_message(* this->variable_ptr, * destination_node.variable_ptr);		double sum(0.0);		// We first loop over the *destination RV* values		while ( destination_node.variable_ptr->loop_over () )	{		this->message_potential->set_variable_value(* destination_node.variable_ptr );				sum = 0.0;				// Then we sum over the origin RV values				while ( this->variable_ptr->loop_over () )		{			this->message_potential->set_variable_value(* this->variable_ptr );			this->potential_ptr->set_variable_value(* this->variable_ptr );			ExtractFromMessage ext(* this->variable_ptr, * destination_node.variable_ptr);						sum += this->potential_ptr->get_potential_value() * message_potential->get_potential_value() * make_product_over (this->messages_list.begin(), this->messages_list.end(), ext );		}				outgoing_message.set_value( * destination_node.variable_ptr, sum);			}		outgoing_message.reduce();		return outgoing_message;	}template <class RanVar, class Pot, class VertexDescriptor >void MRFMessageNode <RanVar, Pot, VertexDescriptor>::compute_marginals_from_messages(){	double a(0.0);		while ( variable_ptr->loop_over() )	{		potential_ptr->set_variable_value( *variable_ptr);				ExtractFromMessage ext(* variable_ptr);				a = potential_ptr->get_potential_value() * make_product_over(this->messages_list.begin(), this->messages_list.end(), ext);				variable_ptr->set_current_probability(a);	}		variable_ptr->normalize_probabilities();}template <class RanVar, class Pot, class VertexDescriptor >void MRFMessageNode < RanVar, Pot, VertexDescriptor >::compute_marginals_from_samples(){	variable_ptr->obtain_inference_from_prior_samples();}// We have to sample *BEFORE* we receive the true Message (for the Rao Blackwell estimator), since we will here use the full list of messagestemplate <class RanVar, class Pot, class VertexDescriptor >void MRFMessageNode<RanVar, Pot, VertexDescriptor>::sample_from_joint(MessageNode <RanVar, VertexDescriptor > & parent_node_uncasted){	MRFMessageNode<RanVar, Pot, VertexDescriptor> & parent_node = static_cast < MRFMessageNode<RanVar, Pot, VertexDescriptor> & > (parent_node_uncasted);		parent_node.variable_ptr->last_sampled_value = parent_node.joint_sampled_value;		message_potential->set_variable_value(* parent_node.variable_ptr);		double a (0.0);		while (variable_ptr->loop_over())	{		message_potential->set_variable_value(*variable_ptr);		potential_ptr->set_variable_value(*variable_ptr);		ExtractFromMessage ext(* variable_ptr);				a = potential_ptr->get_potential_value() * message_potential->get_potential_value() * make_product_over (this->messages_list.begin(), this->messages_list.end(), ext );				variable_ptr->set_current_probability(a);	}		// Don't forget to normalize		variable_ptr->normalize_probabilities();		joint_sampled_value = variable_ptr->sample();		//cout << "Sampled: " << joint_sampled_value << endl;}template <class RanVar, class Pot, class VertexDescriptor >void MRFMessageNode<RanVar, Pot, VertexDescriptor>::sample_from_joint_root(){	double a(0.0);		while (variable_ptr->loop_over())	{		potential_ptr->set_variable_value( *variable_ptr);				ExtractFromMessage ext(* variable_ptr);				a = potential_ptr->get_potential_value()  * make_product_over (this->messages_list.begin(), this->messages_list.end(), ext );				variable_ptr->set_current_probability(a);	}		// Don't forget to normalize		variable_ptr->normalize_probabilities();	joint_sampled_value = variable_ptr->sample();		// cout << "Sampled (root): " << joint_sampled_value << endl;}template <class RanVar, class Pot, class VertexDescriptor >void MRFMessageNode < RanVar, Pot, VertexDescriptor >::set_variable_to_joint_sampled_value(){	// Normally there is no need to check if variable is observed or not, since at worse if RV is observed, we can only get its observed value as a joint sample		variable_ptr->last_sampled_value = joint_sampled_value;		// Here we can do that ourselves since there is only one thing to do; set the potential_ptr to the sampled value; I don't even think we have to mark it as observed		potential_ptr->set_variable_value(*variable_ptr);}template <class RanVar, class Pot, class VertexDescriptor >void MRFMessageNode < RanVar, Pot, VertexDescriptor >::serialize_marginals(std::ofstream & of) const{	variable_ptr->serialize_marginals(of);}template <class RanVar, class Pot, class VertexDescriptor >void MRFMessageNode < RanVar, Pot, VertexDescriptor >::add_rao_blackwell_estimates(vector <double > & estimates) const{	transform( estimates.begin(), estimates.end(), variable_ptr->sampling_probabilities.begin(), estimates.begin(), plus<double>());}template <class RanVar, class Pot, class VertexDescriptor >void MRFMessageNode < RanVar, Pot, VertexDescriptor >::set_rao_blackwell_estimates(const vector <double > & estimates){	copy (estimates.begin(), estimates.end(), variable_ptr->sampling_probabilities.begin());	variable_ptr->normalize_probabilities();}// ********************** End of MRFMessageNode Implementation **************//// ********************** VariableMessageNode Implementation **************//// Constructors, destructortemplate <class Var, class Pot, class VertexDescriptor >VariableMessageNode < Var, Pot, VertexDescriptor >::VariableMessageNode (Var * a = NULL, Pot * b = NULL) : MessageNode < Var, VertexDescriptor > (a->index), variable_ptr(a), potential_ptr(b), sample_proposal(NULL){	//this->index = variable_ptr->index;}template <class Var, class Pot, class VertexDescriptor >VariableMessageNode < Var, Pot, VertexDescriptor >::~VariableMessageNode (){	delete sample_proposal;	sample_proposal = NULL;}// Inherited virtual methodstemplate <class RanVar, class Pot, class VertexDescriptor >void VariableMessageNode<RanVar, Pot, VertexDescriptor>::initialize(){	delete sample_proposal;	sample_proposal = NULL;		MessageNode<RanVar, VertexDescriptor>::initialize();}template < class Var, class Pot, class VertexDescriptor >Message VariableMessageNode < Var, Pot, VertexDescriptor >::send_message(MessageNode <Var, VertexDescriptor > & destination_node){	double product (0.0);		// The outgoing message needs to have the index of the current variable, and the size of the current variable too, so we send twice the current variable to the Message constructor		Message outgoing_message(* this->variable_ptr, * this->variable_ptr);		const unsigned int destination_index = static_cast < PotentialMessageNode < Var, Pot, VertexDescriptor > & > (destination_node).index;		while ( this->variable_ptr->loop_over () )	{		this->potential_ptr->set_variable_value(* this->variable_ptr );				ExtractFromMessage ext(* this->variable_ptr, destination_index);				//cout << this->potential_ptr->get_potential_value() << endl;		//cout << "product of msgs: " << make_product_over (this->messages_list.begin(), this->messages_list.end(), ext ) << endl;				product = this->potential_ptr->get_potential_value() * make_product_over (this->messages_list.begin(), this->messages_list.end(), ext );				outgoing_message.set_value( * this->variable_ptr, product);	}		outgoing_message.reduce();		//outgoing_message.display();		return outgoing_message;}template <class RanVar, class Pot, class VertexDescriptor >void VariableMessageNode<RanVar, Pot, VertexDescriptor>::sample_from_joint(MessageNode <RanVar, VertexDescriptor > & parent_node_uncasted ){	// We cast to a Potential Message Node		PotentialMessageNode<RanVar, Pot, VertexDescriptor> & parent_node = static_cast < PotentialMessageNode<RanVar, Pot, VertexDescriptor> & > (parent_node_uncasted);	

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -