📄 init_pot.c
字号:
/* C mex file for init_pot.c in @jtree_C_inf_engine directory */
/* The file enter_evidence.m in directory @jtree_C_inf_engine call it */
/**************************************/
/* init_pot has 5 input & 2 output */
/* engine */
/* clqs */
/* pot */
/* pot_type */
/* onodes */
/* */
/* clpot */
/* seppot */
/**************************************/
#include "mex.h"
void multiply_pot(mxArray *bTable, const mxArray *sTable, const mxArray *bDomain, const mxArray *sDomain, const double *ns){
int i, j, count, NB, NS, siz_b, siz_s, ndim, temp;
int *mask, *sx, *sy, *cpsy, *subs, *s, *cpsy2;
double *pb, *ps, *bp, *sp;
siz_b = mxGetNumberOfElements(bDomain);
siz_s = mxGetNumberOfElements(sDomain);
pb = mxGetPr(bDomain);
ps = mxGetPr(sDomain);
NB = mxGetNumberOfElements(bTable);
NS = mxGetNumberOfElements(sTable);
bp = mxGetPr(bTable);
sp = mxGetPr(sTable);
if(NS == 1){
for(i=0; i<NB; i++){
bp[i] *= *sp;
}
return;
}
if(NS == NB){
for(i=0; i<NB; i++){
bp[i] *= sp[i];
}
return;
}
mask = malloc(siz_s * sizeof(int));
count = 0;
for(i=0; i<siz_s; i++){
for(j=0; j<siz_b; j++){
if(ps[i] == pb[j]){
mask[count] = j;
count++;
break;
}
}
}
ndim = siz_b;
sx = (int *)malloc(sizeof(int)*ndim);
sy = (int *)malloc(sizeof(int)*ndim);
for(i=0; i<ndim; i++){
temp = (int)pb[i] - 1;
sx[i] = (int)ns[temp];
sy[i] = 1;
}
for(i=0; i<count; i++){
sy[mask[i]] = sx[mask[i]];
}
s = (int *)malloc(sizeof(int)*ndim);
*(cpsy = (int *)malloc(sizeof(int)*ndim)) = 1;
subs = (int *)malloc(sizeof(int)*ndim);
cpsy2 = (int *)malloc(sizeof(int)*ndim);
for(i = 0; i < ndim; i++){
subs[i] = 0;
s[i] = sx[i] - 1;
}
for(i = 0; i < ndim-1; i++){
cpsy[i+1] = cpsy[i]*sy[i]--;
cpsy2[i] = cpsy[i]*sy[i];
}
cpsy2[ndim-1] = cpsy[ndim-1]*(--sy[ndim-1]);
for(j=0; j<NB; j++){
*bp++ *= *sp;
for(i = 0; i < ndim; i++){
if(subs[i] == s[i]){
subs[i] = 0;
if(sy[i])
sp -= cpsy2[i];
}
else{
subs[i]++;
if(sy[i])
sp += cpsy[i];
break;
}
}
}
free(sx);
free(sy);
free(s);
free(cpsy);
free(subs);
free(cpsy2);
free(mask);
}
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){
int i, j, loop, loops, nNodes, nCliques, num;
double *ns, *clweight, *pClNode, *pr;
mxArray *pTemp, *pCliques, *pBig, *pSmall, *pClpot, *pCPDpot;
int dims[2];
nNodes = mxGetNumberOfElements(prhs[1]);
pCliques = mxGetField(prhs[0], 0, "cliques");
nCliques = mxGetNumberOfElements(pCliques);
loops = nCliques - 1;
pTemp = mxGetField(prhs[0], 0, "eff_node_sizes");
ns = mxGetPr(pTemp);
pTemp = mxGetField(prhs[0], 0, "clique_weight");
clweight = mxGetPr(pTemp);
plhs[0] = mxCreateCellArray(1, &nCliques);
for(i=0; i<nCliques; i++){
num = (int)clweight[i];
pTemp = mxCreateDoubleMatrix(num, 1, mxREAL);
pr = mxGetPr(pTemp);
for(j=0; j<num; j++){
pr[j] = 1;
}
mxSetCell(plhs[0], i, pTemp);
}
pClNode = mxGetPr(prhs[1]);
for(loop=0; loop<nNodes; loop++){
i = (int)pClNode[loop] - 1;
pBig = mxGetCell(pCliques, i);
pTemp = mxGetCell(prhs[2], loop);
pSmall = mxGetField(pTemp, 0, "domain");
pCPDpot = mxGetField(pTemp, 0, "T");
pClpot = mxGetCell(plhs[0], i);
multiply_pot(pClpot, pCPDpot, pBig, pSmall, ns);
}
dims[0] = nCliques;
dims[1] = nCliques;
plhs[1] = mxCreateCellArray(2, dims);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -