⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 gibbssampler.cpp

📁 gibbs
💻 CPP
📖 第 1 页 / 共 2 页
字号:
#include "GibbsSampler.h"#include "VarConfig.h"#include <algorithm>#include "Prob.h"#define CHECKMEM(x) \{ \    cout << "Pre-" << x << endl; \    int* foo = new int[1000]; \    cout << "Post-" << x << endl; \}#define SQR(x) ((x) * (x))#if 0// Unused -- we use the VarConfig class instead.int GibbsSampler::configIndex(VarSet& allVars, list<int>& testIndices){    int rangeProduct = 1;    int ret = 0;    list<int>::iterator i;    for (i = testIndices.begin(); i != testIndices.end(); i++) {        // Test vars must be discrete.        ret += rangeProduct * (int)allVars[*i];        rangeProduct *= model.getRange(*i);    }    return ret;}#endif// Old way of burning in -- run a constant number of iterations.void GibbsSampler::burnInChain(VarSet &chain, const VarSet& evidence,         int burnInIters) const{    for (int iter = 0; iter < burnInIters; iter++) {        for (int v = 0; v < evidence.getNumVars(); v++) {            if (!evidence.isTested(v)) {                chain[v] = model.MBsample(v, chain);            }        }    }}unsigned int GibbsSampler::sampleFromDist(const vector<double>& dist) const{    double p = (double)rand()/RAND_MAX;    for (unsigned int v = 0; v < dist.size(); v++) {        p -= dist[v];        if (p < 0.0) {            return v;        }    }    // DEBUG    cout << "Error: returning last value in sampleFromDist()\n";    return (dist.size() - 1);}double GibbsSampler::testConvergence(vector<vector<double> > summaries,        vector<vector<double> > sqSummaries, int n) const{    double maxR = 0.0;    int numChains = summaries.size();    int numSummaries = sqSummaries[0].size();    vector<double> avgSummaryVariance(numSummaries);    // Iterate through all chains    for (int c = 0; c < numChains; c++) {        // Compute within-chain variance for each summary statistic        for (int s = 0; s < numSummaries; s++) {#if 0            // HACK: We require that every single state receive at least a            // fractional count in some chain before we converge.            if (summaries[c][s] == 1.0) {                // DEBUG                cout << "c = " << c << "; s = " << s << endl;                return 1000.0;            }#endif            double chainVar = (sqSummaries[c][s] - SQR(summaries[c][s])/n)/(n-1);            avgSummaryVariance[s] += chainVar/numChains;        }    }    for (int s = 0; s < numSummaries; s++) {        // Compute between-chain variance for each summary statistic        double squareSum = 0.0;        double sum = 0.0;        for (int c = 0; c < numChains; c++) {            squareSum += SQR(summaries[c][s]/n);            sum += summaries[c][s]/n;        }        double betweenChainVariance            = (squareSum - SQR(sum)/numChains)/(numChains-1);        // Compute convergence criteria R        double R = ((n-1.0)/n*avgSummaryVariance[s]                      + betweenChainVariance)/avgSummaryVariance[s];        // Report largest convergence statistic        if (R > maxR) {            maxR = R;#if 0        // DEBUG        cout << "Counts:";        for (int c = 0; c < numChains; c++) {            cout << " " << summaries[c][s];            cout << " (" << sqSummaries[c][s] << ")";        }        cout << "\n";        cout << "Within: " << avgSummaryVariance[s];        cout << "; Between: " << betweenChainVariance;        cout << "; R: " << sqrt(R) << endl;        // END DEBUG#endif        }    }    // DEBUG    //cout << "sqrt(R) = " << sqrt(maxR) << endl;    return sqrt(maxR);}double GibbsSampler::predictIters(vector<vector<double> > summaries,        vector<vector<double> > sqSummaries, int n) const{    double maxV = 0.0;    int numChains = summaries.size();    int numSummaries = sqSummaries[0].size();    vector<double> avgSummaryVariance(numSummaries);    // Iterate through all summary statistics    for (int s = 0; s < numSummaries; s++) {        // Consider the chains as independent estimates of each summary        // statistic, and compute their standard deviation.        double squareSum = 0.0;        double sum = 0.0;        for (int c = 0; c < numChains; c++) {            squareSum += SQR(summaries[c][s]/n);            sum += summaries[c][s]/n;        }        // See page 740 of DeGroot and Schervish, 3rd ed.        double S = sqrt((squareSum - SQR(sum)/numChains)/numChains);        double sigma_hat = sqrt((double)n) * S;        // Compute number of expected iterations (with 95% certainty)        // to get the estimate correct within 5%.        // (See page 707, eqn 11.1.5 of DeGroot and Schervish, 3rd ed.)        double epsilon = 0.05 * sum/numChains;        double v = SQR(1.96 * sigma_hat/epsilon);        // Report largest number of iterations to run        if (v > maxV) {            maxV = v;#if 0            // DEBUG            cout << "epsilon = " << epsilon << endl;            cout << "sigma_hat = " << sigma_hat << endl;            cout << "n = " << n << endl;            cout << "sum = " << sum << endl;            cout << "squareSum = " << squareSum << endl;#endif#if 0            double mean = sum/numChains;            double maxRatio = 1.0;            for (int c = 0; c < numChains; c++) {                double currStat = summaries[c][s]/n;                if (currStat/mean > maxRatio) {                    maxRatio = currStat/mean;                }                if (mean/currStat > maxRatio) {                    maxRatio = mean/currStat;                }            }            cout << "Max ratio: " << maxRatio << endl;#endif        }    }    return maxV;}void GibbsSampler::runMarginalInference(const VarSet& evidence){    // We use this vector to convert var/value pairs into summary     // statistic indices.    vector<vector<int> > index(model.getNumVars());    int numSummaries = 0;    for (int v = 0; v < model.getNumVars(); v++) {        for (int val = 0; val < model.getRange(v); val++) {            index[v].push_back(numSummaries++);        }    }    // Keep track of marginal counts for all test variables    vector<vector<double> > counts(numChains, vector<double>(numSummaries));    vector<vector<double> > sqCounts(numChains, vector<double>(numSummaries));    for (int c = 0; c < numChains; c++) {        for (int s = 0; s < numSummaries; s++) {            counts[c][s] = 0.0;            sqCounts[c][s] = 0.0;            //counts[c][s] = 1.0;            //sqCounts[c][s] = 1.0;        }    }    // Initialize and burn-in all chains    vector<VarSet> chains(numChains);    for (int c = 0; c < numChains; c++) {        chains[c] = evidence;        model.wholeSample(chains[c]);        // Use a fixed number of burn-in iters, if appropriate        if (fixedIters) {            burnInChain(chains[c], evidence, burnInIters);        }    }    // Sample, sample, sample until convergence    double burnin_iter = 0;    double sampling_iter = 0;    double predicted_iters = minIters;    bool burnin_done = fixedIters;    while (1) {        if (burnin_done) {            sampling_iter++;        } else {            burnin_iter++;        }        // Sample all variables and increase counts        for (int c = 0; c < numChains; c++) {            for (int v = 0; v < model.getNumVars(); v++) {                // Don't resample evidence variables                if (evidence.isTested(v)) {                    continue;                }                                // Sample                vector<double> dist = model.MBdist(v, chains[c]);                chains[c][v] = sampleFromDist(dist);#define RAOBLACKWELL#ifdef RAOBLACKWELL                // Update (Rao-Blackwellized) counts                for (int val = 0; val < model.getRange(v); val++) {                    // Update counts using the distribution                    int s = index[v][val];                    counts[c][s] += dist[val];                    sqCounts[c][s] += dist[val]*dist[val];                }#else                // Update counts                int s = index[v][(int)chains[c][v]];                // DEBUG                //cout << "s = " << s << endl;                counts[c][s]++;                sqCounts[c][s]++;#endif            }        }        // After completing some minimum number of iterations,        // check for convergence of our burn-in period        if (!burnin_done && burnin_iter >= burnInIters                 && ((int)burnin_iter % 100 == 0)                 && (testConvergence(counts, sqCounts, (int)burnin_iter)                    < convergenceRatio)) {            // Stop burn-in            burnin_done = true;            // Throw away counts for burn-in period            for (int c = 0; c < numChains; c++) {                for (int s = 0; s < numSummaries; s++) {                    counts[c][s] = 0.0;                    sqCounts[c][s] = 0.0;                    //counts[c][s] = 1.0;                    //sqCounts[c][s] = 1.0;                }            }        }        // Test for convergence of the sampling        // Go until our standard error among the different chains is        // less than 5% of the predicted value.        if (burnin_done && sampling_iter >= predicted_iters) {            // Stop, if we're only running a fixed number of iterations            if (fixedIters) {                break;            }                        predicted_iters = predictIters(counts, sqCounts, (int)sampling_iter);            // DEBUG            cout << "Predicted iters = " << predicted_iters << endl;            if (predicted_iters <= sampling_iter) {                break;            }        }    }    // Save distributions    for (int v = 0; v < model.getNumVars(); v++) {        if (evidence.isTested(v)) {            continue;        }        Distribution m(model.getRange(v));        for (int val = 0; val < model.getRange(v); val++) {            m[val] = 0;            for (int c = 0; c < numChains; c++) {                m[val] += counts[c][index[v][val]];            }        }        m.normalize();        marginals[v] = m;#ifdef DEBUG        if (!evidence.isTested(v)) {            cout << v << ": " << m << endl;        }#endif    }    // DEBUG    if (!fixedIters) {        cout << burnin_iter << "; " << sampling_iter << endl;    }}void GibbsSampler::runJointInference(const list<int>& queryVars,        const VarSet& evidence){    VarSchema schema = model.getSchema();    VarConfig query(evidence, queryVars, schema);    int numSummaries = query.getMaxIndex() + 1;        // Keep track of counts for all test configurations    vector<vector<double> > counts(numChains, vector<double>(numSummaries));    vector<vector<double> > sqCounts(numChains, vector<double>(numSummaries));    for (int c = 0; c < numChains; c++) {        for (int s = 0; s < numSummaries; s++) {            counts[c][s] = 1.0/numSummaries;            sqCounts[c][s] = 1.0/numSummaries;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -