⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bp.cpp

📁 gibbs
💻 CPP
字号:
#include "BayesNet.h"#include "MarkovNet.h"#include "VarSchema.h"#include "VarSet.h"#include <stdio.h>#include <iostream>#include <fstream>// Headers required for getpid() syscall#include <sys/types.h>#include <unistd.h>#define SQR(x) ((x)*(x))double convergenceThreshold = 0.001;double damping = 0.001;#if 0double getLikelihood(VarSet& query, VarSet& evidence, MarkovNet& model){    // DEBUG    cout << "Running query...\n";    return model.getLikelihood(query, evidence, convergenceThreshold, damping);}#endifvoid checkpoint(int argc, char** argv, double ll, int numExamples){    char filename[1024];    sprintf(filename, "bp.%d", getpid());    ofstream tmpfile(filename);    for (int i = 1; i < argc; i++) {        tmpfile << argv[i] << endl;    }    tmpfile << endl;    tmpfile << ll << endl;    tmpfile << numExamples << endl;    tmpfile.close();}// HACK#include <stdio.h>int main(int argc, char* argv[]){    double ll = 0.0;    int numExamples = 0;    if (argc == 2) {        ifstream checkfile(argv[1]);        char** input = new (char*)[7];        input[0] = new char[1024];        strcpy(input[0], argv[0]);        for (int i = 1; i < 7; i++) {            input[i] = new char[1024];            checkfile >> input[i];        }        checkfile >> ll;        checkfile >> numExamples;        argc = 7;        argv = input;    } else if (argc < 6) {        cout << "Usage: bp <model> <testfile> <condfile> ";        cout << "<c-threshold> <damping>\n";        return -1;    }    BayesNet model;    /* ifstream modelIn(argv[1], ios::binary); */    FILE* modelIn = fopen(argv[1], "r");    if (!modelIn) {        cout << "ERROR: could not open model file \"" << argv[1] << "\"\n";        return -1;    }    loadModel(modelIn, model);    // Generate a pairwise Markov random field from this Bayes net    MarkovNet mnet(model);    // DEBUG    cout << "Successfully loaded model.\n";    ifstream testIn(argv[2]);    if (!testIn) {        cout << "ERROR: could not open test file \"" << argv[2] << "\"\n";        return -1;    }    ifstream condIn(argv[3]);    if (!condIn) {        cout << "ERROR: could not open conditions file \""             << argv[3] << "\"\n";        return -1;    }    convergenceThreshold = atof(argv[4]);    damping = atof(argv[5]);    VarSet example;    VarSet cond;    for (int i = 0; i < numExamples; i++) {        testIn >> example;        condIn >> cond;    }    // DEBUG    cout << "Loaded all tests.\n";    VarSet prevCond;    while (testIn) {        testIn >> example;        condIn >> cond;        double l0 = 0.0;#if 1        if (cond != prevCond) {            prevCond = cond;            // DEBUG            cout << "Resetting network.\n";            mnet.resetAllNodes();            mnet.resetAllMessages();            // Fix evidence variables            for (int i = 0; i < cond.getNumVars(); i++) {                if (cond.isTested(i)) {                    mnet.fixNodeValue(i, (int)cond[i]);                }            }            // Run belief propagation to convergence            mnet.runBP(convergenceThreshold, damping);        }        // Get marginal for actual query value        for (int i = 0; i < example.getNumVars(); i++) {            if (example.isTested(i) && !cond.isTested(i)) {                l0 = mnet.getMarginal(i).get((int)example[i]);                break;            }        }#else        l0 = mnet.getLikelihood(example, cond,                 convergenceThreshold, damping);#endif        // DEBUG        if (l0 == 0.0) {            cout << "ERROR: zero likelihood!\n";            cout << "Query: " << example << endl;            cout << "Evidence: " << cond << endl;        }        double l = log(l0);        ll += l;        numExamples++;        checkpoint(argc, argv, ll, numExamples);        // DEBUG        //cout << "ll = " << ll/numExamples << endl;        //cout << "ll = " << l << endl;    }    cout << "log(likelihood) = " << ll/numExamples << endl;    return 0;}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -