📄 loopy.cpp
字号:
case MAX: if (outM > outgoing[xj]) { outgoing[xj] = outM; } break; default: break; } } if (outgoing[xj] < epsilon) outgoing[xj] = epsilon; sum_outgoing_to_j += outgoing[xj]; } for (int xj=0; xj<ia_mrf->V[j]; xj++) { if (sum_outgoing_to_j > 0.0) { outgoing[xj] /= sum_outgoing_to_j; } if (!(outgoing[xj]>0.0)) { nIter = l_maxIter + 1; break; } switch (l_strategy) { case SEQUENTIAL: l_messages[i][n][xj] = outgoing[xj]; break; case PARALLEL: new_messages[i][n][xj] = outgoing[xj]; break; default: break; } } delete[] outgoing; outgoing = 0; } delete[] incoming; incoming = 0; delete[] factor; factor = 0; } if (l_strategy == PARALLEL) { for (int i=0; i<ia_mrf->N; i++) { for (int n=0; n<ia_mrf->neighbNum(i); n++) { int j = ia_mrf->adjMat[i][n]; for (int xj=0; xj<ia_mrf->V[j]; xj++) { l_messages[i][n][xj] = new_messages[i][n][xj]; } } } } // update beliefs and check for convergence dBel = 0.0; double** new_beliefs = new double*[ia_mrf->N]; for (int i=0; i<ia_mrf->N; i++) { new_beliefs[i] = new double[ia_mrf->V[i]]; double sum_beliefs_i = 0.0; for (int xi=0; xi<ia_mrf->V[i]; xi++) { new_beliefs[i][xi] = ia_mrf->localMat[i][xi]; for (int n=0; n<ia_mrf->neighbNum(i); n++) { int j = ia_mrf->adjMat[i][n]; int nj = 0; while (ia_mrf->adjMat[j][nj] != i) { nj++; } new_beliefs[i][xi] *= l_messages[j][nj][xi]; } sum_beliefs_i += new_beliefs[i][xi]; } double norm_dBel_i = 0.0; for (int xi=0; xi<ia_mrf->V[i]; xi++) { if (sum_beliefs_i > 0.0) { new_beliefs[i][xi] /= sum_beliefs_i; } norm_dBel_i += pow((new_beliefs[i][xi] - ia_beliefs[i][xi]), 2.0); } norm_dBel_i = pow(norm_dBel_i, 0.5); dBel += norm_dBel_i; } freeBeliefs(); ia_beliefs = new_beliefs; new_beliefs = 0; } if (l_strategy == PARALLEL) { for (int i=0; i<ia_mrf->N; i++) { for (int n=0; n<ia_mrf->neighbNum(i); n++) { delete[] new_messages[i][n]; } delete[] new_messages[i]; } delete[] new_messages; new_messages = 0; } if (nIter > l_maxIter) { (*converged) = -1; mexPrintf("c-Loopy: messages decreased to zero, iterating stopped\n"); } else { if (dBel<=l_th) { (*converged) = nIter; mexPrintf("c-Loopy: converged in %d iterations\n",nIter); } else { (*converged) = -1; mexPrintf("c-Loopy: did not converge after %d iterations\n",nIter); } } return ia_beliefs; }double** Loopy::inferenceTRBP(int* converged) { double epsilon = pow(10.,-200); double dBel = l_th+1.0; int nIter = 0; double*** new_messages = 0; if (l_strategy == PARALLEL) { new_messages = new double**[ia_mrf->N]; for (int i=0; i<ia_mrf->N; i++) { new_messages[i] = new double*[ia_mrf->neighbNum(i)]; for (int n=0; n<ia_mrf->neighbNum(i); n++) { int j = ia_mrf->adjMat[i][n]; new_messages[i][n] = new double[ia_mrf->V[j]]; for (int xj=0; xj<ia_mrf->V[j]; xj++) { new_messages[i][n][xj] = l_messages[i][n][xj]; } } } } while (dBel>l_th && nIter<l_maxIter) { nIter++; for (int i=0; i<ia_mrf->N; i++) { // init the incoming messages to 1 double* incoming = new double[ia_mrf->V[i]]; double* factor = new double[ia_mrf->neighbNum(i)]; for (int xi=0; xi<ia_mrf->V[i]; xi++) { incoming[xi] = 1.0; } // get incoming messages for (int n=0; n<ia_mrf->neighbNum(i); n++) { int j = ia_mrf->adjMat[i][n]; int nj = 0; while (ia_mrf->adjMat[j][nj] != i) { nj++; } factor[n] = 0.0; for (int xi=0; xi<ia_mrf->V[i]; xi++) { incoming[xi] *= pow(l_messages[j][nj][xi], l_trwRho[i][n]); factor[n] += incoming[xi]; } for (int xi=0; xi<ia_mrf->V[i]; xi++) { incoming[xi] /= factor[n]; } } // calculate outgoing messages for (int n=0; n<ia_mrf->neighbNum(i); n++) { int j = ia_mrf->adjMat[i][n]; int nj = 0; while (ia_mrf->adjMat[j][nj] != i) { nj++; } double sum_outgoing_to_j = 0.0; double* outgoing = new double[ia_mrf->V[j]]; for (int xj=0; xj<ia_mrf->V[j]; xj++) { switch (l_sumOrMax) { case SUM: outgoing[xj] = 0.0; break; case MAX: outgoing[xj] = -1.0; break; default: break; } for (int xi=0; xi<ia_mrf->V[i]; xi++) { double outM = ia_mrf->pairPotential(i,n,xi,xj) * ia_mrf->localMat[i][xi] * incoming[xi] / l_messages[j][nj][xi]; // the pair-potentials are raised by 1/rho in the matlab interface switch (l_sumOrMax) { case SUM: outgoing[xj] += outM; break; case MAX: if (outM > outgoing[xj]) { outgoing[xj] = outM; } break; default: break; } } if (outgoing[xj] < epsilon) outgoing[xj] = epsilon; sum_outgoing_to_j += outgoing[xj]; } for (int xj=0; xj<ia_mrf->V[j]; xj++) { if (sum_outgoing_to_j > 0.0) { outgoing[xj] /= sum_outgoing_to_j; } if (!(outgoing[xj]>0.0)) { nIter = l_maxIter + 1; break; } switch (l_strategy) { case SEQUENTIAL: l_messages[i][n][xj] = outgoing[xj]; break; case PARALLEL: new_messages[i][n][xj] = outgoing[xj]; break; default: break; } } delete[] outgoing; outgoing = 0; } delete[] incoming; incoming = 0; delete[] factor; factor = 0; } if (l_strategy == PARALLEL) { for (int i=0; i<ia_mrf->N; i++) { for (int n=0; n<ia_mrf->neighbNum(i); n++) { int j = ia_mrf->adjMat[i][n]; for (int xj=0; xj<ia_mrf->V[j]; xj++) { l_messages[i][n][xj] = new_messages[i][n][xj]; } } } } // update beliefs and check for convergence dBel = 0.0; double** new_beliefs = new double*[ia_mrf->N]; for (int i=0; i<ia_mrf->N; i++) { new_beliefs[i] = new double[ia_mrf->V[i]]; double sum_beliefs_i = 0.0; for (int xi=0; xi<ia_mrf->V[i]; xi++) { new_beliefs[i][xi] = ia_mrf->localMat[i][xi]; for (int n=0; n<ia_mrf->neighbNum(i); n++) { int j = ia_mrf->adjMat[i][n]; int nj = 0; while (ia_mrf->adjMat[j][nj] != i) { nj++; } new_beliefs[i][xi] *= pow(l_messages[j][nj][xi],l_trwRho[i][n]); } sum_beliefs_i += new_beliefs[i][xi]; } double norm_dBel_i = 0.0; for (int xi=0; xi<ia_mrf->V[i]; xi++) { if (sum_beliefs_i > 0.0) { new_beliefs[i][xi] /= sum_beliefs_i; } norm_dBel_i += pow((new_beliefs[i][xi] - ia_beliefs[i][xi]), 2.0); } norm_dBel_i = pow(norm_dBel_i, 0.5); dBel += norm_dBel_i; } freeBeliefs(); ia_beliefs = new_beliefs; new_beliefs = 0; } if (l_strategy == PARALLEL) { for (int i=0; i<ia_mrf->N; i++) { for (int n=0; n<ia_mrf->neighbNum(i); n++) { delete[] new_messages[i][n]; } delete[] new_messages[i]; } delete[] new_messages; new_messages = 0; } if (nIter > l_maxIter) { (*converged) = -1; mexPrintf("c-Loopy: messages decreased to zero, iterating stopped\n"); } else { if (dBel<=l_th) { (*converged) = nIter; mexPrintf("c-Loopy: converged in %d iterations\n",nIter); } else { (*converged) = -1; mexPrintf("c-Loopy: did not converge after %d iterations\n",nIter); } } return ia_beliefs; }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -