📄 c_inference.cpp
字号:
} break; case AT_GIBBS: algorithm = new Gibbs(mrf,startX,burningTime,samplingInterval,num_samples); delete[] startX; startX = 0; break; case AT_WOLFF: algorithm = new Wolff((PottsMRF*)mrf,startX,burningTime,samplingInterval,num_samples); delete[] startX; startX = 0; break; case AT_SWENDSEN_WANG: algorithm = new SwendsenWang((PottsMRF*)mrf,startX,burningTime,samplingInterval,num_samples); delete[] startX; startX = 0; break; case AT_METROPOLIS: algorithm = new Metropolis(mrf,startX,burningTime,samplingInterval,num_samples); delete[] startX; startX = 0; break; case AT_MEAN_FIELD: if (logspace) { algorithm = new LogMeanField(mrf,maxIter,logBels,threshold); } else { algorithm = new MeanField(mrf,maxIter,threshold); } break; default: mexErrMsgTxt("invalid algorithm type. possible values are: 0-loopy, 1-gbp, 2-gibbs, 3-wolff, 4-swendswen-wang, 5-metropolis, 6-mean-field"); break; } // ************************************************************************** // make inference // ************************************************************************** int converged; double** beliefs = algorithm->inference(&converged); double** singleBeliefs = 0; double**** pairBeliefs = 0; switch (algo_type) { case AT_LOOPY: if (nlhs > 2) { if (countingNode != 0) { pairBeliefs = ((PairsGBP*)algorithm)->calcPairBeliefs(); } else { pairBeliefs = ((Loopy*)algorithm)->calcPairBeliefs(); } } break; case AT_GBP: bool marg; if (sumOrMax==SUM) { marg = ((GBP*)algorithm)->isSumMarg(.0001); } else { marg = ((GBP*)algorithm)->isMaxMarg(.0001); } if (!marg) { converged = -2; mexWarnMsgTxt("resulted beliefs are not marginalizable\n"); } if (!regBeliefs || nlhs > 4) { singleBeliefs = new double*[num_nodes]; for (int i=0; i<num_nodes; i++) { singleBeliefs[i] = new double[mrf->V[i]]; } processor->extractSingle(beliefs,singleBeliefs,sumOrMax); if ((!regBeliefs && nlhs > 2) || (regBeliefs && nlhs > 5)) { pairBeliefs = new double***[num_nodes]; for (int i=0; i<num_nodes; i++) { pairBeliefs[i] = new double**[mrf->neighbNum(i)]; for (int n=0; n<mrf->neighbNum(i); n++) { pairBeliefs[i][n] = 0; int j = mrf->adjMat[i][n]; if (i<j) { pairBeliefs[i][n] = new double*[mrf->V[i]]; for (int xi=0; xi<mrf->V[i]; xi++) { pairBeliefs[i][n][xi] = new double[mrf->V[j]]; } } } } processor->extractPairs(beliefs,pairBeliefs,sumOrMax); } if (!regBeliefs) { beliefs = singleBeliefs; } } break; default: break; } // ************************************************************************** // assign results to output argument (if given) // ************************************************************************** if (regBeliefs) { int num_regs = reg_mrf->N; int regs_dims[2] = {1,num_regs}; Region** all_regions = processor->getAllRegions(); // assign: 1. regions 2. regions' adj-matrix 3. region beliefs 4. convergence flag plhs[0] = mxCreateCellArray(2,regs_dims); plhs[1] = mxCreateCellArray(2,regs_dims); plhs[2] = mxCreateCellArray(2,regs_dims); plhs[3] = mxCreateDoubleScalar(converged); for (int i=0; i<num_regs; i++) { // regions Region* region = all_regions[i]; int reg_i_size = (int)(region->size()); int reg_i_dims[2] = {1,reg_i_size}; mxArray* reg_i = mxCreateNumericArray(2,reg_i_dims,mxDOUBLE_CLASS, mxREAL); double* reg_i_ptr = mxGetPr(reg_i); for (int n=0; n<reg_i_size; n++) { reg_i_ptr[n] = (double)((*region)[n] + 1); } mxSetCell(plhs[0],i,reg_i); // adj-matrix int adj_dims[2] = {1,reg_mrf->neighbNum(i)}; mxArray* adj_i = mxCreateNumericArray(2,adj_dims,mxDOUBLE_CLASS, mxREAL); double* adj_i_ptr = mxGetPr(adj_i); for (int n=0; n<reg_mrf->neighbNum(i); n++) { adj_i_ptr[n] = (double)(reg_mrf->adjMat[i][n] + 1); } mxSetCell(plhs[1],i,adj_i); // region beliefs int val_dims[2] = {reg_mrf->V[i],1}; mxArray* bel_i = mxCreateNumericArray(2,val_dims,mxDOUBLE_CLASS, mxREAL); double* resBelPtr = mxGetPr(bel_i); for (int xi=0; xi<reg_mrf->V[i]; xi++) { resBelPtr[xi] = beliefs[i][xi]; } mxSetCell(plhs[2],i,bel_i); } if (nlhs > 4) { int bel_dims[2] = {1,num_nodes}; // assign: 5. single beliefs 6. pairwise beliefs (if required) plhs[4] = mxCreateCellArray(2,bel_dims); if (nlhs > 5) { plhs[5] = mxCreateCellArray(2,bel_dims); } for (int i=0; i<num_nodes; i++) { // single beliefs int val_dims[2] = {mrf->V[i],1}; mxArray* bel_i = mxCreateNumericArray(2,val_dims,mxDOUBLE_CLASS, mxREAL); double* resBelPtr = mxGetPr(bel_i); for (int xi=0; xi<mrf->V[i]; xi++) { resBelPtr[xi] = singleBeliefs[i][xi]; } mxSetCell(plhs[4],i,bel_i); // pairwise beliefs if (nlhs > 5) { int pair_i_dims[2] = {1, mrf->neighbNum(i)}; mxArray* pbel_i = mxCreateCellArray(2,pair_i_dims); for (int n=0; n<mrf->neighbNum(i); n++) { int j = mrf->adjMat[i][n]; if (i<j) { int pval_dims[2] = {mrf->V[i], mrf->V[j]}; mxArray* pbel_ij = mxCreateNumericArray(2,pval_dims,mxDOUBLE_CLASS,mxREAL); double* resPBelPtr = mxGetPr(pbel_ij); for (int xi=0; xi<mrf->V[i]; xi++) { for (int xj=0; xj<mrf->V[j]; xj++) { resPBelPtr[xi + xj*mrf->V[i]] = pairBeliefs[i][n][xi][xj]; } } mxSetCell(pbel_i, n, pbel_ij); } } mxSetCell(plhs[5], i, pbel_i); } } } } else { if (nlhs > 0) { int bel_dims[2] = {1,num_nodes}; plhs[0] = mxCreateCellArray(2,bel_dims); for (int i=0; i<num_nodes; i++) { int val_dims[2] = {mrf->V[i],1}; mxArray* bel_i = mxCreateNumericArray(2,val_dims,mxDOUBLE_CLASS, mxREAL); double* resBelPtr = mxGetPr(bel_i); for (int xi=0; xi<mrf->V[i]; xi++) { resBelPtr[xi] = beliefs[i][xi]; } mxSetCell(plhs[0],i,bel_i); } if (nlhs > 1) { plhs[1] = mxCreateDoubleScalar(converged); // For matlab6.5 //plhs[1] = mxCreateScalarDouble(converged); if (nlhs > 2) { int pair_dims[2] = {1,num_nodes}; plhs[2] = mxCreateCellArray(2,pair_dims); for (int i=0; i<num_nodes; i++) { int pair_i_dims[2] = {1, mrf->neighbNum(i)}; mxArray* bel_i = mxCreateCellArray(2,pair_i_dims); for (int n=0; n<mrf->neighbNum(i); n++) { int j = mrf->adjMat[i][n]; if (i<j) { int pval_dims[2] = {mrf->V[i], mrf->V[j]}; mxArray* bel_ij = mxCreateNumericArray(2,pval_dims,mxDOUBLE_CLASS,mxREAL); double* resBelPtr = mxGetPr(bel_ij); for (int xi=0; xi<mrf->V[i]; xi++) { for (int xj=0; xj<mrf->V[j]; xj++) { resBelPtr[xi + xj*mrf->V[i]] = pairBeliefs[i][n][xi][xj]; } } mxSetCell(bel_i, n, bel_ij); } } mxSetCell(plhs[2], i, bel_i); } if (nlhs > 3) { double*** msg = 0; int msg_dims[2]; switch (algo_type) { case AT_LOOPY: if (countingNode != 0) { msg = ((PairsGBP*)algorithm)->getMessages(); } else { msg = ((Loopy*)algorithm)->getMessages(); } msg_dims[0] = 1; msg_dims[1] = num_nodes; plhs[3] = mxCreateCellArray(2,msg_dims); for (int i=0; i<num_nodes; i++) { int Ni = mrf->neighbNum(i); int msg_i_dims[2]; msg_i_dims[0] = 1; msg_i_dims[1] = (saveTime ? num_nodes : Ni); mxArray* msg_i = mxCreateCellArray(2,msg_i_dims); for (int n=0; n<Ni; n++) { int j = mrf->adjMat[i][n]; int nei = (saveTime ? j : n); int msg_ij_dims[2] = {1, mrf->V[j]}; mxArray* msg_ij = mxCreateNumericArray(2,msg_ij_dims,mxDOUBLE_CLASS, mxREAL); double* msg_ij_ptr = mxGetPr(msg_ij); for (int xj=0; xj<mrf->V[j]; xj++) { msg_ij_ptr[xj] = msg[i][nei][xj]; } mxSetCell(msg_i,nei,msg_ij); } mxSetCell(plhs[3],i,msg_i); } break; case AT_GBP: if (!trw) { msg = ((GBP*)algorithm)->getMessages(); msg_dims[0] = 1; msg_dims[1] = reg_mrf->N; plhs[3] = mxCreateCellArray(2,msg_dims); for (int i=0; i<reg_mrf->N; i++) { int Ni = reg_mrf->neighbNum(i); int msg_i_dims[2] = {1,Ni}; mxArray* msg_i = mxCreateCellArray(2,msg_i_dims); for (int n=0; n<Ni; n++) { int j = reg_mrf->adjMat[i][n]; int numStates = reg_mrf->V[max(i,j)]; int msg_ij_dims[2] = {1, numStates}; mxArray* msg_ij = mxCreateNumericArray(2,msg_ij_dims,mxDOUBLE_CLASS, mxREAL); double* msg_ij_ptr = mxGetPr(msg_ij); for (int xs=0; xs<numStates; xs++) { msg_ij_ptr[xs] = msg[i][n][xs]; } mxSetCell(msg_i,n,msg_ij); } mxSetCell(plhs[3],i,msg_i); } } break; default: break; } } } } } } // ************************************************************************** // free memory // ************************************************************************** delete algorithm; algorithm = 0; if (singleBeliefs != 0) { for (int i=0; i<num_nodes; i++) { delete[] singleBeliefs[i]; } delete[] singleBeliefs; singleBeliefs = 0; } if (pairBeliefs != 0 && algo_type == AT_GBP) { for (int i=0; i<num_nodes; i++) { for (int n=0; n<mrf->neighbNum(i); n++) { if (pairBeliefs[i][n] != 0) { for (int xi=0; xi<mrf->V[i]; xi++) { delete[] pairBeliefs[i][n][xi]; } delete[] pairBeliefs[i][n]; } } delete[] pairBeliefs[i]; } delete[] pairBeliefs; pairBeliefs = 0; } if (processor != 0) { delete processor; processor = 0; } if (rho != 0) { for (int i=0; i<num_nodes; i++) { delete[] rho[i]; rho[i] = 0; } delete[] rho; rho = 0; } if (countingNode != 0) { delete[] countingNode; countingNode = 0; } if (bigRegsPot != 0) { for (int i=0; i<num_regs; i++) { delete[] bigRegsPot[i]; } delete[] bigRegsPot; bigRegsPot = 0; } delete mrf; mrf = 0; delete adjMat; adjMat = 0; }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -