📄 inftest.cpp
字号:
#include "VarSet.h"#include "VarSchema.h"#include "NumSet.h"// Inference methods#include "InferenceMethod.h"#include "initinference.h"#include "VarConfig.h"#include <stdio.h>#include <math.h>#include <iostream>#include <fstream>// Headers required for getpid() syscall#include <sys/types.h>#include <unistd.h>#define SQR(x) ((x)*(x))#define STDEV(a, sq_a, n) (((a) - (sq_a)/(n))/(n))double getMargLogLikelihood(VarSet& query, VarSet& evidence, InferenceMethod* method){ double totalLogLikelihood = 0.0; int numMarginals = 0; // Run inference method->runMarginalInference(evidence); // Get marginal for actual query for (int i = 0; i < query.getNumVars(); i++) { if (query.isTested(i) && !evidence.isTested(i)) { totalLogLikelihood += log(method->getMarginalProb(i,query[i])); numMarginals++; break; } } // This query makes a contribution (averaged over the number of // subqueries) to the overall log likelihood. return totalLogLikelihood/numMarginals;}double getJointLogLikelihood(VarSet& query, VarSet& evidence, InferenceMethod* method){ list<int> queryVars; for (int i = 0; i < query.getNumVars(); i++) { if (query.isTested(i) && !evidence.isTested(i)) { queryVars.push_back(i); } } // Run inference#if 0 method->runJointInference(queryVars, evidence); // Return prob of actual query VarSchema schema(1, query.getNumVars()); for (int i = 0; i < query.getNumVars(); i++) { schema[i] = method->getRange(i); } VarConfig config(query, queryVars, schema); return log(method->getJointProb(config));#else return method->singleConditionalLogProb(queryVars, evidence, query);#endif}void checkpoint(int argc, char** argv, double ll, double sq_ll, double t, double sq_t, int numExamples){ char filename[1024]; sprintf(filename, "inftest.%d", getpid()); ofstream tmpfile(filename); for (int i = 1; i < argc; i++) { tmpfile << argv[i] << endl; } tmpfile << "ENDPARAMETERS\n"; tmpfile << endl; tmpfile << ll << endl; tmpfile << sq_ll << endl; tmpfile << t << endl; tmpfile << sq_t << endl; tmpfile << numExamples << endl; tmpfile.close();}// HACK#include <stdio.h>#include <sys/times.h>struct tms startTime;void resetTimer(){ times(&startTime);}double getTimer(){ struct tms endTime; double ticksPerSecond = sysconf(_SC_CLK_TCK); times(&endTime); return (double)(endTime.tms_utime - startTime.tms_utime)/ticksPerSecond;}int main(int argc, char* argv[]){ double ll = 0.0; double sq_ll = 0.0; double t = 0.0; double sq_t = 0.0; int numExamples = 0; bool margTest; if (argc == 2) { ifstream checkfile(argv[1]); // Read in all command line parameters, one at a time // Command line parameters are terminated with a single // blank line. list<char*> argvList; char* buf; do { // NOTE: buffer overflow vulnerability here. buf = new char[1024]; checkfile >> buf; argvList.push_back(buf); } while (strcmp(buf, "ENDPARAMETERS")); // Put command line parameters into a new char** argc = argvList.size(); argv = new (char*)[argc]; list<char*>::iterator argviter; int i; for (i = 1, argviter = argvList.begin(); i < argc; i++, argviter++) { argv[i] = *argviter; } // Read in statistics of completed tests checkfile >> ll; checkfile >> sq_ll; checkfile >> t; checkfile >> sq_t; checkfile >> numExamples; } else if (argc < 5) { cout << "Usage: inftest -[jm] <testfile> <condfile> -[gbn] "; cout << "[other parameters]\n"; return -1; } // -m => marginal queries; -j => joint queries margTest = (argv[1][1] == 'm'); InferenceMethod* method = initInferenceMethod(argc, argv, 4); // Load files containing queries and evidence ifstream testIn(argv[2]); if (!testIn) { cout << "ERROR: could not open test file \"" << argv[2] << "\"\n"; return -1; } ifstream condIn(argv[3]); if (!condIn) { cout << "ERROR: could not open conditions file \"" << argv[3] << "\"\n"; return -1; } VarSet example; VarSet cond; // Skip through file until we reach the last checkpoint location for (int i = 0; i < numExamples; i++) { testIn >> example; condIn >> cond; } while (testIn) { testIn >> example; condIn >> cond; double l; double time; if (!testIn) { break; } resetTimer(); if (margTest) { l = getMargLogLikelihood(example, cond, method); } else { //cout << "Joint likelihood!\n"; l = getJointLogLikelihood(example, cond, method); } time = getTimer(); // DEBUG if (isinf(l) || isnan(l)) { cout << "ERROR: undefined log likelihood!\n"; cout << "Query: " << example << endl; cout << "Evidence: " << cond << endl; } ll += l; sq_ll += l * l; t += time; sq_t += time * time; numExamples++; checkpoint(argc, argv, ll, sq_ll, t, sq_t, numExamples); // DEBUG cout << l << " " << time << endl; //cerr << "ll = " << ll/numExamples << endl; //cout << "ll = " << l << endl; } double ll_stderr = sqrt((sq_ll - ll*ll/numExamples)/SQR(numExamples)); cout << "log(likelihood) = " << ll/numExamples; cout << " +/- " << ll_stderr << endl; double t_stderr = sqrt((sq_t - t*t/numExamples)/SQR(numExamples)); if (isnan(t_stderr)) { t_stderr = 0; } cout << "time = " << t/numExamples; cout << " +/- " << t_stderr << endl; return 0;}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -