⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 c_inference.cpp

📁 The package includes 3 Matlab-interfaces to the c-code: 1. inference.m An interface to the full
💻 CPP
📖 第 1 页 / 共 2 页
字号:
#include "mex.h"#include "fillMethods.h"#include "LogLoopy.h"#include "LogLoopySTime.h"#include "LogPairsGBP.h"#include "LogGBP.h"#include "Gibbs.h"#include "Wolff.h"#include "SwendsenWang.h"#include "Metropolis.h"#include "LogMeanField.h"#include "GBPPreProcessor.h"#include <iostream>// *****************************************************************************// enumerators// *****************************************************************************enum algorithmType {AT_LOOPY,AT_GBP,AT_GIBBS,AT_WOLFF,AT_SWENDSEN_WANG,		    AT_METROPOLIS,AT_MEAN_FIELD};// *****************************************************************************// mexFunction// *****************************************************************************void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){  // **************************************************************************  // variables declaration  // **************************************************************************  // variables for both Loopy and GBP  double*** initMsg = 0;  bool trw = false;  bool full = true;  double* countingNode = 0; // relevant for loopy, or for gbp if trw = true      // variables for Loopy  Strategy strategy;  SumOrMax sumOrMax;  double gbp_alpha;  double** rho = 0; // relevant if trw = true  bool saveTime = false;  // variables for GBP  int*** assignInd = 0;  double* bethe = 0;  GBPPreProcessor* processor = 0;  MRF* reg_mrf = 0;  RegionLevel* regions = 0;  vector<RegionLevel>* allRegions = 0;  Potentials* bigRegsPot = 0;  bool allLevels = false;  bool regBeliefs = false;  int num_regs = 0;  // variables for Monte-Carlo  int burningTime, samplingInterval, num_samples;  int* startX = 0;  // for Loopy, Mean-Field  bool logspace = false;  bool logBels = false;  // for Loopy,GBP,Mean-Field  int maxIter;  double threshold;    // **************************************************************************  // reading input arguments  // **************************************************************************  // check number of input arguments.  // arguments should be:  //  // adjMat - 1xN cell array, each cell {i} is a row vector with the indices of  //          i's neighbours  //  // lambda - there are 2 forms for lambda:  //          1. in general MRF algorithms (loopy, gbp, gibbs, mean-field) :  //             lambda should be a cell array of 1xN, each cell {i} is a cell  //             array of 1xneighbNum(i). each cell {i}{n} is a VixVj matrix,  //             where j is the n-th neighbour of i  //          2. in PottsMRF alhorithms (monte-carlo algorithms which are planned  //             for Potts model, i.e. metropolis and the cluster algorithms  //             wolff and swendsen-wang) :  //             here lambda should be 1xN cell array, each cell {i} is a row  //             vector with the strength of interaction of i with each of its  //             neighbours  // note: Psi{i,j} = exp( [lambda(i,j), 0; 0, lambda(i,j)] )  //  // local - cell array of Nx1, each cell {i} is a row vector of length Vi  //  // algorithm - integer representing the inference algorithm to use, see the  //             enumerator algorithmType at the top of this page  //  // temperature - double scalar, the temperature of the system  //  // model - integeger representing the model, see the enumerator in "definitions.h"  //  // trw - use Tree-Reweighted  //  // for other parameters required for each algorithm see the header of "inference.m"  //  //  // note: N = number of nodes, V = number of possible values    if (nrhs < 10 || nrhs > 20) {    mexErrMsgTxt("Incorrect number of inputs.");  }  // get algorithm-type  algorithmType algo_type = (algorithmType)((int)(mxGetScalar(prhs[3])));  Model model = (Model)((int)(mxGetScalar(prhs[5])));  bool potts_model = ((model==POTTS) ||		      (algo_type==AT_WOLFF) ||		      (algo_type==AT_SWENDSEN_WANG));  bool monte_carlo = ((algo_type==AT_GIBBS) ||		      (algo_type==AT_WOLFF) ||		      (algo_type==AT_SWENDSEN_WANG) ||		      (algo_type==AT_METROPOLIS));  // check number of output arguments  if ((nlhs > 6) || ((nlhs > 2) && (algo_type != AT_GBP) && (algo_type != AT_LOOPY))) {    mexErrMsgTxt("Too many output arguments.");  }  // get number of nodes and adjMat  vector<Nodes>* adjMat = new vector<Nodes>();  fillAdjMat(prhs[0],*adjMat);  int num_nodes = adjMat->size();  // define the MRF  MRF* mrf = 0;  if (potts_model) {    mrf = new PottsMRF(*adjMat);  }  else {    mrf = new MRF(*adjMat);  }    // get local potentials  fillLocalMat(prhs[2],mrf);    // get pairwise potentials  if (potts_model) {    fillLambdaMat(prhs[1],(PottsMRF*)mrf);  }  else {    fillPsiMat(prhs[1],mrf);  }  // For monte-carlo algorithms (gibbs, wolff, swendsen-wang), get  // the initial state and the sampling parameters  if (monte_carlo) {    if (nrhs != 10) {      mexErrMsgTxt("incorrect number of inputs");    }    startX = new int[num_nodes];    fillInitialAssignment(prhs[6], startX, num_nodes);    // get burningTime, samplingInterval, num_samples    burningTime = (int)(mxGetScalar(prhs[7]));    samplingInterval = (int)(mxGetScalar(prhs[8]));    num_samples = (int)(mxGetScalar(prhs[9]));      }  else {    // for all non-monte-carlo-algorithms (Mean-Field, BP & GBP)    maxIter = (int)(mxGetScalar(prhs[6]));    threshold = mxGetScalar(prhs[7]);    // get log-space flag    logspace = ((int)(mxGetScalar(prhs[8]))) > 0;    mrf->logspace = logspace;    logBels = ((int)(mxGetScalar(prhs[9]))) > 0;    if (algo_type == AT_LOOPY) {      // for loopy belief propagation:      // get sum-or-max-flag and strategy      if (nrhs != 17) {	mexErrMsgTxt("incorrect number of inputs");      }	      sumOrMax = (SumOrMax)((int)(mxGetScalar(prhs[10])));      strategy = (Strategy)((int)(mxGetScalar(prhs[11])));      trw = ((int)(mxGetScalar(prhs[12]))) > 0;      if (trw) {	rho = new double*[num_nodes];	for (int i=0; i<num_nodes; i++) {	  rho[i] = new double[mrf->neighbNum(i)];	}	fillRhoMat(prhs[13],mrf,rho);      }	      // get save-time flag      saveTime = ((int)(mxGetScalar(prhs[15]))) > 0;      // get initial messages, if given      int initM_nd = mxGetNumberOfDimensions(prhs[14]);      const int* initM_dim = mxGetDimensions(prhs[14]);      if ((initM_nd == 2) && (initM_dim[0] == 1) &&	  (initM_dim[1] == num_nodes) && mxIsCell(prhs[14])) {	initMsg = new double**[num_nodes];	for (int i=0; i<num_nodes; i++) {	  mxArray* initMsg_i = mxGetCell(prhs[14],i);	  int Ni = mrf->neighbNum(i);	  int len = (saveTime ? num_nodes : Ni);	  initMsg[i] = new double*[len];	  if (saveTime) {	    for (int j=0; j<num_nodes; j++) {	      initMsg[i][j] = 0;	    }	  }	  for (int n=0; n<Ni; n++) {	    int j = mrf->adjMat[i][n];	    int nei = (saveTime ? j : n);	    initMsg[i][nei] = new double[mrf->V[j]];	    mxArray* initMsg_ij = mxGetCell(initMsg_i,nei);	    fillDouble(initMsg_ij,initMsg[i][nei],mrf->V[j]);	  }	}      }      int count_nd = mxGetNumberOfDimensions(prhs[16]);      const int* count_dim = mxGetDimensions(prhs[16]);      if ((count_nd==2) && (count_dim[0]*count_dim[1]==num_nodes)) {	countingNode = new double[num_nodes];	fillDouble(prhs[16], countingNode, num_nodes);	// incorporate local potentials into pairwise	for (int i=0; i<num_nodes; i++) {	  if (mrf->neighbNum(i)>0) {	    int j = mrf->adjMat[i][0];	    if (i<j) {	      for (int xi=0; xi<mrf->V[i]; xi++) {		for (int xj=0; xj<mrf->V[j]; xj++) {		  mrf->lambdaMat[i][0][xi][xj] *= mrf->localMat[i][xi];		}		mrf->localMat[i][xi] = 1.0;	      }	    }	    else {	      int n = 0;	      while (mrf->adjMat[j][n] != i) {		n++;	      }	      for (int xi=0; xi<mrf->V[i]; xi++) {		for (int xj=0; xj<mrf->V[j]; xj++) {		  mrf->lambdaMat[j][n][xj][xi] *= mrf->localMat[i][xi];		}		mrf->localMat[i][xi] = 1.0;	      }	      	    }	  }	  else {	    mexPrintf("warning: the graph is not connected\n");	  }	}      }    }    if (algo_type == AT_GBP) {      // for generalized belief propagation:      // get regions, regions-adj (if given), sum-or-max-flag      // and alpha      if (nrhs != 20) {	mexErrMsgTxt("incorrect number of inputs");      }      allLevels = (int)(mxGetScalar(prhs[11])) > 0;      if (allLevels) {	allRegions = new vector<RegionLevel>();	allRegions->clear();	fillRegionLevels(prhs[10],*allRegions);      }      else {	regions = new RegionLevel();	fillRegions(prhs[10],*regions);      }      sumOrMax = (SumOrMax)((int)(mxGetScalar(prhs[12])));      gbp_alpha = mxGetScalar(prhs[13]);      trw = ((int)(mxGetScalar(prhs[14]))) > 0;      if (trw) {	countingNode = new double[num_nodes];	fillDouble(prhs[15], countingNode, num_nodes);      }      full = ((int)(mxGetScalar(prhs[16]))) > 0;      // get initial messages, if given      int initM_nd = mxGetNumberOfDimensions(prhs[17]);      const int* initM_dim = mxGetDimensions(prhs[17]);      if ((initM_nd == 2) && mxIsCell(prhs[17])) {	num_regs = initM_dim[0] * initM_dim[1];	initMsg = new double**[num_regs];	for (int i=0; i<num_regs; i++) {	  mxArray* initMsg_i = mxGetCell(prhs[17],i);	  int msg_i_nd = mxGetNumberOfDimensions(initMsg_i);	  const int* msg_i_dim = mxGetDimensions(initMsg_i);	  if ((msg_i_nd != 2) || !mxIsCell(initMsg_i)) {	    mexErrMsgTxt("each cell {i} in initMsg for GBP should be a cell array in length of number of neighbour-regions to region i\n");	  }	  int Ni = msg_i_dim[0] * msg_i_dim[1];	  initMsg[i] = new double*[Ni];	  for (int n=0; n<Ni; n++) {	    mxArray* initMsg_ij = mxGetCell(initMsg_i,n);	    const int* msg_ij_dim = mxGetDimensions(initMsg_ij);	    int numStates = msg_ij_dim[0] * msg_ij_dim[1];	    initMsg[i][n] = new double[numStates];	    fillDouble(initMsg_ij,initMsg[i][n],numStates);	  }	}      }      // get potentials for the big regions, if given      int regPot_nd = mxGetNumberOfDimensions(prhs[18]);      const int* regPot_dim = mxGetDimensions(prhs[18]);      if ((regPot_nd == 2) && mxIsCell(prhs[18])) {	num_regs = regPot_dim[0] * regPot_dim[1];	bigRegsPot = new Potentials[num_regs];	for (int i=0; i<num_regs; i++) {	  mxArray* regPot_i = mxGetCell(prhs[18],i);	  int regPot_i_nd = mxGetNumberOfDimensions(regPot_i);	  const int* regPot_i_dim = mxGetDimensions(regPot_i);	  if (regPot_i_nd != 2) {	    mexErrMsgTxt("each cell {i} in big-regions' potentials for GBP should be a vector in length of number of possible states for the region i\n");	  }	  int Vi = regPot_i_dim[0] * regPot_i_dim[1];	  bigRegsPot[i] = new Potential[Vi];	  fillDouble(regPot_i,bigRegsPot[i],Vi);	}	      }      // if true - get the region beliefs (instead of the single beliefs)      regBeliefs = ((int)(mxGetScalar(prhs[19]))) > 0;    }  }  // get tepmerature  double temperature = mxGetScalar(prhs[4]);  mrf->setTemperature(temperature);    // **************************************************************************  // create the algorithm  // **************************************************************************  InferenceAlgorithm* algorithm = 0;  switch (algo_type) {    case AT_LOOPY:      if (saveTime) {	if (logspace) {	  algorithm = new LogLoopySTime(mrf,sumOrMax,strategy,maxIter,rho,initMsg,logBels,threshold);	}	else {	  algorithm = new LoopySTime(mrf,sumOrMax,strategy,maxIter,rho,initMsg,threshold);	}      }      else {	if (logspace) {	  if (countingNode != 0) {	    algorithm = new LogPairsGBP(mrf,sumOrMax,strategy,maxIter,countingNode,initMsg,logBels,threshold);	  }	  else {	    algorithm = new LogLoopy(mrf,sumOrMax,strategy,maxIter,rho,initMsg,logBels,threshold);	  }	}	else {	  if (countingNode != 0) {	    algorithm = new PairsGBP(mrf,sumOrMax,strategy,maxIter,countingNode,initMsg,threshold);	  }	  else {	    algorithm = new Loopy(mrf,sumOrMax,strategy,maxIter,rho,initMsg,threshold);	  }	}      }      break;    case AT_GBP:      if (allLevels) {	processor = new GBPPreProcessor(allRegions, mrf, trw, full, countingNode, bigRegsPot);      }      else {	processor = new GBPPreProcessor(*regions, mrf, trw, full, countingNode, bigRegsPot);	regions->clear();	delete regions;	regions = 0;      }      reg_mrf = processor->getRegionMRF();      assignInd = processor->getAssignTable();      bethe = processor->getBethe();      if (logspace) {	algorithm = new LogGBP(reg_mrf,assignInd,bethe,sumOrMax,gbp_alpha,maxIter,initMsg,logBels,threshold);      }      else {	algorithm = new GBP(reg_mrf,assignInd,bethe,sumOrMax,gbp_alpha,maxIter,initMsg,threshold);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -