📄 node.cc.svn-base
字号:
// Here there may be a problem with the index, but I don't think so. Message incoming_message( 0, *variable_ptr); double a(0.0); while ( variable_ptr->loop_over () ) { //We don't even have to set the potential to the correct value as it is done... in sum_product_messages() a = parent_node.potential_ptr->sum_product_messages_fixed ( * variable_ptr, parent_node.product_messages); // WE NEED TO DIVIDE !!!!!! a = a / parent_node.get_message_value (* variable_ptr ); incoming_message.set_value( * variable_ptr, a ); } incoming_message.reduce(); //incoming_message.display(); //cout << "Sampling from the variable " << this->index << " with probabilities: " ; // For the following computations we add this message to our list this->messages_list.push_back(incoming_message); while (variable_ptr->loop_over()) { potential_ptr->set_variable_value( *variable_ptr); ExtractFromMessage ext(* variable_ptr); a = potential_ptr->get_potential_value() * make_product_over (this->messages_list.begin(), this->messages_list.end(), ext ) ; variable_ptr->set_current_probability(a); //cout << a << ", " ; } variable_ptr->normalize_probabilities(); joint_sampled_value = variable_ptr->sample(); //cout << ". Sampled: " << joint_sampled_value << endl; // We must delete the "artificial message" that we just received this->messages_list.pop_back(); // At the end don't forget to send information that the current Variable Node has been sampled, to the parent potential (we should do that with the child as well) parent_node.potential_ptr->set_observed_variable( *variable_ptr);}template <class RanVar, class Pot, class VertexDescriptor >void VariableMessageNode<RanVar, Pot, VertexDescriptor>::sample_from_joint_root(){ double a(0.0); //cout << "Sampling from the root, " << this->index << " with probabilities: " ; while (variable_ptr->loop_over()) { potential_ptr->set_variable_value( *variable_ptr); ExtractFromMessage ext(* variable_ptr); //cout << " ( " << potential_ptr->get_potential_value() << " ) "; //cout << make_product_over (this->messages_list.begin(), this->messages_list.end(), ext ) << endl; a = potential_ptr->get_potential_value() * make_product_over (this->messages_list.begin(), this->messages_list.end(), ext ); variable_ptr->set_current_probability(a); //cout << a << ", " ; } // Don't forget to normalize variable_ptr->normalize_probabilities(); joint_sampled_value = variable_ptr->sample(); //cout << ". Sampled: " << joint_sampled_value << endl; }/*template <class RanVar, class Pot, class VertexDescriptor >void VariableMessageNode<RanVar, Pot, VertexDescriptor>::sample(){ joint_sampled_value = variable_ptr->sample();}*/template <class RanVar, class Pot, class VertexDescriptor >void VariableMessageNode<RanVar, Pot, VertexDescriptor>::sample_from_proposal(bool record){ while (this->variable_ptr->loop_over() ) { sample_proposal->set_variable_value( *this->variable_ptr ); this->variable_ptr->set_current_probability(sample_proposal->get_potential_value()); } this->variable_ptr->normalize_probabilities(); if ( record) { this->variable_ptr->sample(); } else { this->variable_ptr->sample_without_recording(); } // Don't forget to set all the potentials to the newly sampled value now sample_proposal->set_variable_value( *this->variable_ptr ); }template <class RanVar, class Pot, class VertexDescriptor >void VariableMessageNode<RanVar, Pot, VertexDescriptor>::compute_marginals_from_messages(){ double a; while ( variable_ptr->loop_over() ) { potential_ptr->set_variable_value( *variable_ptr); ExtractFromMessage ext(* variable_ptr); a = potential_ptr->get_potential_value() * make_product_over(this->messages_list.begin(), this->messages_list.end(), ext); //cout << potential_ptr->get_potential_value() << endl; variable_ptr->set_current_probability(a); } variable_ptr->normalize_probabilities();}template <class RanVar, class Pot, class VertexDescriptor >void VariableMessageNode < RanVar, Pot, VertexDescriptor >::compute_marginals_from_samples(){ variable_ptr->obtain_inference_from_prior_samples();}template <class RanVar, class Pot, class VertexDescriptor >void VariableMessageNode < RanVar, Pot, VertexDescriptor >::set_variable_to_joint_sampled_value(){ // This is completely full of bugs // Normally there is no need to check if variable is observed or not, since at worse if RV is observed, we can only get its observed value as a joint sample variable_ptr->last_sampled_value = joint_sampled_value;}template <class RanVar, class Pot, class VertexDescriptor >void VariableMessageNode < RanVar, Pot, VertexDescriptor >::serialize_marginals(std::ofstream & of) const{ variable_ptr->serialize_marginals(of);}template <class RanVar, class Pot, class VertexDescriptor >void VariableMessageNode < RanVar, Pot, VertexDescriptor >::add_rao_blackwell_estimates(vector <double > & estimates) const{ transform( estimates.begin(), estimates.end(), variable_ptr->sampling_probabilities.begin(), estimates.begin(), plus<double>());}template <class RanVar, class Pot, class VertexDescriptor >void VariableMessageNode < RanVar, Pot, VertexDescriptor >::set_rao_blackwell_estimates(const vector <double > & estimates){ copy (estimates.begin(), estimates.end(), variable_ptr->sampling_probabilities.begin()); variable_ptr->normalize_probabilities();}// ********************** End of VariableMessageNode Implementation **************//// ********************** PotentialMessageNode Implementation **************//template <class RanVar, class Pot, class VertexDescriptor >PotentialMessageNode < RanVar, Pot, VertexDescriptor >::~PotentialMessageNode (){}// In this constructor we correctly initialize the product_messages to 1.0template <class RanVar, class Pot, class VertexDescriptor >PotentialMessageNode < RanVar, Pot, VertexDescriptor >::PotentialMessageNode (const unsigned int idx, Pot * a) : MessageNode < RanVar, VertexDescriptor >(idx), phantom_potential(false), potential_ptr(a), product_messages(a->get_table_size(), 1.0){}// Inherited virtual methodstemplate <class RanVar, class Pot, class VertexDescriptor >void PotentialMessageNode<RanVar, Pot, VertexDescriptor>::initialize(){ // Here we should reset the fixed_values vectors in our potential, AND reset our product_messages... NO // fill(potential_ptr->fixed_variable.begin(), potential_ptr->fixed_variable.end(), false); fill(product_messages.begin(), product_messages.end(), 1.0); MessageNode<RanVar, VertexDescriptor>::initialize();}template <class RanVar, class Pot, class VertexDescriptor >Message PotentialMessageNode < RanVar, Pot, VertexDescriptor >::send_message(MessageNode < RanVar, VertexDescriptor > & destination_node_uncasted){ // We immediately static_cast to the right type VariableMessageNode<RanVar, Pot, VertexDescriptor> & destination_node = static_cast < VariableMessageNode<RanVar, Pot, VertexDescriptor> & > (destination_node_uncasted); Message outgoing_message( this->index, * destination_node.variable_ptr); double sum_product(0.0); // We first loop over the *destination RV* values //this->potential_ptr->display_table(); while ( destination_node.variable_ptr->loop_over () ) { //We don't even have to set the potential to the correct value as it is done... in sum_product_messages() sum_product = this->potential_ptr->sum_product_messages_fixed ( * destination_node.variable_ptr, product_messages); //cout << "SP: " << sum_product << endl; // We now divide the result by the value of the message coming from the destination variable // Note that we could not multiply in the first place, but currently we multiply all the way and then divide sum_product = sum_product / get_message_value (* destination_node.variable_ptr ); outgoing_message.set_value( * destination_node.variable_ptr, sum_product ); } outgoing_message.reduce(); //outgoing_message.display(); return outgoing_message;}// An implementation storing the product of messages received, in such a way that it is correctly linked to the underlying potential// table. This implementation needs to be double checked for errors...template <class RanVar, class Pot, class VertexDescriptor >void PotentialMessageNode < RanVar, Pot, VertexDescriptor >::receive_message(Message m){ this->messages_list.push_back(m); // need size (easy) // need left_block_size, that's all unsigned int left_block_size, number_values; potential_ptr->obtain_numbers(m.get_index(), left_block_size, number_values); // here the two last arguments are passed by reference double product; unsigned int k = 0; unsigned int i = 0; while ( i < product_messages.size()) { product = m.get_value(k); for (unsigned int j = 0; j < left_block_size; ++j) { product_messages [i] = product_messages [i] * product; ++i; } ++k; if (k == number_values) // when we attain the number of values { k = 0; } } /* cout << "Displaying product_messages: "; for (unsigned int i = 0; i < product_messages.size(); i++) { cout << product_messages [i] << ", "; } cout << endl; */}template <class RanVar, class Pot, class VertexDescriptor >void PotentialMessageNode < RanVar, Pot, VertexDescriptor >::update_message(Message msg){ unsigned int new_index = msg.get_index(); list < Message >::iterator it = find_if (this->messages_list.begin(), this->messages_list.end(), boost::bind (equal_to<unsigned int>(), boost::bind( & Message::get_index, _1), new_index) ); if ( it != this->messages_list.end() ) { unsigned int left_block_size, number_values; potential_ptr->obtain_numbers(it->get_index(), left_block_size, number_values); // here the two last arguments are passed by reference double product; unsigned int k = 0; unsigned int i = 0; while ( i < product_messages.size()) { product = it->get_value(k); for (unsigned int j = 0; j < left_block_size; ++j) { product_messages [i] = product_messages [i] / product; ++i; } ++k; if (k == number_values) // when we attain the number of values { k = 0; } } this->messages_list.erase(it); } receive_message(msg); }template <class RanVar, class Pot, class VertexDescriptor >void PotentialMessageNode<RanVar, Pot, VertexDescriptor>::compute_marginals_from_messages(){ // We don't need to do anything here}template <class RanVar, class Pot, class VertexDescriptor >void PotentialMessageNode < RanVar, Pot, VertexDescriptor >::compute_marginals_from_samples(){ // We don't need to do anything here}template <class RanVar, class Pot, class VertexDescriptor >void PotentialMessageNode<RanVar, Pot, VertexDescriptor>::sample_from_proposal(bool){ // We don't need to do anything here}template <class RanVar, class Pot, class VertexDescriptor >void PotentialMessageNode < RanVar, Pot, VertexDescriptor >::serialize_marginals(std::ofstream &) const{ // We don't need to do anything here}template <class RanVar, class Pot, class VertexDescriptor >void PotentialMessageNode<RanVar, Pot, VertexDescriptor>::sample_from_joint(MessageNode <RanVar, VertexDescriptor > & mn){ // We DO need to do something here: get the sampled value from our parent potential_ptr->set_observed_variable( * static_cast< VariableMessageNode<RanVar, Pot, VertexDescriptor> & > (mn).variable_ptr);}template <class RanVar, class Pot, class VertexDescriptor >void PotentialMessageNode<RanVar, Pot, VertexDescriptor>::sample_from_joint_root(){ // We don't need to do anything here}template <class RanVar, class Pot, class VertexDescriptor >void PotentialMessageNode < RanVar, Pot, VertexDescriptor >::set_variable_to_joint_sampled_value(){ // We don't need to do anything here}template <class RanVar, class Pot, class VertexDescriptor >void PotentialMessageNode < RanVar, Pot, VertexDescriptor >::add_rao_blackwell_estimates(vector <double > &) const{ // We don't need to do anything here}template <class RanVar, class Pot, class VertexDescriptor >void PotentialMessageNode < RanVar, Pot, VertexDescriptor >::set_rao_blackwell_estimates(const vector <double > &){ // We don't need to do anything here}// ********************** End of PotentialMessageNode Implementation **************//// Explicit Template Instantiations //typedef boost::adjacency_list_traits<vecS, vecS, undirectedS>::vertex_descriptor VertexDescriptor;template class MessageNode <DiscreteRandomVariable, VertexDescriptor>;template class MRFMessageNode <DiscreteRandomVariable, DiscretePotential, VertexDescriptor>;template class VariableMessageNode <DiscreteRandomVariable, DiscretePotential, VertexDescriptor>;template class PotentialMessageNode <DiscreteRandomVariable, DiscretePotential, VertexDescriptor>;// End Template Instantiations //
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -