⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 inftest.cpp

📁 gibbs
💻 CPP
字号:
#include "VarSet.h"#include "VarSchema.h"#include "NumSet.h"// Inference methods#include "InferenceMethod.h"#include "initinference.h"#include "VarConfig.h"#include <stdio.h>#include <math.h>#include <iostream>#include <fstream>// Headers required for getpid() syscall#include <sys/types.h>#include <unistd.h>#define SQR(x) ((x)*(x))#define STDEV(a, sq_a, n) (((a) - (sq_a)/(n))/(n))double getMargLogLikelihood(VarSet& query, VarSet& evidence, InferenceMethod* method){    double totalLogLikelihood = 0.0;    int numMarginals = 0;    // Run inference    method->runMarginalInference(evidence);    // Get marginal for actual query    for (int i = 0; i < query.getNumVars(); i++) {        if (query.isTested(i) && !evidence.isTested(i)) {            totalLogLikelihood += log(method->getMarginalProb(i,query[i]));            numMarginals++;            break;        }    }    // This query makes a contribution (averaged over the number of    // subqueries) to the overall log likelihood.    return totalLogLikelihood/numMarginals;}double getJointLogLikelihood(VarSet& query, VarSet& evidence,         InferenceMethod* method){    list<int> queryVars;    for (int i = 0; i < query.getNumVars(); i++) {        if (query.isTested(i) && !evidence.isTested(i)) {            queryVars.push_back(i);        }    }    // Run inference#if 0    method->runJointInference(queryVars, evidence);    // Return prob of actual query    VarSchema schema(1, query.getNumVars());    for (int i = 0; i < query.getNumVars(); i++) {        schema[i] = method->getRange(i);    }    VarConfig config(query, queryVars, schema);    return log(method->getJointProb(config));#else    return method->singleConditionalLogProb(queryVars, evidence, query);#endif}void checkpoint(int argc, char** argv, double ll, double sq_ll,        double t, double sq_t, int numExamples){    char filename[1024];    sprintf(filename, "inftest.%d", getpid());    ofstream tmpfile(filename);    for (int i = 1; i < argc; i++) {        tmpfile << argv[i] << endl;    }    tmpfile << "ENDPARAMETERS\n";    tmpfile << endl;    tmpfile << ll << endl;    tmpfile << sq_ll << endl;    tmpfile << t << endl;    tmpfile << sq_t << endl;    tmpfile << numExamples << endl;    tmpfile.close();}// HACK#include <stdio.h>#include <sys/times.h>struct tms startTime;void resetTimer(){    times(&startTime);}double getTimer(){    struct tms endTime;    double ticksPerSecond = sysconf(_SC_CLK_TCK);    times(&endTime);    return (double)(endTime.tms_utime - startTime.tms_utime)/ticksPerSecond;}int main(int argc, char* argv[]){    double ll = 0.0;    double sq_ll = 0.0;    double t = 0.0;    double sq_t = 0.0;    int numExamples = 0;    bool margTest;    if (argc == 2) {        ifstream checkfile(argv[1]);        // Read in all command line parameters, one at a time        // Command line parameters are terminated with a single        // blank line.        list<char*> argvList;        char* buf;        do {            // NOTE: buffer overflow vulnerability here.            buf = new char[1024];            checkfile >> buf;            argvList.push_back(buf);        } while (strcmp(buf, "ENDPARAMETERS"));        // Put command line parameters into a new char**        argc = argvList.size();        argv = new (char*)[argc];        list<char*>::iterator argviter;        int i;        for (i = 1, argviter = argvList.begin();              i < argc; i++, argviter++) {            argv[i] = *argviter;        }        // Read in statistics of completed tests        checkfile >> ll;        checkfile >> sq_ll;        checkfile >> t;        checkfile >> sq_t;        checkfile >> numExamples;    } else if (argc < 5) {        cout << "Usage: inftest -[jm] <testfile> <condfile> -[gbn] ";        cout << "[other parameters]\n";        return -1;    }    // -m => marginal queries; -j => joint queries    margTest = (argv[1][1] == 'm');    InferenceMethod* method = initInferenceMethod(argc, argv, 4);    // Load files containing queries and evidence    ifstream testIn(argv[2]);    if (!testIn) {        cout << "ERROR: could not open test file \"" << argv[2] << "\"\n";        return -1;    }    ifstream condIn(argv[3]);    if (!condIn) {        cout << "ERROR: could not open conditions file \""             << argv[3] << "\"\n";        return -1;    }    VarSet example;    VarSet cond;    // Skip through file until we reach the last checkpoint location    for (int i = 0; i < numExamples; i++) {        testIn >> example;        condIn >> cond;    }    while (testIn) {        testIn >> example;        condIn >> cond;        double l;        double time;        if (!testIn) {            break;        }        resetTimer();        if (margTest) {            l = getMargLogLikelihood(example, cond, method);        } else {            //cout << "Joint likelihood!\n";            l = getJointLogLikelihood(example, cond, method);        }        time = getTimer();        // DEBUG        if (isinf(l) || isnan(l)) {            cout << "ERROR: undefined log likelihood!\n";            cout << "Query: " << example << endl;            cout << "Evidence: " << cond << endl;        }        ll    += l;        sq_ll += l * l;        t     += time;        sq_t  += time * time;        numExamples++;        checkpoint(argc, argv, ll, sq_ll, t, sq_t, numExamples);        // DEBUG        cout << l << " " << time << endl;        //cerr << "ll = " << ll/numExamples << endl;        //cout << "ll = " << l << endl;    }    double ll_stderr = sqrt((sq_ll - ll*ll/numExamples)/SQR(numExamples));    cout << "log(likelihood) = " << ll/numExamples;    cout << " +/- " << ll_stderr << endl;    double t_stderr = sqrt((sq_t - t*t/numExamples)/SQR(numExamples));    if (isnan(t_stderr)) {        t_stderr = 0;    }    cout << "time = " << t/numExamples;    cout << " +/- " << t_stderr << endl;    return 0;}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -