📄 collect_evidence.c
字号:
/* C mex for collect_evidence.c in @jtree_C_inf_engine directory */
/* The file enter_evidence.m in directory @jtree_C_inf_engine call it */
/******************************************/
/* collect_evidence has 3 input & 2 output*/
/* engine */
/* clpot */
/* seppot */
/* */
/* clpot */
/* seppot */
/******************************************/
#include "mex.h"
void multiply_pot(mxArray *bTable, const mxArray *sTable, const mxArray *bDomain, const mxArray *sDomain, const double *ns){
int i, j, count, NB, NS, siz_b, siz_s, ndim, temp;
int *mask, *sx, *sy, *cpsy, *subs, *s, *cpsy2;
double *pb, *ps, *bp, *sp;
siz_b = mxGetNumberOfElements(bDomain);
siz_s = mxGetNumberOfElements(sDomain);
pb = mxGetPr(bDomain);
ps = mxGetPr(sDomain);
NB = mxGetNumberOfElements(bTable);
NS = mxGetNumberOfElements(sTable);
bp = mxGetPr(bTable);
sp = mxGetPr(sTable);
if(NS == 1){
for(i=0; i<NB; i++){
bp[i] *= *sp;
}
return;
}
if(NS == NB){
for(i=0; i<NB; i++){
bp[i] *= sp[i];
}
return;
}
mask = malloc(siz_s * sizeof(int));
count = 0;
for(i=0; i<siz_s; i++){
for(j=0; j<siz_b; j++){
if(ps[i] == pb[j]){
mask[count] = j;
count++;
break;
}
}
}
ndim = siz_b;
sx = (int *)malloc(sizeof(int)*ndim);
sy = (int *)malloc(sizeof(int)*ndim);
for(i=0; i<ndim; i++){
temp = (int)pb[i] - 1;
sx[i] = (int)ns[temp];
sy[i] = 1;
}
for(i=0; i<count; i++){
temp = (int)ps[i] - 1;
sy[mask[i]] = (int)ns[temp];
}
s = (int *)malloc(sizeof(int)*ndim);
*(cpsy = (int *)malloc(sizeof(int)*ndim)) = 1;
subs = (int *)malloc(sizeof(int)*ndim);
cpsy2 = (int *)malloc(sizeof(int)*ndim);
for(i = 0; i < ndim; i++){
subs[i] = 0;
s[i] = sx[i] - 1;
}
for(i = 0; i < ndim-1; i++){
cpsy[i+1] = cpsy[i]*sy[i]--;
cpsy2[i] = cpsy[i]*sy[i];
}
cpsy2[ndim-1] = cpsy[ndim-1]*(--sy[ndim-1]);
for(j=0; j<NB; j++){
*bp++ *= *sp;
for(i = 0; i < ndim; i++){
if(subs[i] == s[i]){
subs[i] = 0;
if(sy[i])
sp -= cpsy2[i];
}
else{
subs[i]++;
if(sy[i])
sp += cpsy[i];
break;
}
}
}
free(sx);
free(sy);
free(s);
free(cpsy);
free(subs);
free(cpsy2);
free(mask);
}
mxArray* marginalise_pot(const mxArray *bTable, const mxArray *bDomain, const mxArray *sDomain, const double *ns, const int maximize){
int i, j, count, NB, NS, siz_b, siz_s, ndim, temp;
int *mask, *sx, *sy, *cpsy, *subs, *s, *cpsy2;
double *pb, *ps, *bp, *sp;
mxArray *sTable;
siz_b = mxGetNumberOfElements(bDomain);
siz_s = mxGetNumberOfElements(sDomain);
pb = mxGetPr(bDomain);
ps = mxGetPr(sDomain);
NB = mxGetNumberOfElements(bTable);
bp = mxGetPr(bTable);
if(siz_s == 0){
sTable = mxCreateDoubleMatrix(1, 1, mxREAL);
sp = mxGetPr(sTable);
if(maximize){
for(i=0; i<NB; i++){
*sp = (*sp > bp[i])? *sp : bp[i];
}
}
else{
for(i=0; i<NB; i++){
*sp += bp[i];
}
}
return sTable;
}
mask = malloc(siz_s * sizeof(int));
count = 0;
for(i=0; i<siz_s; i++){
for(j=0; j<siz_b; j++){
if(ps[i] == pb[j]){
mask[count] = j;
count++;
break;
}
}
}
ndim = siz_b;
sx = (int *)malloc(sizeof(int)*ndim);
sy = (int *)malloc(sizeof(int)*ndim);
for(i=0; i<ndim; i++){
temp = (int)pb[i] - 1;
sx[i] = (int)ns[temp];
sy[i] = 1;
}
for(i=0; i<count; i++){
temp = (int)ps[i] - 1;
sy[mask[i]] = (int)ns[temp];
}
NS = 1;
for(i=0; i<ndim; i++){
NS *= sy[i];
}
sTable = mxCreateDoubleMatrix(NS, 1, mxREAL);
sp = mxGetPr(sTable);
if(NS == NB){
for(i=0; i<NB; i++){
sp[i] = bp[i];
}
free(mask);
free(sx);
free(sy);
return sTable;
}
s = (int *)malloc(sizeof(int)*ndim);
*(cpsy = (int *)malloc(sizeof(int)*ndim)) = 1;
subs = (int *)malloc(sizeof(int)*ndim);
cpsy2 = (int *)malloc(sizeof(int)*ndim);
for(i = 0; i < ndim; i++){
subs[i] = 0;
s[i] = sx[i] - 1;
}
for(i = 0; i < ndim-1; i++){
cpsy[i+1] = cpsy[i]*sy[i]--;
cpsy2[i] = cpsy[i]*sy[i];
}
cpsy2[ndim-1] = cpsy[ndim-1]*(--sy[ndim-1]);
if(maximize){
for(j=0; j<NB; j++){
if(*bp > *sp) *sp = *bp;
bp++;
for(i = 0; i < ndim; i++){
if(subs[i] == s[i]){
subs[i] = 0;
if(sy[i])
sp -= cpsy2[i];
}
else{
subs[i]++;
if(sy[i])
sp += cpsy[i];
break;
}
}
}
}
else{
for(j=0; j<NB; j++){
*sp += *bp++;
for(i = 0; i < ndim; i++){
if(subs[i] == s[i]){
subs[i] = 0;
if(sy[i])
sp -= cpsy2[i];
}
else{
subs[i]++;
if(sy[i])
sp += cpsy[i];
break;
}
}
}
}
free(sx);
free(sy);
free(s);
free(cpsy);
free(subs);
free(cpsy2);
free(mask);
return sTable;
}
/******************************************/
/* collect_evidence has 3 input & 2 output*/
/* engine */
/* clpot */
/* seppot */
/* */
/* clpot */
/* seppot */
/******************************************/
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){
int i, loop, loops, nCliques, temp, maximize;
int parent, child;
int *collect_order;
double *ns, *pr, *pr1;
mxArray *pTemp, *pCliques, *pBig, *pSmall, *pSeparator, *pPostP, *pClpot, *pSeppot;
pCliques = mxGetField(prhs[0], 0, "cliques");
nCliques = mxGetNumberOfElements(pCliques);
loops = nCliques - 1;
pTemp = mxGetField(prhs[0], 0, "eff_node_sizes");
ns = mxGetPr(pTemp);
pTemp = mxGetField(prhs[0], 0, "maximize");
maximize = (int)mxGetScalar(pTemp);
collect_order = malloc(2 * loops * sizeof(int));
pTemp = mxGetField(prhs[0], 0, "postorder");
pr = mxGetPr(pTemp);
pPostP = mxGetField(prhs[0], 0, "postorder_parents");
for(i=0; i<loops; i++){
temp = (int)pr[i] - 1;
pTemp = mxGetCell(pPostP, temp);
pr1 = mxGetPr(pTemp);
collect_order[i] = (int)pr1[0] - 1;
collect_order[i+loops] = temp;
}
plhs[0] = mxDuplicateArray(prhs[1]);
plhs[1] = mxDuplicateArray(prhs[2]);
pSeparator = mxGetField(prhs[0], 0, "separator");
for(loop=0; loop<loops; loop++){
parent = collect_order[loop];
child = collect_order[loop+loops];
pBig = mxGetCell(pCliques, child);
i = nCliques * child + parent;
pSmall = mxGetCell(pSeparator, i);
pClpot = mxGetCell(plhs[0], child);
pSeppot = marginalise_pot(pClpot, pBig, pSmall, ns, maximize);
mxSetCell(plhs[1], i, pSeppot);
pBig = mxGetCell(pCliques, parent);
pClpot = mxGetCell(plhs[0], parent);
multiply_pot(pClpot, pSeppot, pBig, pSmall, ns);
}
free(collect_order);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -