📄 gibbs.cpp
字号:
#include "BayesNet.h"#include "VarSchema.h"#include "VarSet.h"#include <stdio.h>#include <iostream>#include <fstream>// Headers required for getpid() syscall#include <sys/types.h>#include <unistd.h>#define SQR(x) ((x)*(x))float gasdev();int numChains = 10;double convergenceRatio = 1.1;int burnInIters = 100;int minIters = 20;int numZeroes = 0;int configIndex(VarSet& allVars, list<int>& testIndices, BayesNet& bn){ int rangeProduct = 1; int ret = 0; list<int>::iterator i; for (i = testIndices.begin(); i != testIndices.end(); i++) { // Test vars must be discrete. ret += rangeProduct * (int)allVars[*i]; rangeProduct *= bn.getRange(*i); } return ret;}void burnInChain(VarSet &chain, VarSet& evidence, int burnInIters, BayesNet& model){ for (int iter = 0; iter < burnInIters; iter++) { for (int v = 0; v < evidence.getNumVars(); v++) { if (!evidence.isTested(v)) { chain[v] = model.MBsample(v, chain); } } }}unsigned int sampleFromDist(const vector<double>& dist){ double p = (double)rand()/RAND_MAX; for (unsigned int v = 0; v < dist.size(); v++) { p -= dist[v]; if (p < 0.0) { return v; } } // DEBUG cout << "Error: returning last value in sampleFromDist()\n"; return (dist.size() - 1);}double getLikelihood(VarSet& query, VarSet& evidence, BayesNet& model){ // Count number of sampled variables, and the product of all test // variables int numVars = 0; int testRangeProduct = 1; list<int> testIndices; for (int v = 0; v < query.getNumVars(); v++) { if (!evidence.isTested(v) && query.isTested(v)) { testIndices.push_back(v); testRangeProduct *= model.getRange(v); numVars++; } } // Keep track of counts for every configuration of the test variables vector<vector<double> > counts(numChains); vector<vector<double> > sqCounts(numChains); for (int i = 0; i < testRangeProduct; i++) { for (int c = 0; c < numChains; c++) { counts[c].push_back(1.0); sqCounts[c].push_back(1.0); } } // Initialize and burn-in all chains vector<VarSet> chains(numChains); for (int c = 0; c < numChains; c++) { chains[c] = evidence; model.wholeSample(chains[c]); burnInChain(chains[c], evidence, 100, model); } // Sample, sample, sample until convergence bool notConverged = true; double iter = 0; int maxCount = testRangeProduct; do { double avgChainVariance[testRangeProduct]; double avgSquareSum[testRangeProduct]; memset(avgChainVariance, 0, testRangeProduct*sizeof(double)); memset(avgSquareSum, 0, testRangeProduct*sizeof(double)); iter++; maxCount += numVars; for (int c = 0; c < numChains; c++) { for (int v = 0; v < query.getNumVars(); v++) { // Don't resample evidence variables if (evidence.isTested(v)) { continue; }#define RAOBLACKWELL#ifdef RAOBLACKWELL // Update (Rao-Blackwellized) counts if (query.isTested(v)) { vector<double> dist = model.MBdist(v, chains[c]); for (int val = 0; val < model.getRange(v); val++) { // Update counts using the distribution chains[c][v] = val; int config = configIndex(chains[c], testIndices, model); counts[c][config] += dist[val]; // DEBUG if (isnan(dist[val])) { cout << "NaN count -- var: " << v << "; val: " << val << endl; } sqCounts[c][config] += dist[val]*dist[val]; } // Assign new value chains[c][v] = sampleFromDist(dist); } else { // Simply sample chains[c][v] = model.MBsample(v, chains[c]); }#else // Sample chains[c][v] = model.MBsample(v, chains[c]); if (query.isTested(v)) { int config = configIndex(chains[c], testIndices, model); counts[c][config]++; sqCounts[c][config]++; }#endif } // HACK: what should actually go here? if (iter < minIters) { continue; } // Convergence criteria: // Compute within-chain variance (and add to total average) for (int config = 0; config < testRangeProduct; config++) { double chainVar = (sqCounts[c][config] - SQR(counts[c][config])/iter)/(iter-1); avgChainVariance[config] += chainVar/numChains; /* cout << "chainVar[" << c << "][" << config << "] = " << chainVar << endl; */ } } // Don't check for convergence until we've completed some // minimum number of iterations. if (iter < minIters) { continue; } // Convergence criteria: // Compute between-chain variance notConverged = false; for (int config = 0; config < testRangeProduct; config++) { double squareSum = 0.0; double sum = 0.0; for (int c = 0; c < numChains; c++) { squareSum += SQR(counts[c][config]/iter); sum += counts[c][config]/iter; } double betweenChainVariance = (squareSum - SQR(sum)/numChains)/(numChains-1); double R = ((iter-1)/iter*avgChainVariance[config] + betweenChainVariance)/avgChainVariance[config]; if (sqrt(R) > convergenceRatio) { notConverged = true; break; } } } while (notConverged); int queryIndex = configIndex(query, testIndices, model); double totalCounts = 0.0; for (int c = 0; c < numChains; c++) { totalCounts += counts[c][queryIndex]; } // DEBUG if (isnan(totalCounts)) { cout << "totalCounts nan!\n"; } if (maxCount == 0.0) { cout << "maxCount = 0!\n"; } return totalCounts/(maxCount * numChains);}void checkpoint(int argc, char** argv, double ll, int numExamples, int numZeroes){ char filename[1024]; sprintf(filename, "gibbs.%d", getpid()); ofstream tmpfile(filename); for (int i = 1; i < argc; i++) { tmpfile << argv[i] << endl; } tmpfile << endl; tmpfile << ll << endl; tmpfile << numExamples << endl; tmpfile << numZeroes << endl; tmpfile.close();}// HACK#include <stdio.h>int main(int argc, char* argv[]){ double ll = 0.0; int numExamples = 0; if (argc == 2) { ifstream checkfile(argv[1]); char** input = new (char*)[7]; input[0] = new char[1024]; strcpy(input[0], argv[0]); for (int i = 1; i < 7; i++) { input[i] = new char[1024]; checkfile >> input[i]; } checkfile >> ll; checkfile >> numExamples; argc = 7; argv = input; } else if (argc < 6) { cout << "Usage: gibbs <model> <testfile> <condfile> "; cout << "<sampling iters> <c-ratio>\n"; return -1; } BayesNet model; /* ifstream modelIn(argv[1], ios::binary); */ FILE* modelIn = fopen(argv[1], "r"); if (!modelIn) { cout << "ERROR: could not open model file \"" << argv[1] << "\"\n"; return -1; } loadModel(modelIn, model); // DEBUG cout << "Successfully loaded model.\n"; ifstream testIn(argv[2]); if (!testIn) { cout << "ERROR: could not open test file \"" << argv[3] << "\"\n"; return -1; } ifstream condIn(argv[3]); if (!condIn) { cout << "ERROR: could not open conditions file \"" << argv[3] << "\"\n"; return -1; } minIters = atoi(argv[4]); convergenceRatio = atof(argv[5]); VarSet example; VarSet cond; for (int i = 0; i < numExamples; i++) { testIn >> example; condIn >> cond; } while (testIn) { testIn >> example; condIn >> cond; double l0 = getLikelihood(example, cond, model); // DEBUG if (l0 == 0.0) { cout << "ERROR: zero likelihood!\n"; cout << "Query: " << example << endl; cout << "Evidence: " << cond << endl; } double l = log(l0); ll += l; numExamples++; checkpoint(argc, argv, ll, numExamples, numZeroes); // DEBUG //cout << "ll = " << ll/numExamples << endl; //cout << "ll = " << l << endl; } cout << "numZeroes = " << numZeroes << endl; cout << "log(likelihood) = " << ll/numExamples << endl; return 0;}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -