⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 init_pot.c

📁 麻省理工学院的人工智能工具箱,很珍贵,希望对大家有用!
💻 C
字号:
/* C mex file for init_pot.c in @jtree_C_inf_engine directory          */
/* The file enter_evidence.m in directory @jtree_C_inf_engine call it  */

/**************************************/
/* init_pot has 5 input & 2 output    */
/* engine                             */
/* clqs                               */
/* pot                                */
/* pot_type                           */
/* onodes                             */
/*                                    */
/* clpot                              */
/* seppot                             */
/**************************************/


#include "mex.h"

void multiply_pot(mxArray *bTable, const mxArray *sTable, const mxArray *bDomain, const mxArray *sDomain, const double *ns){
	int     i, j, count, NB, NS, siz_b, siz_s, ndim, temp;
	int     *mask, *sx, *sy, *cpsy, *subs, *s, *cpsy2;
	double  *pb, *ps, *bp, *sp;

	siz_b = mxGetNumberOfElements(bDomain);
	siz_s = mxGetNumberOfElements(sDomain);
	pb = mxGetPr(bDomain);
	ps = mxGetPr(sDomain);

	NB = mxGetNumberOfElements(bTable);
	NS = mxGetNumberOfElements(sTable);
	bp = mxGetPr(bTable);
	sp = mxGetPr(sTable);

	if(NS == 1){
		for(i=0; i<NB; i++){
			bp[i] *= *sp;
		}
		return;
	}

	if(NS == NB){
		for(i=0; i<NB; i++){
			bp[i] *= sp[i];
		}
		return;
	}

	mask = malloc(siz_s * sizeof(int));
	count = 0;
	for(i=0; i<siz_s; i++){
		for(j=0; j<siz_b; j++){
			if(ps[i] == pb[j]){
				mask[count] = j;
				count++;
				break;
			}
		}
	}
	
	ndim = siz_b;
	sx = (int *)malloc(sizeof(int)*ndim);
	sy = (int *)malloc(sizeof(int)*ndim);
	for(i=0; i<ndim; i++){
		temp = (int)pb[i] - 1;
		sx[i] = (int)ns[temp];
		sy[i] = 1;
	}
	for(i=0; i<count; i++){
		sy[mask[i]] = sx[mask[i]];
	}

	s = (int *)malloc(sizeof(int)*ndim);
	*(cpsy = (int *)malloc(sizeof(int)*ndim)) = 1;
	subs =   (int *)malloc(sizeof(int)*ndim);
	cpsy2 =  (int *)malloc(sizeof(int)*ndim);
	for(i = 0; i < ndim; i++){
		subs[i] = 0;
		s[i] = sx[i] - 1;
	}
			
	for(i = 0; i < ndim-1; i++){
		cpsy[i+1] = cpsy[i]*sy[i]--;
		cpsy2[i] = cpsy[i]*sy[i];
	}
	cpsy2[ndim-1] = cpsy[ndim-1]*(--sy[ndim-1]);

	for(j=0; j<NB; j++){
		*bp++ *= *sp;
		for(i = 0; i < ndim; i++){
			if(subs[i] == s[i]){
				subs[i] = 0;
				if(sy[i])
					sp -= cpsy2[i];
			}
			else{
				subs[i]++;
				if(sy[i])
					sp += cpsy[i];
				break;
			}
		}
	}
	free(sx);
	free(sy);
	free(s);
	free(cpsy);
	free(subs);
	free(cpsy2);
    free(mask);
}


void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){

	int     i, j, loop, loops, nNodes, nCliques, num;
	double  *ns, *clweight, *pClNode, *pr;
	mxArray *pTemp, *pCliques, *pBig, *pSmall, *pClpot, *pCPDpot;
	int     dims[2];

	nNodes = mxGetNumberOfElements(prhs[1]);
	pCliques = mxGetField(prhs[0], 0, "cliques");
	nCliques = mxGetNumberOfElements(pCliques);
	loops = nCliques - 1;

	pTemp = mxGetField(prhs[0], 0, "eff_node_sizes");
	ns = mxGetPr(pTemp);

	pTemp = mxGetField(prhs[0], 0, "clique_weight");
	clweight = mxGetPr(pTemp);

	plhs[0] = mxCreateCellArray(1, &nCliques);
    for(i=0; i<nCliques; i++){
		num = (int)clweight[i];
        pTemp = mxCreateDoubleMatrix(num, 1, mxREAL);
        pr = mxGetPr(pTemp);
        for(j=0; j<num; j++){
            pr[j] = 1;
        }
        mxSetCell(plhs[0], i, pTemp);
    }

	pClNode = mxGetPr(prhs[1]);
	for(loop=0; loop<nNodes; loop++){
		i = (int)pClNode[loop] - 1;
		pBig = mxGetCell(pCliques, i);
		pTemp = mxGetCell(prhs[2], loop);
		pSmall = mxGetField(pTemp, 0, "domain");
		pCPDpot = mxGetField(pTemp, 0, "T");

		pClpot = mxGetCell(plhs[0], i);
		multiply_pot(pClpot, pCPDpot, pBig, pSmall, ns); 
	}

	dims[0] = nCliques;
	dims[1] = nCliques;
	plhs[1] = mxCreateCellArray(2, dims);
}
	

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -