⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 gibbs.cpp

📁 gibbs
💻 CPP
字号:
#include "BayesNet.h"#include "VarSchema.h"#include "VarSet.h"#include <stdio.h>#include <iostream>#include <fstream>// Headers required for getpid() syscall#include <sys/types.h>#include <unistd.h>#define SQR(x) ((x)*(x))float gasdev();int numChains = 10;double convergenceRatio = 1.1;int burnInIters = 100;int minIters = 20;int numZeroes = 0;int configIndex(VarSet& allVars, list<int>& testIndices, BayesNet& bn){    int rangeProduct = 1;    int ret = 0;    list<int>::iterator i;    for (i = testIndices.begin(); i != testIndices.end(); i++) {        // Test vars must be discrete.        ret += rangeProduct * (int)allVars[*i];        rangeProduct *= bn.getRange(*i);    }    return ret;}void burnInChain(VarSet &chain, VarSet& evidence,         int burnInIters, BayesNet& model){    for (int iter = 0; iter < burnInIters; iter++) {        for (int v = 0; v < evidence.getNumVars(); v++) {            if (!evidence.isTested(v)) {                chain[v] = model.MBsample(v, chain);            }        }    }}unsigned int sampleFromDist(const vector<double>& dist){    double p = (double)rand()/RAND_MAX;    for (unsigned int v = 0; v < dist.size(); v++) {        p -= dist[v];        if (p < 0.0) {            return v;        }    }    // DEBUG    cout << "Error: returning last value in sampleFromDist()\n";    return (dist.size() - 1);}double getLikelihood(VarSet& query, VarSet& evidence, BayesNet& model){    // Count number of sampled variables, and the product of all test    // variables    int numVars = 0;    int testRangeProduct = 1;    list<int> testIndices;    for (int v = 0; v < query.getNumVars(); v++) {        if (!evidence.isTested(v) && query.isTested(v)) {            testIndices.push_back(v);            testRangeProduct *= model.getRange(v);            numVars++;        }    }    // Keep track of counts for every configuration of the test variables    vector<vector<double> > counts(numChains);    vector<vector<double> > sqCounts(numChains);    for (int i = 0; i < testRangeProduct; i++) {        for (int c = 0; c < numChains; c++) {            counts[c].push_back(1.0);            sqCounts[c].push_back(1.0);        }    }    // Initialize and burn-in all chains    vector<VarSet> chains(numChains);    for (int c = 0; c < numChains; c++) {        chains[c] = evidence;        model.wholeSample(chains[c]);        burnInChain(chains[c], evidence, 100, model);    }    // Sample, sample, sample until convergence    bool notConverged = true;    double iter = 0;    int maxCount = testRangeProduct;    do {        double avgChainVariance[testRangeProduct];        double avgSquareSum[testRangeProduct];        memset(avgChainVariance, 0, testRangeProduct*sizeof(double));        memset(avgSquareSum, 0, testRangeProduct*sizeof(double));        iter++;        maxCount += numVars;        for (int c = 0; c < numChains; c++) {            for (int v = 0; v < query.getNumVars(); v++) {                // Don't resample evidence variables                if (evidence.isTested(v)) {                    continue;                }#define RAOBLACKWELL#ifdef RAOBLACKWELL                // Update (Rao-Blackwellized) counts                if (query.isTested(v)) {                    vector<double> dist = model.MBdist(v, chains[c]);                    for (int val = 0; val < model.getRange(v); val++) {                        // Update counts using the distribution                        chains[c][v] = val;                        int config = configIndex(chains[c], testIndices, model);                        counts[c][config] += dist[val];                        // DEBUG                        if (isnan(dist[val])) {                            cout << "NaN count -- var: " << v << "; val: " << val << endl;                        }                        sqCounts[c][config] += dist[val]*dist[val];                    }                    // Assign new value                    chains[c][v] = sampleFromDist(dist);                } else {                    // Simply sample                    chains[c][v] = model.MBsample(v, chains[c]);                }#else                // Sample                chains[c][v] = model.MBsample(v, chains[c]);                if (query.isTested(v)) {                    int config = configIndex(chains[c], testIndices, model);                    counts[c][config]++;                    sqCounts[c][config]++;                }#endif            }            // HACK: what should actually go here?            if (iter < minIters) {                continue;            }            // Convergence criteria:            // Compute within-chain variance (and add to total average)            for (int config = 0; config < testRangeProduct; config++) {                double chainVar = (sqCounts[c][config]                         - SQR(counts[c][config])/iter)/(iter-1);                avgChainVariance[config] += chainVar/numChains;                /* cout << "chainVar[" << c << "][" << config << "] = "                     << chainVar << endl; */            }        }        // Don't check for convergence until we've completed some         // minimum number of iterations.        if (iter < minIters) {            continue;        }        // Convergence criteria:        // Compute between-chain variance        notConverged = false;        for (int config = 0; config < testRangeProduct; config++) {            double squareSum = 0.0;            double sum = 0.0;            for (int c = 0; c < numChains; c++) {                squareSum += SQR(counts[c][config]/iter);                sum += counts[c][config]/iter;            }                        double betweenChainVariance                = (squareSum - SQR(sum)/numChains)/(numChains-1);            double R = ((iter-1)/iter*avgChainVariance[config]                          + betweenChainVariance)/avgChainVariance[config];            if (sqrt(R) > convergenceRatio) {                notConverged = true;                break;            }        }    } while (notConverged);    int queryIndex = configIndex(query, testIndices, model);    double totalCounts = 0.0;    for (int c = 0; c < numChains; c++) {        totalCounts += counts[c][queryIndex];    }    // DEBUG    if (isnan(totalCounts)) {        cout << "totalCounts nan!\n";    }    if (maxCount == 0.0) {        cout << "maxCount = 0!\n";    }    return totalCounts/(maxCount * numChains);}void checkpoint(int argc, char** argv, double ll,         int numExamples, int numZeroes){    char filename[1024];    sprintf(filename, "gibbs.%d", getpid());    ofstream tmpfile(filename);    for (int i = 1; i < argc; i++) {        tmpfile << argv[i] << endl;    }    tmpfile << endl;    tmpfile << ll << endl;    tmpfile << numExamples << endl;    tmpfile << numZeroes << endl;    tmpfile.close();}// HACK#include <stdio.h>int main(int argc, char* argv[]){    double ll = 0.0;    int numExamples = 0;    if (argc == 2) {        ifstream checkfile(argv[1]);        char** input = new (char*)[7];        input[0] = new char[1024];        strcpy(input[0], argv[0]);        for (int i = 1; i < 7; i++) {            input[i] = new char[1024];            checkfile >> input[i];        }        checkfile >> ll;        checkfile >> numExamples;        argc = 7;        argv = input;    } else if (argc < 6) {        cout << "Usage: gibbs <model> <testfile> <condfile> ";        cout << "<sampling iters> <c-ratio>\n";        return -1;    }    BayesNet model;    /* ifstream modelIn(argv[1], ios::binary); */    FILE* modelIn = fopen(argv[1], "r");    if (!modelIn) {        cout << "ERROR: could not open model file \"" << argv[1] << "\"\n";        return -1;    }    loadModel(modelIn, model);    // DEBUG    cout << "Successfully loaded model.\n";    ifstream testIn(argv[2]);    if (!testIn) {        cout << "ERROR: could not open test file \"" << argv[3] << "\"\n";        return -1;    }    ifstream condIn(argv[3]);    if (!condIn) {        cout << "ERROR: could not open conditions file \""             << argv[3] << "\"\n";        return -1;    }    minIters = atoi(argv[4]);    convergenceRatio = atof(argv[5]);    VarSet example;    VarSet cond;    for (int i = 0; i < numExamples; i++) {        testIn >> example;        condIn >> cond;    }    while (testIn) {        testIn >> example;        condIn >> cond;        double l0 = getLikelihood(example, cond, model);        // DEBUG        if (l0 == 0.0) {            cout << "ERROR: zero likelihood!\n";            cout << "Query: " << example << endl;            cout << "Evidence: " << cond << endl;        }        double l = log(l0);        ll += l;        numExamples++;        checkpoint(argc, argv, ll, numExamples, numZeroes);        // DEBUG        //cout << "ll = " << ll/numExamples << endl;        //cout << "ll = " << l << endl;    }    cout << "numZeroes = " << numZeroes << endl;    cout << "log(likelihood) = " << ll/numExamples << endl;    return 0;}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -