init_pot.c
来自「麻省理工学院的人工智能工具箱,很珍贵,希望对大家有用!」· C语言 代码 · 共 222 行
C
222 行
/* C mex init_pot for in @jtree_ndx_inf_engine directory */
/* The file enter_evidence.m in directory @jtree_ndx_inf_engine call it*/
/**************************************/
/* init_pot.c has 6 input & 2 output */
/* engine */
/* nodes */
/* CPDs */
/* clqs */
/* pots */
/* ndx2 */
/* */
/* clpot */
/* seppot */
/**************************************/
#include "mex.h"
void multiply_by_table_ndxB(double *Tbig, const double *Tsmall, const mxArray *ndx){
int i, I, J, N, S;
int *pData;
double value;
pData = mxGetData(ndx);
I = mxGetM(ndx);
J = mxGetN(ndx);
N = I * J;
S = mxGetNumberOfElements(ndx);
if(S == 1){
value = Tsmall[*pData];
for(i=0; i<N; i++){
*Tbig++ *= value;
}
}
else{
for(i=0; i<N; i++){
*Tbig++ *= Tsmall[*pData++];
}
}
}
void multiply_by_table_ndxSD(double *Tbig, const double *Tsmall, const mxArray *ndx){
int i, j, k, S, D;
mxArray *psmall, *pdiff;
int *prd, *prs;
psmall = mxGetField(ndx, 0, "small");
pdiff = mxGetField(ndx, 0, "diff");
prs = mxGetData(psmall);
prd = mxGetData(pdiff);
S = mxGetNumberOfElements(psmall);
D = mxGetNumberOfElements(pdiff);
for(i=0; i<S; i++){
for(j=0; j<D; j++){
k = prd[j] + prs[i];
Tbig[k] *= Tsmall[i];
}
}
}
void multiply_by_table_ndxD(mxArray *Tbig, const mxArray *Tsmall, const mxArray *index){
double *pb, *ps;
int i, j, I, J, N, pointer, temp;
int *ndx;
char *used;
N = mxGetNumberOfElements(Tbig);
pb = mxGetPr(Tbig);
ndx = mxGetData(index);
ps = mxGetPr(Tsmall);
J = mxGetNumberOfElements(index);
I = N / J;
if(J == 1){
for(i=0; i<N; i++){
*pb++ *= *ps++;
}
return;
}
if(I == 1){
for(i=0; i<N; i++){
*pb++ *= *ps;
}
return;
}
used = (char *)malloc(N * sizeof(char));
for(i=0; i<N; i++){
used[i] = 0;
}
pointer = 0;
for(i=0; i<I; i++){
while(used[pointer]){
pointer++;
}
temp = pointer;
used[pointer] = 1;
for(j=0; j<J; j++){
temp = pointer + ndx[j];
pb[temp] *= *ps;
used[temp] = 1;
}
ps++;
pointer++;
}
free(used);
}
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){
int i, j, n, c, loop, num, nNodes, nCliques, buflen;
double *clweight, *pNode, *pClq, *pr, *pt, *pNdxID;
mxArray *pTemp, *pCliques, *pClpot, *ndx;
int dims[2];
char *buf, ch0;
const mxArray *mult_ndx;
const mxArray *ndx_type;
nNodes = mxGetNumberOfElements(prhs[1]);
pCliques = mxGetField(prhs[0], 0, "cliques");
nCliques = mxGetNumberOfElements(pCliques);
pTemp = mxGetField(prhs[0], 0, "clique_weight");
clweight = mxGetPr(pTemp);
plhs[0] = mxCreateCellArray(1, &nCliques);
for(i=0; i<nCliques; i++){
num = (int)clweight[i];
pTemp = mxCreateDoubleMatrix(num, 1, mxREAL);
pr = mxGetPr(pTemp);
for(j=0; j<num; j++){
pr[j] = 1.0;
}
mxSetCell(plhs[0], i, pTemp);
}
buflen = 3;
buf = (char *)malloc(buflen * sizeof(char));
ndx_type = mxGetField(prhs[0], 0, "ndx_type");
mxGetString(ndx_type, buf, buflen);
ch0 = buf[0];
switch(ch0){
case 'B':
mult_ndx = (mxArray *)mexGetArrayPtr("B_NDX", "global");
break;
case 'S':
mult_ndx = (mxArray *)mexGetArrayPtr("SD_NDX", "global");
break;
case 'D':
mult_ndx = (mxArray *)mexGetArrayPtr("D_NDX", "global");
break;
default :
mexErrMsgTxt("There is no this kind of index. \n");
}
if(mult_ndx == NULL){
mexErrMsgTxt("Could not get the global index.\n");
}
pTemp = mxGetField(prhs[0], 0, "mult_node_ndx_id");
pNdxID = mxGetPr(pTemp);
pTemp = mxGetField(prhs[0], 0, "clq_ass_to_node");
pClq = mxGetPr(pTemp);
pNode = mxGetPr(prhs[1]);
for(loop=0; loop<nNodes; loop++){
n = (int)pNode[loop] - 1;
c = (int)pClq[n] - 1;
ndx = mxGetCell(mult_ndx, (int)pNdxID[n]-1);
pTemp = mxGetCell(prhs[2], loop);
pt = mxGetPr(pTemp);
pClpot = mxGetCell(plhs[0], c);
pr = mxGetPr(pClpot);
switch(ch0){
case 'B':
multiply_by_table_ndxB(pr, pt, ndx);
break;
case 'S':
multiply_by_table_ndxSD(pr, pt, ndx);
break;
case 'D':
multiply_by_table_ndxD(pClpot, pTemp, ndx);
break;
}
}
if(nrhs>3){
nNodes = mxGetNumberOfElements(prhs[3]);
pClq = mxGetPr(prhs[3]);
for(loop=0; loop<nNodes; loop++){
c = (int)pClq[loop] - 1;
ndx = mxGetCell(prhs[5], loop);
pTemp = mxGetCell(prhs[4], loop);
pt = mxGetPr(pTemp);
pClpot = mxGetCell(plhs[0], c);
pr = mxGetPr(pClpot);
switch(ch0){
case 'B':
multiply_by_table_ndxB(pr, pt, ndx);
break;
case 'S':
multiply_by_table_ndxSD(pr, pt, ndx);
break;
case 'D':
multiply_by_table_ndxD(pClpot, pTemp, ndx);
break;
}
}
}
free(buf);
dims[0] = nCliques;
dims[1] = nCliques;
plhs[1] = mxCreateCellArray(2, dims);
}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?