⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 experiment.cc.svn-base

📁 Probabilistic graphical models in matlab.
💻 SVN-BASE
字号:
#include <Experiment.h>#include <Potential.h>#include <Node.h>#include <RandomVariable.h>#include <ParticleFilters.h>#include <OutputGraph.h>#include <InputFunctions.h>#include <GraphAlgorithms.h>#include <iostream>#include <fstream>#include <boost/graph/adjacency_list.hpp>using namespace std;using namespace boost;//******* Run Implementation *************//void Run::record_errors (double a){	errors.push_back(a);}void Run::record_previous_errors (){	if (errors.empty() )	{		errors.push_back(-1);	}	else	{		errors.push_back(errors.back());	}}bool Run::export_run_errors(ofstream & output_file) const{	//ofstream output_file (output_file_name.c_str());			for (vector<double>::const_iterator it = errors.begin(); it != errors.end(); ++it)	{		output_file << *it << " ";	}		return true;}//******* End of Run Implementation *************////******* Experiment Implementation *************//// ConstructorExperiment::Experiment(): inference_method(exact_belief_propagation), using_rao_blackwell(true), using_timer (false), sampling_interval(0.2), next_sample_time(sampling_interval), number_runs(1), only_partition(false), strict_error_timing(true){}// Normal methodsvoid Experiment::callback_monitoring(const double time){	if (monitoring == false)	{		return;	}		if (using_timer)	{		if (time > next_sample_time)		{			next_sample_time += sampling_interval;						while (time > next_sample_time)			{				next_sample_time += sampling_interval;				current_run->record_previous_errors();			}						record_errors();		}	}		else		{		record_errors();	}}void Experiment::record_errors(){	double current_error = model->compute_error_to_reference();		current_run->record_errors(current_error);}bool Experiment::unserialize_method(ifstream & input_file){	bool method_found = false;	string code_string;		while (input_file.get() != '(' && !input_file.eof())	{	}		if (input_file.eof())	{		cout << "Command Tag with no data tag" << endl;		return false;	}		getline(input_file, code_string, ')');		if (code_string.compare("loopy") == 0)	{		inference_method = loopy_belief_propagation;		method_found = true;				while (input_file.get() != '(' && !input_file.eof())		{		}				if (input_file.eof())		{			cout << "We needed an argument for Loopy." << endl;			return false;		}				if (!(input_file >> steps))		{			input_file.clear();			cout << "Malformed loopy argument data tag" << endl;			return false;		}				while (input_file.get() != ')' && !input_file.eof())		{		}			}		if (code_string.compare("gibbs") == 0)	{		inference_method = simple_gibbs_sampling;		method_found = true;				while (input_file.get() != '(' && !input_file.eof())		{		}				if (input_file.eof())		{			cout << "We needed an argument for Gibbs." << endl;			return false;		}				if (!(input_file >> steps))		{			input_file.clear();			cout << "Malformed Gibbs argument data tag" << endl;			return false;		}				while (input_file.get() != ')' && !input_file.eof())		{		}			}		if (code_string.compare("tree_mcmc") == 0)	{		inference_method = tree_gibbs_sampling;		method_found = true;				while (input_file.get() != '(' && !input_file.eof())		{		}				if (input_file.eof())		{			cout << "We need an argument for Tree Gibbs." << endl;			return false;		}				if (!(input_file >> steps))		{			input_file.clear();			cout << "Malformed Tree Gibbs argument data tag" << endl;			return false;		}				while (input_file.get() != ')' && !input_file.eof())		{		}			}		if (code_string.compare("tree_mcmc_no_rb") == 0)	{		inference_method = tree_gibbs_sampling;		method_found = true;		using_rao_blackwell = false;				while (input_file.get() != '(' && !input_file.eof())		{		}				if (input_file.eof())		{			cout << "We need an argument for Tree Gibbs (without Rao-Blackwellization)." << endl;			return false;		}				if (!(input_file >> steps))		{			input_file.clear();			cout << "Malformed Tree Gibbs argument data tag" << endl;			return false;		}				while (input_file.get() != ')' && !input_file.eof())		{		}			}		if (code_string.compare("exact_bp") == 0)	{		inference_method = exact_belief_propagation;		method_found = true;	}		if (code_string.compare("tree_partition") == 0)	{		inference_method = tree_partition;		method_found = true;		only_partition = true;	}		if (!method_found)	{		cout << "Unknown method or Malformed data tag" << endl;		return false;	}		if (input_file.eof())	{		cout << "Malformed data tag" << endl;		return false;	}		return true;}bool Experiment::set_up_model_from_file(const std::string & input_file_name){	// We must just need to get from the file		string code_string;	unsigned int a, b;	unsigned int number_vertices(0);	bool no_command_tag(true);	bool type_found(false);	bool method_found(false);		ifstream input_file (input_file_name.c_str());		if (!input_file.good())	{		cout << "Failed to open the file '" << input_file_name << "'" << endl;		return false;	}		while (true)	{				while ( (input_file.get() != '<') && (input_file.good()) )		{		}				if (input_file.eof())		{			if (no_command_tag)			{				cout << "No valid Command Tag found" << endl;				return false;			}			break;		}				getline(input_file, code_string, '>');				if ( code_string.compare("number of runs") == 0)		{			no_command_tag = false;						while (input_file.get() != '(' && !input_file.eof())			{			}						if (input_file.eof())			{				cout << "Command Tag with no data tag" << endl;				return false;			}						if (!(input_file >> a))			{				input_file.clear();				cout << "Malformed 'number of runs' data tag" << endl;				return false;			}						else			{				number_runs = a;			}						while (input_file.get() != ')' && !input_file.eof())			{			}						if (input_file.eof())			{				cout << "Malformed data tag" << endl;				return false;			}		}				if ( code_string.compare("number of vertices") == 0)		{			no_command_tag = false;						while (input_file.get() != '(' && !input_file.eof())			{			}						if (input_file.eof())			{				cout << "Command Tag with no data tag" << endl;				return false;			}						if (!(input_file >> a))			{				input_file.clear();				cout << "Malformed 'number of vertices' data tag" << endl;				return false;			}						else			{				number_vertices = a;			}						while (input_file.get() != ')' && !input_file.eof())			{			}						if (input_file.eof())			{				cout << "Malformed data tag" << endl;				return false;			}		}				if ( code_string.compare("structure") == 0)		{			char c = input_file.get();						while ( !((c == '(') || (c == '<')) && (!input_file.eof()) )			{				c = input_file.get();			}						if (input_file.eof())			{				return true;			}						if (c == '<')			{				input_file.unget();				return true;			}						getline(input_file, code_string, ')');						if (code_string.compare("square-lattice") == 0)			{				while ( (input_file.get() != '(')  && (!input_file.eof()) )				{				}								if (input_file.eof())				{					cout << "Error, no parameters given to square-lattice" << endl;					return false;				}								if (!(input_file >> a >> b))				{					input_file.clear();					cout << "Error: Malformed 'structure' data tag" << endl;					return false;				}				else				{					number_vertices = a * b;				}			}					}				if ( code_string.compare("model type") == 0)		{			no_command_tag = false;						while (input_file.get() != '(' && !input_file.eof())			{			}						if (input_file.eof())			{				cout << "Command Tag with no data tag" << endl;				return false;			}						getline(input_file, code_string, ')');						if (code_string.compare("pairwise-discrete") == 0)			{				type_found = true;				cout << "Creating Model with " << number_vertices << " vertices." << endl;				model = new DiscreteProbabilisticModel(number_vertices);				//cout << "Malformed 'number of vertices' data tag" << endl;				//return false;			}						if (code_string.compare("factorgraph-discrete") == 0)			{				type_found = true;				cout << "Creating Model with " << number_vertices << " vertices." << endl;				model = new DiscreteFactorGraphProbabilisticModel(number_vertices);				//cout << "Malformed 'number of vertices' data tag" << endl;				//return false;			}						if (input_file.eof())			{				cout << "Malformed data tag" << endl;				return false;			}		}				if ( code_string.compare("method") == 0)		{			no_command_tag = false;			method_found = true;						if (!unserialize_method(input_file))			{				return false;				}		}				if ( code_string.compare("monitor") == 0)		{			no_command_tag = false;			cout << "Found command tag 'monitor'" << endl;			monitoring = true;			current_experiment = this;		}				if ( code_string.compare("timer") == 0)		{			no_command_tag = false;			cout << "Found command tag 'timer'" << endl;			using_timer = true;		}			}		input_file.close();		if (type_found && method_found)	{		model->unserialize(input_file_name);		return true;	}		return false;}bool Experiment::export_results_to_file(const std::string & output_file_name){	ofstream output_file (output_file_name.c_str());		if (!only_partition)	{	   model->serialize(output_file);	}		if (monitoring)	{		serialize_errors(output_file);	}		if (only_partition)	{		model->serialize_partition(output_file);	}		return true;}bool Experiment::serialize_errors(ofstream & output_file){		unsigned int errors_size = runs.front().errors.size();		if (strict_error_timing)	{		errors_size = steps;				if (using_timer)		{			errors_size = (unsigned int) (steps / sampling_interval);			//cout << errors_size << endl;		}			for( vector <Run>::iterator it = runs.begin(); it != runs.end(); ++it)	{		it->errors.resize(errors_size -1);	}			}			for ( vector <Run>::iterator it = runs.begin(); it != runs.end(); ++it)	{		//cout << "ES " << it->errors.size() << endl;				if (it->errors.size() < errors_size )		{			errors_size = it->errors.size();		}	}		for ( vector <Run>::iterator it = runs.begin(); it != runs.end(); ++it)	{		it->errors.resize(errors_size);	}		// Additional check		errors_size = runs.front().errors.size();		for ( vector <Run>::iterator it = runs.begin(); it != runs.end(); ++it)	{		assert(it->errors.size() == errors_size);	}		output_file << "<exported errors> ";		unsigned int i(0);				for ( vector <Run>::iterator it = runs.begin(); it != runs.end(); ++it)		{			++i;						output_file << "( " << i << " ) ( ";						it->export_run_errors(output_file);						//computed_errors[i] += it->errors[i];						output_file << ") ";		}			return true;}void Experiment::get_inference(){	// Here we should get inference based on some parameter of Experiment		//current_run = runs.begin();		for (unsigned int i = 0; i < number_runs; ++i)	{		runs.push_back(Run());				current_run =  & runs.back();				next_sample_time =  sampling_interval;				//cout << "Making run " << i+1 << " of " << number_runs << endl;				//cout << current_run << endl;				switch (inference_method)		{			case exact_belief_propagation:			{				model->get_inference_mp();			};								break;							case loopy_belief_propagation:			{				model->get_inference_loopy(steps, using_timer);			};								break;							case simple_gibbs_sampling:			{				model->get_inference_gibbs(steps, using_timer);			};								break;							case tree_gibbs_sampling:			{				model->get_inference_tree_gibbs(steps, using_rao_blackwell, using_timer);			};								break;							case tree_partition:			{				model->get_tree_partition();			};								break;			}	}}//******* End of Experiment Implementation *************//

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -