⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 discretepotential.cc.svn-base

📁 Probabilistic graphical models in matlab.
💻 SVN-BASE
📖 第 1 页 / 共 2 页
字号:
#include <DiscretePotential.h>#include <utility>#include <iostream>#include <cmath>#include <cassert>using namespace std;// ************** DiscretePotential Implementation ********************//// Constructors, destructorDiscretePotential::DiscretePotential(){}DiscretePotential::DiscretePotential( const vector <DiscreteRandomVariable *> & a){	for ( vector <DiscreteRandomVariable *>::const_iterator it = a.begin(); it != a.end(); ++it)	{		add_variable ( * (*it));	}	}DiscretePotential::~DiscretePotential(){	}// Normal methods// Following function links a Discrete Random Variable to the Potential, however it doesn't fill the potential table, warning...// Basically the idea is that the Potential table should of course be filled once all variables have been linked to the potential.void DiscretePotential::add_variable (const DiscreteRandomVariable & a){		// We have to put that first, because of the ONE difference between the size and the actual index of our variables...		// cout << "a.get_index()" << a.get_index() << ", size: " << number_of_variable_values.size()  << endl;		correspondance_table[ a.get_index() ] = number_of_variable_values.size() ;		if (number_of_variable_values.empty())	{		help_vector.push_back(1);	}	else	{		help_vector.push_back( help_vector.back() * number_of_variable_values.back() );	}		number_of_variable_values.push_back(a.get_number_values());		variable_values.push_back(a.last_sampled_value);		fixed_variable.push_back(false);		//fixed_variable_values.push_back(a.last_sampled_value);		potential_table.resize(help_vector.back() * number_of_variable_values.back(), 0.0);}// Warning: using this function is extra dangerous as it doesn't make any translation on the order of the values given.// IE, it assigns values as variables were entered without any kind of checking between the variable indexes and the inner indexes of the Potentialvoid DiscretePotential::setup_potential_values ( const vector <double> & a){		/*	 if (!potential_table.empty())	 {		 cout << "ERROR (FATAL): the function setup_potential_values() has been called twice. Aborting..." << endl;		 exit (0);	 }	 */			if (a.size() != potential_table.size() )	{		cout << "Mismatch. In order to instantiate the Potential, the exact number of values is requested.\n";		cout << "That value is equal to " << potential_table.size() << endl;		cout << "(I received " << a.size() << ")" << endl;	}	else				// We use an insert adaptator, which means this function shouldn't be used if the potential table has been somehow altered in any way		// NO LONGER USING AN INSERT ADAPTOR			{		//std::copy ( a.begin(), a.end(), inserter (potential_table,potential_table.begin()) );		std::copy ( a.begin(), a.end(), potential_table.begin() );	}		}void DiscretePotential::obtain_numbers(const unsigned int index, unsigned int & left_block_size, unsigned int & number_values){		unsigned int a = correspondance_table[index];		left_block_size = help_vector[a];	number_values = number_of_variable_values[a];		}// The following is a recursive function to sum other all the values of the potential, multiplied by a product of messages// Note that it frozes the sum over certain variables contained in the fixed_values vector// It is recursive and thus probably VERY slow, it should be optimized probablydouble DiscretePotential::sum_product_messages_fixed ( const DiscreteRandomVariable & rv, const std::vector <double> & msg_product){	unsigned int a = correspondance_table[rv.get_index()];	bool backup_bool = fixed_variable [ a ];		fixed_variable [ a ] = true;	variable_values[ a ] = rv.last_sampled_value;		assert (msg_product.size() == get_table_size() );		double b = recursive_internal_sum_product_messages_fixed (0, msg_product);		//cout << "Sum: " << b << endl;		// Warning: maybe the variable was observed before... so we set it to its previous value		fixed_variable [ a ] = backup_bool;		return b;}double DiscretePotential::recursive_internal_sum_product_messages_fixed (unsigned int index, const vector <double> & msg_product){	double sum (0.0);		if (index == number_of_variable_values.size() -1)	{				//cout << "size: " << msg_product.size() << endl;		//cout << "size: " << potential_table.size() << endl;				if (fixed_variable[index])		{			//cout << "Fixed: " << index << endl;							unsigned int a = inner_product(variable_values.begin(), variable_values.end(), help_vector.begin(), 0);						sum += potential_table[a] * msg_product[a];		}				else		{			variable_values[index] = 0;			unsigned int a = inner_product(variable_values.begin(), variable_values.end(), help_vector.begin(), 0);						for (unsigned int i = 0 ; i < number_of_variable_values[index] ; ++i)			{				sum += potential_table[a] * msg_product[a];								a += help_vector.back();			}			}				return sum;			}		if (fixed_variable[index])	{		//cout << "Fixed: " << index << endl;				sum += recursive_internal_sum_product_messages_fixed(index+1, msg_product);	}		else	{		//cout << "Not fixed. " << index << endl;				for (unsigned int i = 0 ; i < number_of_variable_values[index] ; ++i)		{			variable_values[index] = i;						sum += recursive_internal_sum_product_messages_fixed(index+1, msg_product);		}		}		return sum;	}void DiscretePotential::set_observed_variable (const RandomVariable & rv){	unsigned int a = correspondance_table[rv.get_index()];	fixed_variable [ a ] = true;	variable_values[ a ] = static_cast< const DiscreteRandomVariable & > (rv).last_sampled_value;		//cout << "Setting variable " << rv.get_index() << " (" << a << ") " << "to " << variable_values[ a ] << endl;	}void DiscretePotential::unset_observed_variable (const RandomVariable & rv){		unsigned int a = correspondance_table[rv.get_index()];	fixed_variable [ a ] = false;	//cout << "UNsetting variable " << rv.get_index() << " " << a <<  endl;}/* double DiscretePotential::sum_product_messages (const DiscreteRandomVariable & rv, const std::vector <double> & msg_product) {	 set_variable_value(rv);	 return internal_sum_product_messages (  correspondance_table[rv.get_index()], msg_product); } */// This implementation is probably much faster as it is not recursive// But the risk of bugs is much higher// Note also that this won't work with observed variables.../*double DiscretePotential::internal_sum_product_messages(const unsigned int index, const vector <double> & vec){		unsigned int left_block_size,right_block_size,right_blocks, fixed_value;			// Right_blocks is the number of blocks we have to "the right", it is in fact the number of times we must sum in the outer loop	// Right_block_size is the size of the blocks to the right... we add that once we are over in the inner loop		if (index == number_of_variable_values.size()-1)	{		right_blocks = 1;		right_block_size = 0; // Not used in this case	}	else	{		right_blocks = number_of_variable_values[number_of_variable_values.size()-1]*help_vector[number_of_variable_values.size()-1]/help_vector[index+1];		right_block_size = help_vector[index+1];	}		// The fixed value represents how much we must add constantly to read the point where "the sum is frozen"	// Left block size represents the size of the blocks to the left, ie, 1 if we are not summing over the first variable		left_block_size = help_vector[index];	fixed_value = variable_values[index]*left_block_size;		double sum = 0;	unsigned int a = fixed_value;		for (unsigned int i = 0; i < right_blocks; ++i)	{				for (unsigned int j = 0; j < left_block_size; ++j)		{			sum += potential_table[a+j] * vec[a+j];		}				a += right_block_size;			}		return sum;	}*/void DiscretePotential::set_variable_value (const RandomVariable & rv){	variable_values[ correspondance_table[rv.get_index()] ] = static_cast< const DiscreteRandomVariable & > (rv).last_sampled_value;		}/* void DiscretePotential::set_fixed_variable_value (const RandomVariable & rv) {	 unsigned int a = correspondance_table[rv.get_index()]; } */// ******************* End of DiscretePotential Implementation **************//// **************** ChainedSingleDiscretePotential Implementation ***********//// Constructor, destructorChainedSingleDiscretePotential::ChainedSingleDiscretePotential(DiscreteRandomVariable & a) : single_rv(a), precomputed_products(a.get_number_values()), optimized(false){}ChainedSingleDiscretePotential::~ChainedSingleDiscretePotential(){}// Inherited virtual methodsdouble ChainedSingleDiscretePotential::get_potential_value() const{	if (optimized)	{		return precomputed_products [ single_rv.last_sampled_value];			}		else	{				double a (1.0);				for (list < DiscretePotential * >::const_iterator it = potential_lst.begin(); it != potential_lst.end(); ++it)		{			a = a * (*it)->get_potential_value();		}				return a;	}}void ChainedSingleDiscretePotential::set_variable_value(const RandomVariable & rv){	//cout << "Chained potentials: " << potential_lst.size() << endl;		for (list < DiscretePotential * >::iterator it = potential_lst.begin(); it != potential_lst.end(); ++it)	{		(*it)->set_variable_value(rv);	}}void ChainedSingleDiscretePotential::set_observed_variable(const RandomVariable & rv){	for (list < DiscretePotential * >::iterator it = potential_lst.begin(); it != potential_lst.end(); ++it)	{		(*it)->set_observed_variable(rv);	}}void ChainedSingleDiscretePotential::unset_observed_variable(const RandomVariable & rv){	for (list < DiscretePotential * >::iterator it = potential_lst.begin(); it != potential_lst.end(); ++it)	{		(*it)->unset_observed_variable(rv);	}}// Normal methodsvoid ChainedSingleDiscretePotential::add_potential(DiscretePotential * const p){	potential_lst.push_back( p);}void ChainedSingleDiscretePotential::precompute_products(){		while (single_rv.loop_over())	{		set_variable_value(single_rv);		precomputed_products [ single_rv.last_sampled_value] = get_potential_value();	}		optimized = true;}// **************** End of ChainedSingleDiscretePotential Implementation ***********////**************** SingleDiscretePotential Implementation *************//// Constructor, destructorSingleDiscretePotential::SingleDiscretePotential(const DiscreteRandomVariable & rv): variable_value(0), single_potential_table( rv.get_number_values(), 1.0){	fixed_variable.push_back(false);}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -