⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 c_inference.cpp

📁 The package includes 3 Matlab-interfaces to the c-code: 1. inference.m An interface to the full
💻 CPP
📖 第 1 页 / 共 2 页
字号:
      }            break;    case AT_GIBBS:      algorithm = new Gibbs(mrf,startX,burningTime,samplingInterval,num_samples);      delete[] startX;      startX = 0;            break;    case AT_WOLFF:      algorithm = new Wolff((PottsMRF*)mrf,startX,burningTime,samplingInterval,num_samples);      delete[] startX;      startX = 0;            break;    case AT_SWENDSEN_WANG:      algorithm = new SwendsenWang((PottsMRF*)mrf,startX,burningTime,samplingInterval,num_samples);      delete[] startX;      startX = 0;            break;    case AT_METROPOLIS:      algorithm = new Metropolis(mrf,startX,burningTime,samplingInterval,num_samples);      delete[] startX;      startX = 0;            break;    case AT_MEAN_FIELD:      if (logspace) {	algorithm = new LogMeanField(mrf,maxIter,logBels,threshold);      }      else {	algorithm = new MeanField(mrf,maxIter,threshold);      }      break;          default:      mexErrMsgTxt("invalid algorithm type. possible values are: 0-loopy, 1-gbp, 2-gibbs, 3-wolff, 4-swendswen-wang, 5-metropolis, 6-mean-field");      break;  }    // **************************************************************************  // make inference  // **************************************************************************  int converged;  double** beliefs = algorithm->inference(&converged);  double** singleBeliefs = 0;  double**** pairBeliefs = 0;    switch (algo_type) {        case AT_LOOPY:            if (nlhs > 2) {	if (countingNode != 0) {	  pairBeliefs = ((PairsGBP*)algorithm)->calcPairBeliefs();	}	else {	  pairBeliefs = ((Loopy*)algorithm)->calcPairBeliefs();	}      }            break;          case AT_GBP:      bool marg;      if (sumOrMax==SUM) {	marg = ((GBP*)algorithm)->isSumMarg(.0001);      }      else {	marg = ((GBP*)algorithm)->isMaxMarg(.0001);      }      if (!marg) {	converged = -2;	mexWarnMsgTxt("resulted beliefs are not marginalizable\n");      }      if (!regBeliefs || nlhs > 4) {		singleBeliefs = new double*[num_nodes];	for (int i=0; i<num_nodes; i++) {	  singleBeliefs[i] = new double[mrf->V[i]];      	}	processor->extractSingle(beliefs,singleBeliefs,sumOrMax);      	if ((!regBeliefs && nlhs > 2) || (regBeliefs && nlhs > 5)) {	  pairBeliefs = new double***[num_nodes];	  for (int i=0; i<num_nodes; i++) {	    pairBeliefs[i] = new double**[mrf->neighbNum(i)];	    for (int n=0; n<mrf->neighbNum(i); n++) {	      pairBeliefs[i][n] = 0;	      int j = mrf->adjMat[i][n];	      if (i<j) {		pairBeliefs[i][n] = new double*[mrf->V[i]];		for (int xi=0; xi<mrf->V[i]; xi++) {		  pairBeliefs[i][n][xi] = new double[mrf->V[j]];		}	      }	    }	  }	  processor->extractPairs(beliefs,pairBeliefs,sumOrMax);	}	if (!regBeliefs) {	  beliefs = singleBeliefs;	}      }      break;          default:            break;  }    // **************************************************************************  // assign results to output argument (if given)  // **************************************************************************  if (regBeliefs) {    int num_regs = reg_mrf->N;    int regs_dims[2] = {1,num_regs};    Region** all_regions = processor->getAllRegions();    // assign: 1. regions 2. regions' adj-matrix 3. region beliefs 4. convergence flag    plhs[0] = mxCreateCellArray(2,regs_dims);    plhs[1] = mxCreateCellArray(2,regs_dims);    plhs[2] = mxCreateCellArray(2,regs_dims);    plhs[3] = mxCreateDoubleScalar(converged);    for (int i=0; i<num_regs; i++) {      // regions      Region* region = all_regions[i];      int reg_i_size = (int)(region->size());      int reg_i_dims[2] = {1,reg_i_size};      mxArray* reg_i = mxCreateNumericArray(2,reg_i_dims,mxDOUBLE_CLASS, mxREAL);      double* reg_i_ptr = mxGetPr(reg_i);      for (int n=0; n<reg_i_size; n++) {	reg_i_ptr[n] = (double)((*region)[n] + 1);      }      mxSetCell(plhs[0],i,reg_i);            // adj-matrix      int adj_dims[2] = {1,reg_mrf->neighbNum(i)};      mxArray* adj_i = mxCreateNumericArray(2,adj_dims,mxDOUBLE_CLASS, mxREAL);      double* adj_i_ptr = mxGetPr(adj_i);      for (int n=0; n<reg_mrf->neighbNum(i); n++) {	adj_i_ptr[n] = (double)(reg_mrf->adjMat[i][n] + 1);      }      mxSetCell(plhs[1],i,adj_i);      // region beliefs      int val_dims[2] = {reg_mrf->V[i],1};      mxArray* bel_i = mxCreateNumericArray(2,val_dims,mxDOUBLE_CLASS, mxREAL);      double* resBelPtr = mxGetPr(bel_i);      for (int xi=0; xi<reg_mrf->V[i]; xi++) {	resBelPtr[xi] = beliefs[i][xi];      }      mxSetCell(plhs[2],i,bel_i);    }    if (nlhs > 4) {      int bel_dims[2] = {1,num_nodes};      // assign: 5. single beliefs 6. pairwise beliefs (if required)      plhs[4] = mxCreateCellArray(2,bel_dims);      if (nlhs > 5) {	plhs[5] = mxCreateCellArray(2,bel_dims);      }      for (int i=0; i<num_nodes; i++) {	// single beliefs	int val_dims[2] = {mrf->V[i],1};	mxArray* bel_i = mxCreateNumericArray(2,val_dims,mxDOUBLE_CLASS, mxREAL);	double* resBelPtr = mxGetPr(bel_i);	for (int xi=0; xi<mrf->V[i]; xi++) {	  resBelPtr[xi] = singleBeliefs[i][xi];	}	mxSetCell(plhs[4],i,bel_i);	// pairwise beliefs	if (nlhs > 5) {	  int pair_i_dims[2] = {1, mrf->neighbNum(i)};	  mxArray* pbel_i = mxCreateCellArray(2,pair_i_dims);	  	  for (int n=0; n<mrf->neighbNum(i); n++) {	    int j = mrf->adjMat[i][n];	    if (i<j) {	      int pval_dims[2] = {mrf->V[i], mrf->V[j]};	      mxArray* pbel_ij = mxCreateNumericArray(2,pval_dims,mxDOUBLE_CLASS,mxREAL);	      	      double* resPBelPtr = mxGetPr(pbel_ij);	      for (int xi=0; xi<mrf->V[i]; xi++) {		for (int xj=0; xj<mrf->V[j]; xj++) {		  resPBelPtr[xi + xj*mrf->V[i]] = pairBeliefs[i][n][xi][xj];		}	      }	      mxSetCell(pbel_i, n, pbel_ij);	      	    }	  }	  mxSetCell(plhs[5], i, pbel_i);	}      }    }  }  else {    if (nlhs > 0) {      int bel_dims[2] = {1,num_nodes};      plhs[0] = mxCreateCellArray(2,bel_dims);      for (int i=0; i<num_nodes; i++) {	int val_dims[2] = {mrf->V[i],1};	mxArray* bel_i = mxCreateNumericArray(2,val_dims,mxDOUBLE_CLASS, mxREAL);	double* resBelPtr = mxGetPr(bel_i);	for (int xi=0; xi<mrf->V[i]; xi++) {	  resBelPtr[xi] = beliefs[i][xi];	}	mxSetCell(plhs[0],i,bel_i);      }      if (nlhs > 1) {	plhs[1] = mxCreateDoubleScalar(converged); // For matlab6.5	//plhs[1] = mxCreateScalarDouble(converged);	if (nlhs > 2) {	  int pair_dims[2] = {1,num_nodes};	  plhs[2] = mxCreateCellArray(2,pair_dims);	  for (int i=0; i<num_nodes; i++) {	    int pair_i_dims[2] = {1, mrf->neighbNum(i)};	    mxArray* bel_i = mxCreateCellArray(2,pair_i_dims);	  	    for (int n=0; n<mrf->neighbNum(i); n++) {	      int j = mrf->adjMat[i][n];	      if (i<j) {		int pval_dims[2] = {mrf->V[i], mrf->V[j]};		mxArray* bel_ij = mxCreateNumericArray(2,pval_dims,mxDOUBLE_CLASS,mxREAL);	      		double* resBelPtr = mxGetPr(bel_ij);		for (int xi=0; xi<mrf->V[i]; xi++) {		  for (int xj=0; xj<mrf->V[j]; xj++) {		    resBelPtr[xi + xj*mrf->V[i]] = pairBeliefs[i][n][xi][xj];		  }		}		mxSetCell(bel_i, n, bel_ij);	      	      }	    }	    mxSetCell(plhs[2], i, bel_i);	  }	  if (nlhs > 3) {	    double*** msg = 0;	    int msg_dims[2];	  	    switch (algo_type) {    	      case AT_LOOPY:      		if (countingNode != 0) {		  msg = ((PairsGBP*)algorithm)->getMessages();		}		else {		  msg = ((Loopy*)algorithm)->getMessages();		}	    		msg_dims[0] = 1;		msg_dims[1] = num_nodes;	      		plhs[3] = mxCreateCellArray(2,msg_dims);		for (int i=0; i<num_nodes; i++) {		  int Ni = mrf->neighbNum(i);		  int msg_i_dims[2];		  msg_i_dims[0] = 1;		  msg_i_dims[1] = (saveTime ? num_nodes : Ni);		  mxArray* msg_i = mxCreateCellArray(2,msg_i_dims);		  for (int n=0; n<Ni; n++) {		    int j = mrf->adjMat[i][n];		    int nei = (saveTime ? j : n);		    int msg_ij_dims[2] = {1, mrf->V[j]};		    mxArray* msg_ij = mxCreateNumericArray(2,msg_ij_dims,mxDOUBLE_CLASS, mxREAL);		    double* msg_ij_ptr = mxGetPr(msg_ij);		    for (int xj=0; xj<mrf->V[j]; xj++) {		      msg_ij_ptr[xj] = msg[i][nei][xj];		    }		    mxSetCell(msg_i,nei,msg_ij);		  }		  mxSetCell(plhs[3],i,msg_i);		}				break;      	      case AT_GBP:		if (!trw) {		  msg = ((GBP*)algorithm)->getMessages();		  msg_dims[0] = 1;		  msg_dims[1] = reg_mrf->N;		  plhs[3] = mxCreateCellArray(2,msg_dims);		  for (int i=0; i<reg_mrf->N; i++) {		    int Ni = reg_mrf->neighbNum(i);		    int msg_i_dims[2] = {1,Ni};		    mxArray* msg_i = mxCreateCellArray(2,msg_i_dims);		    for (int n=0; n<Ni; n++) {		      int j = reg_mrf->adjMat[i][n];		      int numStates = reg_mrf->V[max(i,j)];		      int msg_ij_dims[2] = {1, numStates};		      mxArray* msg_ij = mxCreateNumericArray(2,msg_ij_dims,mxDOUBLE_CLASS, mxREAL);		      double* msg_ij_ptr = mxGetPr(msg_ij);		      for (int xs=0; xs<numStates; xs++) {			msg_ij_ptr[xs] = msg[i][n][xs];		      }		      mxSetCell(msg_i,n,msg_ij);		    }		    mxSetCell(plhs[3],i,msg_i);		  }		}		break;	      default:	      		break;	    	    }	  	  }	}      }    }  }  // **************************************************************************  // free memory  // **************************************************************************  delete algorithm;  algorithm = 0;    if (singleBeliefs != 0) {    for (int i=0; i<num_nodes; i++) {      delete[] singleBeliefs[i];    }    delete[] singleBeliefs;        singleBeliefs = 0;  }  if (pairBeliefs != 0 && algo_type == AT_GBP) {    for (int i=0; i<num_nodes; i++) {      for (int n=0; n<mrf->neighbNum(i); n++) {	if (pairBeliefs[i][n] != 0) {	  for (int xi=0; xi<mrf->V[i]; xi++) {	    delete[] pairBeliefs[i][n][xi];	  }	  delete[] pairBeliefs[i][n];	}      }      delete[] pairBeliefs[i];    }    delete[] pairBeliefs;    pairBeliefs = 0;  }  if (processor != 0) {    delete processor;    processor = 0;  }  if (rho != 0) {    for (int i=0; i<num_nodes; i++) {      delete[] rho[i];      rho[i] = 0;    }    delete[] rho;    rho = 0;  }  if (countingNode != 0) {    delete[] countingNode;    countingNode = 0;  }  if (bigRegsPot != 0) {    for (int i=0; i<num_regs; i++) {      delete[] bigRegsPot[i];    }    delete[] bigRegsPot;    bigRegsPot = 0;  }    delete mrf;  mrf = 0;  delete adjMat;  adjMat = 0;  }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -