⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 gibbssampler.cpp

📁 gibbs
💻 CPP
📖 第 1 页 / 共 2 页
字号:
        }    }    // Initialize and burn-in all chains    vector<VarConfig> chains(numChains, query);    for (int c = 0; c < numChains; c++) {        model.wholeSample(chains[c]);        // Use a fixed number of burn-in iters, if appropriate        if (fixedIters) {            burnInChain(chains[c], evidence, burnInIters);        }    }        // Sample, sample, sample until convergence    double burnin_iter = 0;    double sampling_iter = 0;    double predicted_iters = minIters;    bool burnin_done = fixedIters;    while (1) {        if (burnin_done) {            sampling_iter++;        } else {            burnin_iter++;        }        for (int v = 0; v < model.getNumVars(); v++) {            for (int c = 0; c < numChains; c++) {                // Don't resample evidence variables                if (evidence.isTested(v)) {                    continue;                }                // If it's a free variable (neither evidence nor query),                // just sample quickly and move on.                if (find(queryVars.begin(), queryVars.end(), v) == queryVars.end()) {                    chains[c][v] = model.MBsample(v, chains[c]);                    continue;                }                                // Sample                vector<double> dist = model.MBdist(v, chains[c]);#define RAOBLACKWELL#ifdef RAOBLACKWELL                // Update (Rao-Blackwellized) counts                for (int val = 0; val < model.getRange(v); val++) {                    chains[c][v] = val;                    // Update counts using the distribution                    int s = chains[c].getIndex();                    counts[c][s] += dist[val];                    sqCounts[c][s] += dist[val]*dist[val];                }#else                // Update counts                int s = chains[c].getIndex();                // DEBUG                //cout << "s = " << s << endl;                counts[c][s]++;                sqCounts[c][s]++;#endif                chains[c][v] = sampleFromDist(dist);            }        }        // After completing some minimum number of iterations,        // check for convergence of our burn-in period        double R;        if (!burnin_done && burnin_iter >= burnInIters                && ((int)burnin_iter % 100 == 0)                 && (R = (testConvergence(counts, sqCounts,                         (int)burnin_iter * queryVars.size()))                    < convergenceRatio)) {#if 0                        // DEBUG                        cout << "R = " << R << endl;#endif            // Stop burn-in            burnin_done = true;            // Throw away counts for burn-in period            for (int c = 0; c < numChains; c++) {                for (int s = 0; s < numSummaries; s++) {                    counts[c][s] = 1.0/numSummaries;                    sqCounts[c][s] = 1.0/numSummaries;                    //counts[c][s] = 1.0;                    //sqCounts[c][s] = 1.0;                }            }        }        // Test for convergence of the sampling        // Go until our standard error among the different chains is        // less than 5% of the predicted value.        if (burnin_done && sampling_iter >= predicted_iters) {            // Stop, if we're only running a fixed number of iterations            if (fixedIters) {                break;            }                        predicted_iters = predictIters(counts, sqCounts,                     (int)sampling_iter * queryVars.size()) / queryVars.size();            // DEBUG            cout << "Predicted iters = " << predicted_iters << endl;            if (predicted_iters <= sampling_iter) {                break;            }        }    }    // Return distribution    Distribution d(numSummaries);    for (int s = 0; s < numSummaries; s++) {        d[s] = 0.0;        for (int c = 0; c < numChains; c++) {            d[s] += counts[c][s];        }    }    // DEBUG    if (!fixedIters) {        cout << burnin_iter << "; " << sampling_iter << endl;    }    d.normalize();    jointDistrib = d;}double GibbsSampler::singleConditionalLogProb(        const list<int>& queryVars,        const VarSet& evidence,        const VarSet& answer) const{    VarSchema schema = model.getSchema();    VarConfig query(evidence, queryVars, schema);    double sqrt2pi = sqrt(2*PI);    // Compute weight of uniform prior per configuration    Prob prior = 1.0;    {        list<int>::const_iterator i;        for (i = queryVars.begin(); i != queryVars.end(); i++) {            if (schema.getRange(*i) > 0) {                // Discrete variables: divide prior by range, since                // probability is distributed uniformly across all values.                prior *= 1.0/(double)schema.getRange(*i);            } else {                // Continuous variables: assume value is drawn from each                // leaf with equal probability, so average across their                // probabilities.  We must do something like this, since                 // there's no way to have a uniform distribution across                // an infinite range.                list<const Leaf*> leafList = model.getDecisionTree(*i)                                                ->getLeafList();                Prob p = 1.0;                int n = 0;                list<const Leaf*>::iterator li;                for ( li = leafList.begin(); li != leafList.end(); li++ ) {                    p.logplus((*li)->getLogProb(answer[*i]));                    n++;                }                p *= 1.0/n;                prior *= p;            }        }    }    vector<Prob> counts(numChains, prior);    //vector<Prob> sqCounts(numChains, prior * prior);    // Initialize and burn-in all chains    vector<VarConfig> chains(numChains, query);    for (int c = 0; c < numChains; c++) {        model.wholeSample(chains[c]);        // Use a fixed number of burn-in iters, if appropriate        if (fixedIters) {            burnInChain(chains[c], evidence, burnInIters);        }    }    // Sample, sample, sample until convergence    double burnin_iter = 0;    double sampling_iter = 0;    double predicted_iters = minIters;    bool burnin_done = fixedIters;    while (1) {        if (burnin_done) {            sampling_iter++;        } else {            burnin_iter++;        }        for (int v = 0; v < model.getNumVars(); v++) {            for (int c = 0; c < numChains; c++) {                // Don't resample evidence variables                if (evidence.isTested(v)) {                    continue;                }                // If it's a free variable (neither evidence nor query),                // just sample quickly and move on.                if (find(queryVars.begin(), queryVars.end(), v) == queryVars.end()) {                    chains[c][v] = model.MBsample(v, chains[c]);                    continue;                }                                // HACK -- new code for handling continuous variables and such                chains[c][v] = model.MBsample(v, chains[c]);                Prob density = 1.0;                double* valarray = chains[c].getArray();                list<int>::const_iterator i;                for (i = queryVars.begin(); i != queryVars.end(); i++) {                    if (chains[c][*i] == answer[*i]) {                        // If it matches, great!                        continue;                    } else if ( schema.getRange(*i) > 0 || !answer.isObserved(*i)                            || !chains[c].isObserved(*i) ) {                        // If a discrete variable differs, or a continuous                        // variable is only missing in the answer or the                        // current chain, then there's no match.                        density = 0.0;                        break;                    } else if ( answer.isObserved(*i) ) {                        // Smooth continuous values using kernel density                        // estimation.  The mean of the normal distribution                        // used is the last sampled value.  The variance is                        // the variance at the chosen leaf of each decision                        // tree.                        double sigma = model.getSD(*i, valarray);                        density.logplus(-0.5*(chains[c][*i] - answer[*i])                                        *(chains[c][*i] - answer[*i])                                        /(sigma*sigma));                        density *= 1.0/(sigma * sqrt2pi);                    }                }                counts[c] += density;                //sqCounts[c] += density * density;                /*                // Sample                vector<double> dist = model.MBdist(v, chains[c]);#define RAOBLACKWELL#ifdef RAOBLACKWELL                // Update (Rao-Blackwellized) counts                for (int val = 0; val < model.getRange(v); val++) {                    chains[c][v] = val;                    // Update counts using the distribution                    if (chains[c] == answer) {                        counts[c] += dist[val];                        sqCounts[c] += dist[val]*dist[val];                    }                }#else                // Update counts                if (chains[c] == answer) {                    counts[c][s]++;                    sqCounts[c][s]++;                }#endif                chains[c][v] = sampleFromDist(dist);                */            }        }        // After completing some minimum number of iterations,        // check for convergence of our burn-in period#if 0        double R;#endif        if (!burnin_done && burnin_iter >= burnInIters                && ((int)burnin_iter % 100 == 0) #if 0                && (R = (testConvergence(counts, sqCounts,                         (int)burnin_iter * queryVars.size()))                    < convergenceRatio)#endif                ) {#if 0                        // DEBUG                        cout << "R = " << R << endl;#endif            // Stop burn-in            burnin_done = true;            // Throw away counts for burn-in period            for (int c = 0; c < numChains; c++) {                counts[c]   = prior;                //sqCounts[c] = prior * prior;            }        }        // Test for convergence of the sampling        // Go until our standard error among the different chains is        // less than 5% of the predicted value.        if (burnin_done && sampling_iter >= predicted_iters) {            // Stop, if we're only running a fixed number of iterations            if (fixedIters) {                break;            }            #if 0            predicted_iters = predictIters(counts, sqCounts,                     (int)sampling_iter * queryVars.size()) / queryVars.size();            // DEBUG            cout << "Predicted iters = " << predicted_iters << endl;            if (predicted_iters <= sampling_iter) {                break;            }#endif        }    }#if 0    // DEBUG    if (!fixedIters) {        cout << burnin_iter << "; " << sampling_iter << endl;    }#endif    Prob testCounts = 0.0;    double totalCounts = sampling_iter * numChains * queryVars.size()        + numChains * 1.0; // Add in prior, of course    for (int c = 0; c < numChains; c++) {        testCounts += counts[c];    }    if (testCounts <= 0.0) {        cerr << "ERROR: testCounts <= 0.0\n";        cerr << "prior = " << prior << endl;    }    testCounts /= totalCounts;    return testCounts.ln();}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -