📄 experiment.cc.svn-base
字号:
#include <Experiment.h>#include <Potential.h>#include <Node.h>#include <RandomVariable.h>#include <ParticleFilters.h>#include <OutputGraph.h>#include <InputFunctions.h>#include <GraphAlgorithms.h>#include <iostream>#include <fstream>#include <boost/graph/adjacency_list.hpp>using namespace std;using namespace boost;//******* Run Implementation *************//void Run::record_errors (double a){ errors.push_back(a);}void Run::record_previous_errors (){ if (errors.empty() ) { errors.push_back(-1); } else { errors.push_back(errors.back()); }}bool Run::export_run_errors(ofstream & output_file) const{ //ofstream output_file (output_file_name.c_str()); for (vector<double>::const_iterator it = errors.begin(); it != errors.end(); ++it) { output_file << *it << " "; } return true;}//******* End of Run Implementation *************////******* Experiment Implementation *************//// ConstructorExperiment::Experiment(): inference_method(exact_belief_propagation), using_rao_blackwell(true), using_timer (false), sampling_interval(0.2), next_sample_time(sampling_interval), number_runs(1), only_partition(false), strict_error_timing(true){}// Normal methodsvoid Experiment::callback_monitoring(const double time){ if (monitoring == false) { return; } if (using_timer) { if (time > next_sample_time) { next_sample_time += sampling_interval; while (time > next_sample_time) { next_sample_time += sampling_interval; current_run->record_previous_errors(); } record_errors(); } } else { record_errors(); }}void Experiment::record_errors(){ double current_error = model->compute_error_to_reference(); current_run->record_errors(current_error);}bool Experiment::unserialize_method(ifstream & input_file){ bool method_found = false; string code_string; while (input_file.get() != '(' && !input_file.eof()) { } if (input_file.eof()) { cout << "Command Tag with no data tag" << endl; return false; } getline(input_file, code_string, ')'); if (code_string.compare("loopy") == 0) { inference_method = loopy_belief_propagation; method_found = true; while (input_file.get() != '(' && !input_file.eof()) { } if (input_file.eof()) { cout << "We needed an argument for Loopy." << endl; return false; } if (!(input_file >> steps)) { input_file.clear(); cout << "Malformed loopy argument data tag" << endl; return false; } while (input_file.get() != ')' && !input_file.eof()) { } } if (code_string.compare("gibbs") == 0) { inference_method = simple_gibbs_sampling; method_found = true; while (input_file.get() != '(' && !input_file.eof()) { } if (input_file.eof()) { cout << "We needed an argument for Gibbs." << endl; return false; } if (!(input_file >> steps)) { input_file.clear(); cout << "Malformed Gibbs argument data tag" << endl; return false; } while (input_file.get() != ')' && !input_file.eof()) { } } if (code_string.compare("tree_mcmc") == 0) { inference_method = tree_gibbs_sampling; method_found = true; while (input_file.get() != '(' && !input_file.eof()) { } if (input_file.eof()) { cout << "We need an argument for Tree Gibbs." << endl; return false; } if (!(input_file >> steps)) { input_file.clear(); cout << "Malformed Tree Gibbs argument data tag" << endl; return false; } while (input_file.get() != ')' && !input_file.eof()) { } } if (code_string.compare("tree_mcmc_no_rb") == 0) { inference_method = tree_gibbs_sampling; method_found = true; using_rao_blackwell = false; while (input_file.get() != '(' && !input_file.eof()) { } if (input_file.eof()) { cout << "We need an argument for Tree Gibbs (without Rao-Blackwellization)." << endl; return false; } if (!(input_file >> steps)) { input_file.clear(); cout << "Malformed Tree Gibbs argument data tag" << endl; return false; } while (input_file.get() != ')' && !input_file.eof()) { } } if (code_string.compare("exact_bp") == 0) { inference_method = exact_belief_propagation; method_found = true; } if (code_string.compare("tree_partition") == 0) { inference_method = tree_partition; method_found = true; only_partition = true; } if (!method_found) { cout << "Unknown method or Malformed data tag" << endl; return false; } if (input_file.eof()) { cout << "Malformed data tag" << endl; return false; } return true;}bool Experiment::set_up_model_from_file(const std::string & input_file_name){ // We must just need to get from the file string code_string; unsigned int a, b; unsigned int number_vertices(0); bool no_command_tag(true); bool type_found(false); bool method_found(false); ifstream input_file (input_file_name.c_str()); if (!input_file.good()) { cout << "Failed to open the file '" << input_file_name << "'" << endl; return false; } while (true) { while ( (input_file.get() != '<') && (input_file.good()) ) { } if (input_file.eof()) { if (no_command_tag) { cout << "No valid Command Tag found" << endl; return false; } break; } getline(input_file, code_string, '>'); if ( code_string.compare("number of runs") == 0) { no_command_tag = false; while (input_file.get() != '(' && !input_file.eof()) { } if (input_file.eof()) { cout << "Command Tag with no data tag" << endl; return false; } if (!(input_file >> a)) { input_file.clear(); cout << "Malformed 'number of runs' data tag" << endl; return false; } else { number_runs = a; } while (input_file.get() != ')' && !input_file.eof()) { } if (input_file.eof()) { cout << "Malformed data tag" << endl; return false; } } if ( code_string.compare("number of vertices") == 0) { no_command_tag = false; while (input_file.get() != '(' && !input_file.eof()) { } if (input_file.eof()) { cout << "Command Tag with no data tag" << endl; return false; } if (!(input_file >> a)) { input_file.clear(); cout << "Malformed 'number of vertices' data tag" << endl; return false; } else { number_vertices = a; } while (input_file.get() != ')' && !input_file.eof()) { } if (input_file.eof()) { cout << "Malformed data tag" << endl; return false; } } if ( code_string.compare("structure") == 0) { char c = input_file.get(); while ( !((c == '(') || (c == '<')) && (!input_file.eof()) ) { c = input_file.get(); } if (input_file.eof()) { return true; } if (c == '<') { input_file.unget(); return true; } getline(input_file, code_string, ')'); if (code_string.compare("square-lattice") == 0) { while ( (input_file.get() != '(') && (!input_file.eof()) ) { } if (input_file.eof()) { cout << "Error, no parameters given to square-lattice" << endl; return false; } if (!(input_file >> a >> b)) { input_file.clear(); cout << "Error: Malformed 'structure' data tag" << endl; return false; } else { number_vertices = a * b; } } } if ( code_string.compare("model type") == 0) { no_command_tag = false; while (input_file.get() != '(' && !input_file.eof()) { } if (input_file.eof()) { cout << "Command Tag with no data tag" << endl; return false; } getline(input_file, code_string, ')'); if (code_string.compare("pairwise-discrete") == 0) { type_found = true; cout << "Creating Model with " << number_vertices << " vertices." << endl; model = new DiscreteProbabilisticModel(number_vertices); //cout << "Malformed 'number of vertices' data tag" << endl; //return false; } if (code_string.compare("factorgraph-discrete") == 0) { type_found = true; cout << "Creating Model with " << number_vertices << " vertices." << endl; model = new DiscreteFactorGraphProbabilisticModel(number_vertices); //cout << "Malformed 'number of vertices' data tag" << endl; //return false; } if (input_file.eof()) { cout << "Malformed data tag" << endl; return false; } } if ( code_string.compare("method") == 0) { no_command_tag = false; method_found = true; if (!unserialize_method(input_file)) { return false; } } if ( code_string.compare("monitor") == 0) { no_command_tag = false; cout << "Found command tag 'monitor'" << endl; monitoring = true; current_experiment = this; } if ( code_string.compare("timer") == 0) { no_command_tag = false; cout << "Found command tag 'timer'" << endl; using_timer = true; } } input_file.close(); if (type_found && method_found) { model->unserialize(input_file_name); return true; } return false;}bool Experiment::export_results_to_file(const std::string & output_file_name){ ofstream output_file (output_file_name.c_str()); if (!only_partition) { model->serialize(output_file); } if (monitoring) { serialize_errors(output_file); } if (only_partition) { model->serialize_partition(output_file); } return true;}bool Experiment::serialize_errors(ofstream & output_file){ unsigned int errors_size = runs.front().errors.size(); if (strict_error_timing) { errors_size = steps; if (using_timer) { errors_size = (unsigned int) (steps / sampling_interval); //cout << errors_size << endl; } for( vector <Run>::iterator it = runs.begin(); it != runs.end(); ++it) { it->errors.resize(errors_size -1); } } for ( vector <Run>::iterator it = runs.begin(); it != runs.end(); ++it) { //cout << "ES " << it->errors.size() << endl; if (it->errors.size() < errors_size ) { errors_size = it->errors.size(); } } for ( vector <Run>::iterator it = runs.begin(); it != runs.end(); ++it) { it->errors.resize(errors_size); } // Additional check errors_size = runs.front().errors.size(); for ( vector <Run>::iterator it = runs.begin(); it != runs.end(); ++it) { assert(it->errors.size() == errors_size); } output_file << "<exported errors> "; unsigned int i(0); for ( vector <Run>::iterator it = runs.begin(); it != runs.end(); ++it) { ++i; output_file << "( " << i << " ) ( "; it->export_run_errors(output_file); //computed_errors[i] += it->errors[i]; output_file << ") "; } return true;}void Experiment::get_inference(){ // Here we should get inference based on some parameter of Experiment //current_run = runs.begin(); for (unsigned int i = 0; i < number_runs; ++i) { runs.push_back(Run()); current_run = & runs.back(); next_sample_time = sampling_interval; //cout << "Making run " << i+1 << " of " << number_runs << endl; //cout << current_run << endl; switch (inference_method) { case exact_belief_propagation: { model->get_inference_mp(); }; break; case loopy_belief_propagation: { model->get_inference_loopy(steps, using_timer); }; break; case simple_gibbs_sampling: { model->get_inference_gibbs(steps, using_timer); }; break; case tree_gibbs_sampling: { model->get_inference_tree_gibbs(steps, using_rao_blackwell, using_timer); }; break; case tree_partition: { model->get_tree_partition(); }; break; } }}//******* End of Experiment Implementation *************//
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -