⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 loopystime.cpp

📁 The package includes 3 Matlab-interfaces to the c-code: 1. inference.m An interface to the full
💻 CPP
📖 第 1 页 / 共 2 页
字号:
#include "LoopySTime.h"#include <math.h>#include <iostream>#include "mex.h"using namespace std;LoopySTime::~LoopySTime() {  freeMessages();  freePairBeliefs();}void LoopySTime::initMessages(double*** initMsg) {  freeMessages();    // init the messages matrix  if (initMsg != 0) {    l_messages = initMsg;  }  else {    // init the messages matrix    l_messages = new double**[ia_mrf->N];    for (int i=0; i<ia_mrf->N; i++) {      l_messages[i] = new double*[ia_mrf->N];      for (int j=0; j<ia_mrf->N; j++) {	l_messages[i][j] = 0;      }      for (int n=0; n<ia_mrf->neighbNum(i); n++) {	int j = ia_mrf->adjMat[i][n];	l_messages[i][j] = new double[ia_mrf->V[j]];	for (int xj=0; xj<ia_mrf->V[j]; xj++) {	  l_messages[i][j][xj] = 1.0 / ia_mrf->V[j];	}      }    }  }}void LoopySTime::freeMessages() {  if (l_messages != 0) {    // free the messages matrix    for (int i=0; i<ia_mrf->N; i++) {      for (int j=0; j<ia_mrf->N; j++) {	if (l_messages[i][j] != 0) {	  delete[] l_messages[i][j];	}      }      delete[] l_messages[i];    }    delete[] l_messages;    l_messages = 0;  }}void LoopySTime::initPairBeliefs() {  freePairBeliefs();  // init the pair beliefs defultive to p(Xi=xi, Xj=xj) = Psi(xi,i)*Psi(xj,j)*Psi(xi,i,xj,j)  l_pairBeliefs = new double***[ia_mrf->N];  for (int i=0; i<ia_mrf->N; i++) {    l_pairBeliefs[i] = new double**[ia_mrf->neighbNum(i)];    for (int n=0; n<ia_mrf->neighbNum(i); n++) {      l_pairBeliefs[i][n] = 0;      int j = ia_mrf->adjMat[i][n];      if (i<j) {	l_pairBeliefs[i][n] = new double*[ia_mrf->V[i]];	for (int xi=0; xi<ia_mrf->V[i]; xi++) {	  l_pairBeliefs[i][n][xi] = new double[ia_mrf->V[j]];	  for (int xj=0; xj<ia_mrf->V[j]; xj++) {	    l_pairBeliefs[i][n][xi][xj] = (ia_mrf->localMat[i][xi] *					   ia_mrf->localMat[j][xj] *					   ia_mrf->pairPotential(i,n,xi,xj));	  }	}      }    }  }}void LoopySTime::freePairBeliefs() {  if (l_pairBeliefs != 0) {    for (int i=0; i<ia_mrf->N; i++) {      for (int n=0; n<ia_mrf->neighbNum(i); n++) {	if (l_pairBeliefs[i][n] != 0) {	  for (int xi=0; xi<ia_mrf->V[i]; xi++) {	    delete[] l_pairBeliefs[i][n][xi];	  }	  delete[] l_pairBeliefs[i][n];	}      }      delete[] l_pairBeliefs[i];    }    delete[] l_pairBeliefs;    l_pairBeliefs = 0;  }}void LoopySTime::initTRWRho(double** trwRho) {  freeTRWRho();  if (trwRho != 0) {    l_trwRho = new double*[ia_mrf->N];    for (int i=0; i<ia_mrf->N; i++) {      l_trwRho[i] = new double[ia_mrf->neighbNum(i)];      for (int n=0; n<ia_mrf->neighbNum(i); n++) {	l_trwRho[i][n] = trwRho[i][n];      }    }  }}void LoopySTime::freeTRWRho() {  if (l_trwRho != 0) {    for (int i=0; i<ia_mrf->N; i++) {      delete[] l_trwRho[i];      l_trwRho[i] = 0;    }    delete[] l_trwRho;    l_trwRho = 0;  }}double**** LoopySTime::calcPairBeliefs() {  if (l_trwRho != 0) {    return calcPairBeliefsTRBP();  }    double**** new_pairBeliefs = new double***[ia_mrf->N];  for (int i=0; i<ia_mrf->N; i++) {    new_pairBeliefs[i] = new double**[ia_mrf->neighbNum(i)];    for (int n=0; n<ia_mrf->neighbNum(i); n++) {      new_pairBeliefs[i][n] = 0;      int j = ia_mrf->adjMat[i][n];      if (i<j) {	double sum_beliefs_ij = 0.0;	new_pairBeliefs[i][n] = new double*[ia_mrf->V[i]];	for (int xi=0; xi<ia_mrf->V[i]; xi++) {	  new_pairBeliefs[i][n][xi] = new double[ia_mrf->V[j]];	  for (int xj=0; xj<ia_mrf->V[j]; xj++) {	    new_pairBeliefs[i][n][xi][xj] = (ia_mrf->localMat[i][xi] *					     ia_mrf->localMat[j][xj] *					     ia_mrf->pairPotential(i,n,xi,xj));	    for (int ni=0; ni<ia_mrf->neighbNum(i); ni++) {	      int k = ia_mrf->adjMat[i][ni];	      if (k!=j) {		new_pairBeliefs[i][n][xi][xj] *= l_messages[k][i][xi];	      }	    }	    for (int nj=0; nj<ia_mrf->neighbNum(j); nj++) {	      int k = ia_mrf->adjMat[j][nj];	      if (k!=i) {		new_pairBeliefs[i][n][xi][xj] *= l_messages[k][j][xj];	      }	    }	    sum_beliefs_ij += new_pairBeliefs[i][n][xi][xj];	  }	}	// normalize the ij-beliefs	if (sum_beliefs_ij > 0.0) {	  for (int xi=0; xi<ia_mrf->V[i]; xi++) {	    for (int xj=0; xj<ia_mrf->V[j]; xj++) {	      new_pairBeliefs[i][n][xi][xj] /= sum_beliefs_ij;	    }	  }	}      }    }  }  freePairBeliefs();  l_pairBeliefs = new_pairBeliefs;  new_pairBeliefs = 0;  return l_pairBeliefs;}double**** LoopySTime::calcPairBeliefsTRBP() {  double**** new_pairBeliefs = new double***[ia_mrf->N];  for (int i=0; i<ia_mrf->N; i++) {    new_pairBeliefs[i] = new double**[ia_mrf->neighbNum(i)];    for (int n=0; n<ia_mrf->neighbNum(i); n++) {      new_pairBeliefs[i][n] = 0;      int j = ia_mrf->adjMat[i][n];      if (i<j) {	double sum_beliefs_ij = 0.0;	new_pairBeliefs[i][n] = new double*[ia_mrf->V[i]];	for (int xi=0; xi<ia_mrf->V[i]; xi++) {	  new_pairBeliefs[i][n][xi] = new double[ia_mrf->V[j]];	  for (int xj=0; xj<ia_mrf->V[j]; xj++) {	    new_pairBeliefs[i][n][xi][xj] = (ia_mrf->localMat[i][xi] *					     ia_mrf->localMat[j][xj] *					     ia_mrf->pairPotential(i,n,xi,xj));// 	    new_pairBeliefs[i][n][xi][xj] = (ia_mrf->localMat[i][xi] *// 					     ia_mrf->localMat[j][xj] *// 					     pow(ia_mrf->pairPotential(i,n,xi,xj),(1.0 / l_trwRho[i][n])));	    for (int ni=0; ni<ia_mrf->neighbNum(i); ni++) {	      int k = ia_mrf->adjMat[i][ni];	      new_pairBeliefs[i][n][xi][xj] *= pow(l_messages[k][i][xi],l_trwRho[i][ni]);	    }	    new_pairBeliefs[i][n][xi][xj] /= l_messages[j][i][xi];	    for (int nj=0; nj<ia_mrf->neighbNum(j); nj++) {	      int k = ia_mrf->adjMat[j][nj];	      new_pairBeliefs[i][n][xi][xj] *= pow(l_messages[k][j][xj],l_trwRho[j][nj]);	    }	    new_pairBeliefs[i][n][xi][xj] /= l_messages[i][j][xj];	    sum_beliefs_ij += new_pairBeliefs[i][n][xi][xj];	  }	}	// normalize the ij-beliefs	if (sum_beliefs_ij > 0.0) {	  for (int xi=0; xi<ia_mrf->V[i]; xi++) {	    for (int xj=0; xj<ia_mrf->V[j]; xj++) {	      new_pairBeliefs[i][n][xi][xj] /= sum_beliefs_ij;	    }	  }	}      }    }  }  freePairBeliefs();  l_pairBeliefs = new_pairBeliefs;  new_pairBeliefs = 0;  return l_pairBeliefs;}double** LoopySTime::inference(int* converged) {  if (l_trwRho != 0) {    return inferenceTRBP(converged);  }    double epsilon = pow(10.,-200);  double dBel = l_th+1.0;  int nIter = 0;    double*** new_messages = 0;  if (l_strategy == PARALLEL) {    new_messages = new double**[ia_mrf->N];    for (int i=0; i<ia_mrf->N; i++) {      new_messages[i] = new double*[ia_mrf->N];      for (int j=0; j<ia_mrf->N; j++) {	new_messages[i][j] = 0;      }      for (int n=0; n<ia_mrf->neighbNum(i); n++) {	int j = ia_mrf->adjMat[i][n];	new_messages[i][j] = new double[ia_mrf->V[j]];	for (int xj=0; xj<ia_mrf->V[j]; xj++) {	  new_messages[i][j][xj] = l_messages[i][j][xj];	}      }    }  }  while (dBel>l_th && nIter<l_maxIter) {    nIter++;    for (int i=0; i<ia_mrf->N; i++) {      // init the incoming messages to 1      double* incoming = new double[ia_mrf->V[i]];      double* factor = new double[ia_mrf->neighbNum(i)];      for (int xi=0; xi<ia_mrf->V[i]; xi++) {	incoming[xi] = 1.0;      }      // get incoming messages      for (int n=0; n<ia_mrf->neighbNum(i); n++) {	int j = ia_mrf->adjMat[i][n];	factor[n] = 0.0;	for (int xi=0; xi<ia_mrf->V[i]; xi++) {	  incoming[xi] *= l_messages[j][i][xi];	  factor[n] += incoming[xi];	}	for (int xi=0; xi<ia_mrf->V[i]; xi++) {	  incoming[xi] /= factor[n];	}      }            // calculate outgoing messages      for (int n=0; n<ia_mrf->neighbNum(i); n++) {	int j = ia_mrf->adjMat[i][n];	double sum_outgoing_to_j = 0.0;	double* outgoing = new double[ia_mrf->V[j]];		for (int xj=0; xj<ia_mrf->V[j]; xj++) {	  	  switch (l_sumOrMax) {	    case SUM:	      outgoing[xj] = 0.0;	      break;	    case MAX:	      outgoing[xj] = -1.0;	      break;	    default:	      break;	  }	    	  for (int xi=0; xi<ia_mrf->V[i]; xi++) {	    double outM = ia_mrf->pairPotential(i,n,xi,xj) * ia_mrf->localMat[i][xi] *	      incoming[xi] / l_messages[j][i][xi];// 	    double outM = ia_mrf->pairPotential(i,n,xi,xj) * ia_mrf->localMat[i][xi] *// 	      incoming[xi] * factor[n] / l_messages[j][i][xi];	    switch (l_sumOrMax) {	      case SUM:		outgoing[xj] += outM;		break;	      case MAX:		if (outM > outgoing[xj]) {		  outgoing[xj] = outM;		}		break;	      default:		break;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -