📄 bp.cpp
字号:
#include "BayesNet.h"#include "MarkovNet.h"#include "VarSchema.h"#include "VarSet.h"#include <stdio.h>#include <iostream>#include <fstream>// Headers required for getpid() syscall#include <sys/types.h>#include <unistd.h>#define SQR(x) ((x)*(x))double convergenceThreshold = 0.001;double damping = 0.001;#if 0double getLikelihood(VarSet& query, VarSet& evidence, MarkovNet& model){ // DEBUG cout << "Running query...\n"; return model.getLikelihood(query, evidence, convergenceThreshold, damping);}#endifvoid checkpoint(int argc, char** argv, double ll, int numExamples){ char filename[1024]; sprintf(filename, "bp.%d", getpid()); ofstream tmpfile(filename); for (int i = 1; i < argc; i++) { tmpfile << argv[i] << endl; } tmpfile << endl; tmpfile << ll << endl; tmpfile << numExamples << endl; tmpfile.close();}// HACK#include <stdio.h>int main(int argc, char* argv[]){ double ll = 0.0; int numExamples = 0; if (argc == 2) { ifstream checkfile(argv[1]); char** input = new (char*)[7]; input[0] = new char[1024]; strcpy(input[0], argv[0]); for (int i = 1; i < 7; i++) { input[i] = new char[1024]; checkfile >> input[i]; } checkfile >> ll; checkfile >> numExamples; argc = 7; argv = input; } else if (argc < 6) { cout << "Usage: bp <model> <testfile> <condfile> "; cout << "<c-threshold> <damping>\n"; return -1; } BayesNet model; /* ifstream modelIn(argv[1], ios::binary); */ FILE* modelIn = fopen(argv[1], "r"); if (!modelIn) { cout << "ERROR: could not open model file \"" << argv[1] << "\"\n"; return -1; } loadModel(modelIn, model); // Generate a pairwise Markov random field from this Bayes net MarkovNet mnet(model); // DEBUG cout << "Successfully loaded model.\n"; ifstream testIn(argv[2]); if (!testIn) { cout << "ERROR: could not open test file \"" << argv[2] << "\"\n"; return -1; } ifstream condIn(argv[3]); if (!condIn) { cout << "ERROR: could not open conditions file \"" << argv[3] << "\"\n"; return -1; } convergenceThreshold = atof(argv[4]); damping = atof(argv[5]); VarSet example; VarSet cond; for (int i = 0; i < numExamples; i++) { testIn >> example; condIn >> cond; } // DEBUG cout << "Loaded all tests.\n"; VarSet prevCond; while (testIn) { testIn >> example; condIn >> cond; double l0 = 0.0;#if 1 if (cond != prevCond) { prevCond = cond; // DEBUG cout << "Resetting network.\n"; mnet.resetAllNodes(); mnet.resetAllMessages(); // Fix evidence variables for (int i = 0; i < cond.getNumVars(); i++) { if (cond.isTested(i)) { mnet.fixNodeValue(i, (int)cond[i]); } } // Run belief propagation to convergence mnet.runBP(convergenceThreshold, damping); } // Get marginal for actual query value for (int i = 0; i < example.getNumVars(); i++) { if (example.isTested(i) && !cond.isTested(i)) { l0 = mnet.getMarginal(i).get((int)example[i]); break; } }#else l0 = mnet.getLikelihood(example, cond, convergenceThreshold, damping);#endif // DEBUG if (l0 == 0.0) { cout << "ERROR: zero likelihood!\n"; cout << "Query: " << example << endl; cout << "Evidence: " << cond << endl; } double l = log(l0); ll += l; numExamples++; checkpoint(argc, argv, ll, numExamples); // DEBUG //cout << "ll = " << ll/numExamples << endl; //cout << "ll = " << l << endl; } cout << "log(likelihood) = " << ll/numExamples << endl; return 0;}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -