📄 probabilisticmodel.cc.svn-base
字号:
#include <ProbabilisticModel.h>#include <boost/graph/graph_traits.hpp>#include <boost/graph/subgraph.hpp>#include <boost/graph/copy.hpp>#include <Node.h>#include <RandomVariable.h>#include <Potential.h>#include <GraphAlgorithms.h>#include <GraphTreePartition.h>#include <OutputGraph.h>#include <InputFunctions.h>#include <Generic.h>using namespace std;using namespace boost;bool get_next_begin_token(ifstream & input_file){ char c = input_file.get(); while ( !((c == '(') || (c == '<')) && (!input_file.eof()) ) { c = input_file.get(); } if (input_file.eof()) { return false; } if (c == '<') { input_file.unget(); return false; } return true; }//***** ProbabilisticModel implementation ******//ProbabilisticModel::~ProbabilisticModel(){}//***** End of ProbabilisticModel implementation ******////***** DerivedProbabilisticModel implementation ******//// Constructortemplate <class G>DerivedProbabilisticModel <G>::DerivedProbabilisticModel(const unsigned int n): g(n){}template <class G>bool DerivedProbabilisticModel <G>::unserialize(const string & input_file_name){ typedef typename graph_traits<G>::vertex_descriptor VertexDescriptor; string code_string; bool command_tag(false); ifstream input_file (input_file_name.c_str()); if (!input_file.good()) { cout << "Failed to open the file '" << input_file_name << "'" << endl; return false; } while (true) { while ( (input_file.get() != '<') && (input_file.good()) ) { } if (input_file.eof()) { if (!command_tag) { cout << "No valid Command Tag found" << endl; return false; } break; } getline(input_file, code_string, '>'); if ( code_string.compare("structure") == 0) { cout << "Found command tag 'structure'" << endl; command_tag = true; unserialize_structure(input_file); } if ( code_string.compare("link") == 0) { cout << "Found command tag 'link'" << endl; command_tag = true; unserialize_links(input_file); } if ( code_string.compare("all random variable sizes") == 0) { cout << "Found command tag 'all random variable sizes'" << endl; command_tag = true; unserialize_all_random_variable_sizes(input_file); } if ( code_string.compare("random variable sizes") == 0) { cout << "Found command tag 'random variable sizes'" << endl; command_tag = true; unserialize_random_variable_sizes(input_file); } if ( code_string.compare("observations") == 0) { cout << "Found command tag 'observations'" << endl; command_tag = true; unserialize_observations(input_file); } if ( code_string.compare("potential tables") == 0) { cout << "Found command tag 'potential tables'" << endl; command_tag = true; unserialize_potential_tables(input_file); } if ( code_string.compare("internal potentials") == 0) { cout << "Found command tag 'internal potentials'" << endl; command_tag = true; unserialize_internal_potentials(input_file); } if ( code_string.compare("all potential tables") == 0) { cout << "Found command tag 'all potential tables'" << endl; command_tag = true; unserialize_all_potential_tables(input_file); } if ( code_string.compare("all internal potentials") == 0) { cout << "Found command tag 'all internal potentials'" << endl; command_tag = true; unserialize_all_internal_potentials(input_file); } if ( code_string.compare("ground truth") == 0) { cout << "Found command tag 'ground truth'" << endl; command_tag = true; if (!unserialize_ground_truth(input_file)) { cout << "Parsing error occured in command tag 'ground truth'." << endl; } } } return true;}template <class Graph>bool DerivedProbabilisticModel <Graph>::unserialize_internal_potentials(ifstream & input_file){ typedef typename graph_traits<Graph>::vertex_descriptor VertexDescriptor; unsigned int a; while (true) { if (!get_next_begin_token(input_file)) { break; } if (!(input_file >> a)) { input_file.clear(); cout << "Error: Malformed 'internal potentials' data tag" << endl; return false; } else { VertexDescriptor u = vertex(a-1, g); while ( (input_file.get() != ')') && (!input_file.eof()) ) { } while ( (input_file.get() != '(') && (!input_file.eof()) ) { } if (input_file.eof()) { cout << "Error: malformed 'internal potentials' data tag" << endl; return false; } vector < double> potential_table; double c; while (input_file >> c) { potential_table.push_back(c); } static_cast < DiscretePotential *> (g[u]->get_potential())->setup_potential_values(potential_table); } input_file.clear(); while ( (input_file.get() != ')') && (!input_file.eof()) ) { } if (input_file.eof()) { cout << "Error: malformed 'internal potentials' data tag" << endl; return false; } } return true;}template <class Graph>bool DerivedProbabilisticModel <Graph>::unserialize_ground_truth(ifstream & original_input_file){ typedef typename graph_traits<Graph>::vertex_descriptor VertexDescriptor; unsigned int a; string code_string; ifstream external_input_file; ifstream * input_file_ptr = & original_input_file; if (!get_next_begin_token(original_input_file)) { cout << 'a' << endl; return false; } getline(original_input_file, code_string, ')'); if ( code_string.compare("external") == 0) { if (!get_next_begin_token(original_input_file)) { cout << 'b' << endl; return false; } getline(original_input_file, code_string, ')'); external_input_file.open(code_string.c_str()); if (!external_input_file.good()) { cout << "Failed to open the external file '" << code_string << "'" << endl; return false; } //cout << "Opened the external file '" << code_string << "'." << endl; while ( (external_input_file.get() != '<') && (external_input_file.good()) ) { } if (external_input_file.eof()) { cout << "Not a ground truth file." << endl; return false; } getline(external_input_file, code_string, '>'); if (! code_string.compare("marginals") == 0) { cout << 'c' << endl; return false; } input_file_ptr = & external_input_file; } else { if (! code_string.compare("normal") == 0) { cout << 'd' << endl; return false; } } ifstream & input_file = * input_file_ptr; while (true) { if (!get_next_begin_token(input_file)) { return true; } if (!(input_file >> a)) { input_file.clear(); cout << "Error: Malformed 'ground truth' data tag" << endl; return false; } else { VertexDescriptor u = vertex(a-1, g); while ( (input_file.get() != ')') && (!input_file.eof()) ) { } while ( (input_file.get() != '(') && (!input_file.eof()) ) { } if (input_file.eof()) { cout << "Error: malformed 'ground truth' data tag" << endl; return false; } vector < double> true_marginals; double c; while (input_file >> c) { true_marginals.push_back(c); } g[u]->get_random_variable()->set_reference_marginals(true_marginals); } input_file.clear(); while ( (input_file.get() != ')') && (!input_file.eof()) ) { } if (input_file.eof()) { cout << "Error: malformed 'ground truth' data tag" << endl; return false; } } return true;}template <class G>void DerivedProbabilisticModel <G>::get_inference_mp(){ belief_propagation(g);}template <class G>void DerivedProbabilisticModel <G>::get_inference_loopy(const unsigned int a, const bool timer){ loopy_belief_propagation(g, a, timer);}template <class G>void DerivedProbabilisticModel <G>::get_inference_gibbs(const unsigned int gibbs_steps, const bool timer ){ gibbs_sampler(g, gibbs_steps, 0, timer);}template <class G>void DerivedProbabilisticModel <G>::get_inference_tree_gibbs(const unsigned int gibbs_steps, const bool rao_blackwell, const bool timer){ tree_gibbs_sampler(g, gibbs_steps, 0, rao_blackwell, timer);}template <class G>void DerivedProbabilisticModel <G>::get_tree_partition(){ subgraph <G> subgraph_g; copy_graph(g, subgraph_g); partition_graph_into_trees(subgraph_g); display_subgraph(subgraph_g); unsigned int a(0); typename subgraph<G>::children_iterator s, s_end; for ( tie(s, s_end) = subgraph_g.children(); s != s_end; ++s) { ++a; } number_of_trees.push_back(a);}template <class G>bool DerivedProbabilisticModel <G>::serialize_marginals(ofstream & output_file){ output_file << "<marginals> " ; typename graph_traits<G>::vertex_iterator u, u_end; for (tie (u, u_end) = vertices(g); u != u_end; ++u) { g[*u]->serialize_marginals(output_file); } output_file << endl; return true;}template <class G>bool DerivedProbabilisticModel <G>::serialize_partition(ofstream & output_file){ output_file << "<partition> "; typename graph_traits<G>::vertex_iterator u, u_end; for (vector<unsigned int>::iterator it = number_of_trees.begin(); it != number_of_trees.end(); ++it) { output_file << *it << " "; } output_file << endl; return true;}template <class G>bool DerivedProbabilisticModel <G>::serialize(ofstream & output_file){ return serialize_marginals(output_file);}//***** End of DerivedProbabilisticModel implementation ******////***** DiscreteProbabilisticModel implementation ******//// ConstructorDiscreteProbabilisticModel::DiscreteProbabilisticModel(unsigned int n): DerivedProbabilisticModel<DiscretePairwiseGraph>(n){}// Virtual functionsbool DiscreteProbabilisticModel::unserialize_structure(ifstream & input_file){ typedef graph_traits<DiscretePairwiseGraph>::vertex_descriptor VertexDescriptor; unsigned int a, b; double c; string code_string; if (!get_next_begin_token(input_file)) { return false; } getline(input_file, code_string, ')'); if (code_string.compare("square-lattice") == 0) { while ( (input_file.get() != '(') && (!input_file.eof()) ) { } if (input_file.eof()) { cout << "Error, no parameters given to square-lattice" << endl; return false; } if (!(input_file >> a >> b)) { input_file.clear(); cout << "Error: Malformed 'structure' data tag" << endl; return false; } else { create_pairwise_square_lattice(a, b, g); } } if (code_string.compare("random") == 0) { while ( (input_file.get() != '(') && (!input_file.eof()) ) { } if (input_file.eof()) { cout << "Error, no parameters given to random" << endl; return false; } if (!(input_file >> c)) { input_file.clear(); cout << "Error: Malformed 'structure' data tag" << endl; return false; } else { create_random_graph(c, g); } } input_file.clear(); while ( (input_file.get() != ')') && (!input_file.eof()) ) { } if (input_file.eof()) { cout << "Error: malformed 'link' data tag" << endl; return false; } return true;}bool DiscreteProbabilisticModel::unserialize_all_random_variable_sizes(ifstream & input_file){ typedef graph_traits<Graph>::vertex_descriptor VertexDescriptor; unsigned int a, b; while (true) { if (!get_next_begin_token(input_file)) { break; } if (!(input_file >> b)) { input_file.clear(); cout << "Error: Malformed 'link' data tag" << endl; return false; }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -