📄 gibbssampler.cpp
字号:
} } // Initialize and burn-in all chains vector<VarConfig> chains(numChains, query); for (int c = 0; c < numChains; c++) { model.wholeSample(chains[c]); // Use a fixed number of burn-in iters, if appropriate if (fixedIters) { burnInChain(chains[c], evidence, burnInIters); } } // Sample, sample, sample until convergence double burnin_iter = 0; double sampling_iter = 0; double predicted_iters = minIters; bool burnin_done = fixedIters; while (1) { if (burnin_done) { sampling_iter++; } else { burnin_iter++; } for (int v = 0; v < model.getNumVars(); v++) { for (int c = 0; c < numChains; c++) { // Don't resample evidence variables if (evidence.isTested(v)) { continue; } // If it's a free variable (neither evidence nor query), // just sample quickly and move on. if (find(queryVars.begin(), queryVars.end(), v) == queryVars.end()) { chains[c][v] = model.MBsample(v, chains[c]); continue; } // Sample vector<double> dist = model.MBdist(v, chains[c]);#define RAOBLACKWELL#ifdef RAOBLACKWELL // Update (Rao-Blackwellized) counts for (int val = 0; val < model.getRange(v); val++) { chains[c][v] = val; // Update counts using the distribution int s = chains[c].getIndex(); counts[c][s] += dist[val]; sqCounts[c][s] += dist[val]*dist[val]; }#else // Update counts int s = chains[c].getIndex(); // DEBUG //cout << "s = " << s << endl; counts[c][s]++; sqCounts[c][s]++;#endif chains[c][v] = sampleFromDist(dist); } } // After completing some minimum number of iterations, // check for convergence of our burn-in period double R; if (!burnin_done && burnin_iter >= burnInIters && ((int)burnin_iter % 100 == 0) && (R = (testConvergence(counts, sqCounts, (int)burnin_iter * queryVars.size())) < convergenceRatio)) {#if 0 // DEBUG cout << "R = " << R << endl;#endif // Stop burn-in burnin_done = true; // Throw away counts for burn-in period for (int c = 0; c < numChains; c++) { for (int s = 0; s < numSummaries; s++) { counts[c][s] = 1.0/numSummaries; sqCounts[c][s] = 1.0/numSummaries; //counts[c][s] = 1.0; //sqCounts[c][s] = 1.0; } } } // Test for convergence of the sampling // Go until our standard error among the different chains is // less than 5% of the predicted value. if (burnin_done && sampling_iter >= predicted_iters) { // Stop, if we're only running a fixed number of iterations if (fixedIters) { break; } predicted_iters = predictIters(counts, sqCounts, (int)sampling_iter * queryVars.size()) / queryVars.size(); // DEBUG cout << "Predicted iters = " << predicted_iters << endl; if (predicted_iters <= sampling_iter) { break; } } } // Return distribution Distribution d(numSummaries); for (int s = 0; s < numSummaries; s++) { d[s] = 0.0; for (int c = 0; c < numChains; c++) { d[s] += counts[c][s]; } } // DEBUG if (!fixedIters) { cout << burnin_iter << "; " << sampling_iter << endl; } d.normalize(); jointDistrib = d;}double GibbsSampler::singleConditionalLogProb( const list<int>& queryVars, const VarSet& evidence, const VarSet& answer) const{ VarSchema schema = model.getSchema(); VarConfig query(evidence, queryVars, schema); double sqrt2pi = sqrt(2*PI); // Compute weight of uniform prior per configuration Prob prior = 1.0; { list<int>::const_iterator i; for (i = queryVars.begin(); i != queryVars.end(); i++) { if (schema.getRange(*i) > 0) { // Discrete variables: divide prior by range, since // probability is distributed uniformly across all values. prior *= 1.0/(double)schema.getRange(*i); } else { // Continuous variables: assume value is drawn from each // leaf with equal probability, so average across their // probabilities. We must do something like this, since // there's no way to have a uniform distribution across // an infinite range. list<const Leaf*> leafList = model.getDecisionTree(*i) ->getLeafList(); Prob p = 1.0; int n = 0; list<const Leaf*>::iterator li; for ( li = leafList.begin(); li != leafList.end(); li++ ) { p.logplus((*li)->getLogProb(answer[*i])); n++; } p *= 1.0/n; prior *= p; } } } vector<Prob> counts(numChains, prior); //vector<Prob> sqCounts(numChains, prior * prior); // Initialize and burn-in all chains vector<VarConfig> chains(numChains, query); for (int c = 0; c < numChains; c++) { model.wholeSample(chains[c]); // Use a fixed number of burn-in iters, if appropriate if (fixedIters) { burnInChain(chains[c], evidence, burnInIters); } } // Sample, sample, sample until convergence double burnin_iter = 0; double sampling_iter = 0; double predicted_iters = minIters; bool burnin_done = fixedIters; while (1) { if (burnin_done) { sampling_iter++; } else { burnin_iter++; } for (int v = 0; v < model.getNumVars(); v++) { for (int c = 0; c < numChains; c++) { // Don't resample evidence variables if (evidence.isTested(v)) { continue; } // If it's a free variable (neither evidence nor query), // just sample quickly and move on. if (find(queryVars.begin(), queryVars.end(), v) == queryVars.end()) { chains[c][v] = model.MBsample(v, chains[c]); continue; } // HACK -- new code for handling continuous variables and such chains[c][v] = model.MBsample(v, chains[c]); Prob density = 1.0; double* valarray = chains[c].getArray(); list<int>::const_iterator i; for (i = queryVars.begin(); i != queryVars.end(); i++) { if (chains[c][*i] == answer[*i]) { // If it matches, great! continue; } else if ( schema.getRange(*i) > 0 || !answer.isObserved(*i) || !chains[c].isObserved(*i) ) { // If a discrete variable differs, or a continuous // variable is only missing in the answer or the // current chain, then there's no match. density = 0.0; break; } else if ( answer.isObserved(*i) ) { // Smooth continuous values using kernel density // estimation. The mean of the normal distribution // used is the last sampled value. The variance is // the variance at the chosen leaf of each decision // tree. double sigma = model.getSD(*i, valarray); density.logplus(-0.5*(chains[c][*i] - answer[*i]) *(chains[c][*i] - answer[*i]) /(sigma*sigma)); density *= 1.0/(sigma * sqrt2pi); } } counts[c] += density; //sqCounts[c] += density * density; /* // Sample vector<double> dist = model.MBdist(v, chains[c]);#define RAOBLACKWELL#ifdef RAOBLACKWELL // Update (Rao-Blackwellized) counts for (int val = 0; val < model.getRange(v); val++) { chains[c][v] = val; // Update counts using the distribution if (chains[c] == answer) { counts[c] += dist[val]; sqCounts[c] += dist[val]*dist[val]; } }#else // Update counts if (chains[c] == answer) { counts[c][s]++; sqCounts[c][s]++; }#endif chains[c][v] = sampleFromDist(dist); */ } } // After completing some minimum number of iterations, // check for convergence of our burn-in period#if 0 double R;#endif if (!burnin_done && burnin_iter >= burnInIters && ((int)burnin_iter % 100 == 0) #if 0 && (R = (testConvergence(counts, sqCounts, (int)burnin_iter * queryVars.size())) < convergenceRatio)#endif ) {#if 0 // DEBUG cout << "R = " << R << endl;#endif // Stop burn-in burnin_done = true; // Throw away counts for burn-in period for (int c = 0; c < numChains; c++) { counts[c] = prior; //sqCounts[c] = prior * prior; } } // Test for convergence of the sampling // Go until our standard error among the different chains is // less than 5% of the predicted value. if (burnin_done && sampling_iter >= predicted_iters) { // Stop, if we're only running a fixed number of iterations if (fixedIters) { break; } #if 0 predicted_iters = predictIters(counts, sqCounts, (int)sampling_iter * queryVars.size()) / queryVars.size(); // DEBUG cout << "Predicted iters = " << predicted_iters << endl; if (predicted_iters <= sampling_iter) { break; }#endif } }#if 0 // DEBUG if (!fixedIters) { cout << burnin_iter << "; " << sampling_iter << endl; }#endif Prob testCounts = 0.0; double totalCounts = sampling_iter * numChains * queryVars.size() + numChains * 1.0; // Add in prior, of course for (int c = 0; c < numChains; c++) { testCounts += counts[c]; } if (testCounts <= 0.0) { cerr << "ERROR: testCounts <= 0.0\n"; cerr << "prior = " << prior << endl; } testCounts /= totalCounts; return testCounts.ln();}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -