📄 collect_evidence.c
字号:
/* C mex for collect_evidence.c in @jtree_ndx_inf_engine directory *//* The file enter_evidence.m in directory @jtree_ndx_inf_engine call it *//******************************************//* collect_evidence has 3 input & 2 output*//* engine *//* clpot *//* seppot *//* *//* clpot *//* seppot *//******************************************/#include "mex.h"void multiply_by_table_ndxB(double *Tbig, const double *Tsmall, const mxArray *ndx){ int i, I, J, N, S; int *pData; double value; pData = mxGetData(ndx); I = mxGetM(ndx); J = mxGetN(ndx); N = I * J; S = mxGetNumberOfElements(ndx); if(S == 1){ value = Tsmall[*pData]; for(i=0; i<N; i++){ *Tbig++ *= value; } } else{ for(i=0; i<N; i++){ *Tbig++ *= Tsmall[*pData++]; } }}void multiply_by_table_ndxSD(double *Tbig, const double *Tsmall, const mxArray *ndx){ int i, j, k, S, D; mxArray *psmall, *pdiff; int *prd, *prs; psmall = mxGetField(ndx, 0, "small"); pdiff = mxGetField(ndx, 0, "diff"); prs = mxGetData(psmall); prd = mxGetData(pdiff); S = mxGetNumberOfElements(psmall); D = mxGetNumberOfElements(pdiff); for(i=0; i<S; i++){ for(j=0; j<D; j++){ k = prd[j] + prs[i]; Tbig[k] *= Tsmall[i]; } }}void multiply_by_table_ndxD(mxArray *Tbig, const mxArray *Tsmall, const mxArray *index){ double *pb, *ps; int i, j, I, J, N, pointer, temp; int *ndx; char *used; N = mxGetNumberOfElements(Tbig); pb = mxGetPr(Tbig); ndx = mxGetData(index); ps = mxGetPr(Tsmall); J = mxGetNumberOfElements(index); I = N / J; if(J == 1){ for(i=0; i<N; i++){ *pb++ *= *ps++; } return; } if(I == 1){ for(i=0; i<N; i++){ *pb++ *= *ps; } return; } used = (char *)malloc(N * sizeof(char)); for(i=0; i<N; i++){ used[i] = 0; } pointer = 0; for(i=0; i<I; i++){ while(used[pointer]){ pointer++; } temp = pointer; used[pointer] = 1; for(j=0; j<J; j++){ temp = pointer + ndx[j]; pb[temp] *= *ps; used[temp] = 1; } ps++; pointer++; } free(used);}mxArray *marg_table_ndxB(const double *Tbig, const int maximize, const mxArray *index){ mxArray *result; double *Tsmall; int i, I, J, N; int *ndx; ndx = mxGetData(index); I = mxGetM(index); J = mxGetN(index); N = I * J; result = mxCreateDoubleMatrix(I, 1, mxREAL); Tsmall = mxGetPr(result); if(maximize){ for (i=0; i<N; i++){ Tsmall[*ndx] = (Tsmall[*ndx] < Tbig[i])? Tbig[i] : Tsmall[*ndx]; ndx++; } } else{ for(i=0; i<N; i++){ Tsmall[*ndx++] += *Tbig++; } } return result;}mxArray *marg_table_ndxSD(const double *Tbig, const int maximize, const mxArray *index){ mxArray *ptemp, *result; double *Tsmall, max, sum; int i, j, k, S, D; int *small_ndx, *diff_ndx; ptemp = mxGetField(index, 0, "small"); small_ndx = mxGetData(ptemp); S = mxGetNumberOfElements(ptemp); ptemp = mxGetField(index, 0, "diff"); diff_ndx = mxGetData(ptemp); D = mxGetNumberOfElements(ptemp); result = mxCreateDoubleMatrix(S, 1, mxREAL); Tsmall = mxGetPr(result); if(maximize){ for(i=0; i<S; i++){ max = -1e-20; for(j=0; j<D; j++){ k = small_ndx[i] + diff_ndx[j]; if(Tbig[k] > max) max = Tbig[k]; } Tsmall[i] = max; } } else{ for(i=0; i<S; i++){ sum = 0; for(j=0; j<D; j++){ k = small_ndx[i] + diff_ndx[j]; sum += Tbig[k]; } Tsmall[i] = sum; } } return result;}mxArray *marg_table_ndxD(const mxArray *Tbig, const int maximize, const mxArray *index){ mxArray *Tsmall; double *pb, *ps; int i, j, I, J, N, pointer, temp; int *ndx; char *used; N = mxGetNumberOfElements(Tbig); pb = mxGetPr(Tbig); ndx = mxGetData(index); J = mxGetNumberOfElements(index); I = N / J; if(J == 1){ Tsmall = mxDuplicateArray(Tbig); return Tsmall; } if(I == 1){ Tsmall = mxCreateDoubleMatrix(1, 1, mxREAL); ps = mxGetPr(Tsmall); if(maximize){ for(i=0; i<N; i++){ *ps = (*ps < *pb)? *pb : *ps; pb++; } } else{ for(i=0; i<N; i++){ *ps += *pb++; } } return Tsmall; } Tsmall = mxCreateDoubleMatrix(I, 1, mxREAL); ps = mxGetPr(Tsmall); used = (char *)malloc(N * sizeof(char)); for(i=0; i<N; i++){ used[i] = 0; } pointer = 0; if(maximize){ for(i=0; i<I; i++){ while(used[pointer]){ pointer++; } temp = pointer; used[pointer] = 1; for(j=0; j<J; j++){ temp = pointer + ndx[j]; *ps = (*ps < pb[temp])? pb[temp] : *ps; used[temp] = 1; } ps++; pointer++; } } else{ for(i=0; i<I; i++){ while(used[pointer]){ pointer++; } temp = pointer; used[pointer] = 1; for(j=0; j<J; j++){ temp = pointer + ndx[j]; *ps += pb[temp]; used[temp] = 1; } ps++; pointer++; } } free(used); return Tsmall;}void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){ int i, n, p, np, pn, loop, loops, nCliques, temp, maximize, buflen; int *collect_order; double *pr, *pr1, *pMargndxId, *pMultndxId; mxArray *pTemp, *pPostP, *pClpot, *pSeppot, *ndx; char *buf, ch0; const mxArray *ndx_type; const mxArray *mm_ndx; pTemp = mxGetField(prhs[0], 0, "cliques"); nCliques = mxGetNumberOfElements(pTemp); loops = nCliques - 1; pTemp = mxGetField(prhs[0], 0, "maximize"); maximize = (int)mxGetScalar(pTemp); collect_order = malloc(2 * loops * sizeof(int)); pTemp = mxGetField(prhs[0], 0, "postorder"); pr = mxGetPr(pTemp); pPostP = mxGetField(prhs[0], 0, "postorder_parents"); for(i=0; i<loops; i++){ temp = (int)pr[i] - 1; pTemp = mxGetCell(pPostP, temp); pr1 = mxGetPr(pTemp); collect_order[i] = (int)pr1[0] - 1; collect_order[i+loops] = temp; } plhs[0] = mxDuplicateArray(prhs[1]); plhs[1] = mxDuplicateArray(prhs[2]); buflen = 3; buf = (char *)malloc(buflen * sizeof(char)); ndx_type = mxGetField(prhs[0], 0, "ndx_type"); mxGetString(ndx_type, buf, buflen); ch0 = buf[0]; switch(ch0){ case 'B': mm_ndx = (mxArray *)mexGetArrayPtr("B_NDX", "global"); break; case 'S': mm_ndx = (mxArray *)mexGetArrayPtr("SD_NDX", "global"); break; case 'D': mm_ndx = (mxArray *)mexGetArrayPtr("D_NDX", "global"); break; default : mexErrMsgTxt("There is no this kind of index. \n"); } if(mm_ndx == NULL){ mexErrMsgTxt("Could not get the global index.\n"); } pTemp = mxGetField(prhs[0], 0, "marg_cl_onto_sep_ndx_id"); pMargndxId = mxGetPr(pTemp); pTemp = mxGetField(prhs[0], 0, "mult_cl_by_sep_ndx_id"); pMultndxId = mxGetPr(pTemp); for(loop=0; loop<loops; loop++){ p = collect_order[loop]; n = collect_order[loop+loops]; np = p * nCliques + n; pn = n * nCliques + p; ndx = mxGetCell(mm_ndx, (int)pMargndxId[np]-1); pClpot = mxGetCell(plhs[0], n); pr = mxGetPr(pClpot); switch(ch0){ case 'B': pSeppot = marg_table_ndxB(pr, maximize, ndx); break; case 'S': pSeppot = marg_table_ndxSD(pr, maximize, ndx); break; case 'D': pSeppot = marg_table_ndxD(pClpot, maximize, ndx); break; } pr1 = mxGetPr(pSeppot); mxSetCell(plhs[1], pn, pSeppot); pClpot = mxGetCell(plhs[0], p); pr = mxGetPr(pClpot); ndx = mxGetCell(mm_ndx, (int)pMultndxId[pn]-1); switch(ch0){ case 'B': multiply_by_table_ndxB(pr, pr1, ndx); break; case 'S': multiply_by_table_ndxSD(pr, pr1, ndx); break; case 'D': multiply_by_table_ndxD(pClpot, pSeppot, ndx); break; } } free(collect_order); free(buf);}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -