⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 collect_evidence.c

📁 麻省理工学院的人工智能工具箱,很珍贵,希望对大家有用!
💻 C
字号:
/* C mex for collect_evidence.c in @jtree_C_inf_engine directory       */
/* The file enter_evidence.m in directory @jtree_C_inf_engine call it  */

/******************************************/
/* collect_evidence has 3 input & 2 output*/
/* engine                                 */
/* clpot                                  */
/* seppot                                 */
/*                                        */
/* clpot                                  */
/* seppot                                 */
/******************************************/

#include "mex.h"


void multiply_pot(mxArray *bTable, const mxArray *sTable, const mxArray *bDomain, const mxArray *sDomain, const double *ns){
	int     i, j, count, NB, NS, siz_b, siz_s, ndim, temp;
	int     *mask, *sx, *sy, *cpsy, *subs, *s, *cpsy2;
	double  *pb, *ps, *bp, *sp;

	siz_b = mxGetNumberOfElements(bDomain);
	siz_s = mxGetNumberOfElements(sDomain);
	pb = mxGetPr(bDomain);
	ps = mxGetPr(sDomain);

	NB = mxGetNumberOfElements(bTable);
	NS = mxGetNumberOfElements(sTable);
	bp = mxGetPr(bTable);
	sp = mxGetPr(sTable);

	if(NS == 1){
		for(i=0; i<NB; i++){
			bp[i] *= *sp;
		}
		return;
	}

	if(NS == NB){
		for(i=0; i<NB; i++){
			bp[i] *= sp[i];
		}
		return;
	}

	mask = malloc(siz_s * sizeof(int));
	count = 0;
	for(i=0; i<siz_s; i++){
		for(j=0; j<siz_b; j++){
			if(ps[i] == pb[j]){
				mask[count] = j;
				count++;
				break;
			}
		}
	}
	
	ndim = siz_b;
	sx = (int *)malloc(sizeof(int)*ndim);
	sy = (int *)malloc(sizeof(int)*ndim);
	for(i=0; i<ndim; i++){
		temp = (int)pb[i] - 1;
		sx[i] = (int)ns[temp];
		sy[i] = 1;
	}
	for(i=0; i<count; i++){
		temp = (int)ps[i] - 1;
		sy[mask[i]] = (int)ns[temp];
	}

	s = (int *)malloc(sizeof(int)*ndim);
	*(cpsy = (int *)malloc(sizeof(int)*ndim)) = 1;
	subs =   (int *)malloc(sizeof(int)*ndim);
	cpsy2 =  (int *)malloc(sizeof(int)*ndim);
	for(i = 0; i < ndim; i++){
		subs[i] = 0;
		s[i] = sx[i] - 1;
	}
			
	for(i = 0; i < ndim-1; i++){
		cpsy[i+1] = cpsy[i]*sy[i]--;
		cpsy2[i] = cpsy[i]*sy[i];
	}
	cpsy2[ndim-1] = cpsy[ndim-1]*(--sy[ndim-1]);

	for(j=0; j<NB; j++){
		*bp++ *= *sp;
		for(i = 0; i < ndim; i++){
			if(subs[i] == s[i]){
				subs[i] = 0;
				if(sy[i])
					sp -= cpsy2[i];
			}
			else{
				subs[i]++;
				if(sy[i])
					sp += cpsy[i];
				break;
			}
		}
	}
	free(sx);
	free(sy);
	free(s);
	free(cpsy);
	free(subs);
	free(cpsy2);
    free(mask);
}

mxArray* marginalise_pot(const mxArray *bTable, const mxArray *bDomain, const mxArray *sDomain, const double *ns, const int maximize){
	int     i, j, count, NB, NS, siz_b, siz_s, ndim, temp;
	int     *mask, *sx, *sy, *cpsy, *subs, *s, *cpsy2;
	double  *pb, *ps, *bp, *sp;
	mxArray *sTable;

	siz_b = mxGetNumberOfElements(bDomain);
	siz_s = mxGetNumberOfElements(sDomain);
	pb = mxGetPr(bDomain);
	ps = mxGetPr(sDomain);

	NB = mxGetNumberOfElements(bTable);
	bp = mxGetPr(bTable);

	if(siz_s == 0){
		sTable = mxCreateDoubleMatrix(1, 1, mxREAL);
		sp = mxGetPr(sTable);
		if(maximize){
			for(i=0; i<NB; i++){
				*sp = (*sp > bp[i])? *sp : bp[i];
			}
		}
		else{
			for(i=0; i<NB; i++){
				*sp += bp[i];
			}
		}
		return sTable;
	}

	mask = malloc(siz_s * sizeof(int));
	count = 0;
	for(i=0; i<siz_s; i++){
		for(j=0; j<siz_b; j++){
			if(ps[i] == pb[j]){
				mask[count] = j;
				count++;
				break;
			}
		}
	}
	
	ndim = siz_b;
	sx = (int *)malloc(sizeof(int)*ndim);
	sy = (int *)malloc(sizeof(int)*ndim);
	for(i=0; i<ndim; i++){
		temp = (int)pb[i] - 1;
		sx[i] = (int)ns[temp];
		sy[i] = 1;
	}
	for(i=0; i<count; i++){
		temp = (int)ps[i] - 1;
		sy[mask[i]] = (int)ns[temp];
	}

	NS = 1;
	for(i=0; i<ndim; i++){
		NS *= sy[i];
	}

	sTable = mxCreateDoubleMatrix(NS, 1, mxREAL);
	sp = mxGetPr(sTable);

	if(NS == NB){
		for(i=0; i<NB; i++){
			sp[i] = bp[i];
		}
		free(mask);
		free(sx);
		free(sy);
		return sTable;
	}

	s = (int *)malloc(sizeof(int)*ndim);
	*(cpsy = (int *)malloc(sizeof(int)*ndim)) = 1;
	subs =   (int *)malloc(sizeof(int)*ndim);
	cpsy2 =  (int *)malloc(sizeof(int)*ndim);
	for(i = 0; i < ndim; i++){
		subs[i] = 0;
		s[i] = sx[i] - 1;
	}
			
	for(i = 0; i < ndim-1; i++){
		cpsy[i+1] = cpsy[i]*sy[i]--;
		cpsy2[i] = cpsy[i]*sy[i];
	}
	cpsy2[ndim-1] = cpsy[ndim-1]*(--sy[ndim-1]);
	
	if(maximize){
		for(j=0; j<NB; j++){
			if(*bp > *sp) *sp = *bp;
			bp++;
			for(i = 0; i < ndim; i++){
				if(subs[i] == s[i]){
					subs[i] = 0;
					if(sy[i])
						sp -= cpsy2[i];
				}
				else{
					subs[i]++;
					if(sy[i])
						sp += cpsy[i];
					break;
				}
			}
		}
	}
	else{
		for(j=0; j<NB; j++){
			*sp += *bp++;
			for(i = 0; i < ndim; i++){
				if(subs[i] == s[i]){
					subs[i] = 0;
					if(sy[i])
						sp -= cpsy2[i];
				}
				else{
					subs[i]++;
					if(sy[i])
						sp += cpsy[i];
					break;
				}
			}
		}
	}
	free(sx);
	free(sy);
	free(s);
	free(cpsy);
	free(subs);
	free(cpsy2);
    free(mask);

	return sTable;
}


/******************************************/
/* collect_evidence has 3 input & 2 output*/
/* engine                                 */
/* clpot                                  */
/* seppot                                 */
/*                                        */
/* clpot                                  */
/* seppot                                 */
/******************************************/



void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){

	int     i, loop, loops, nCliques, temp, maximize;
	int     parent, child;
	int     *collect_order;
	double  *ns, *pr, *pr1;
	mxArray *pTemp, *pCliques, *pBig, *pSmall, *pSeparator, *pPostP, *pClpot, *pSeppot;

	pCliques = mxGetField(prhs[0], 0, "cliques");
	nCliques = mxGetNumberOfElements(pCliques);
	loops = nCliques - 1;

	pTemp = mxGetField(prhs[0], 0, "eff_node_sizes");
	ns = mxGetPr(pTemp);

	pTemp = mxGetField(prhs[0], 0, "maximize");
	maximize = (int)mxGetScalar(pTemp);

	collect_order = malloc(2 * loops * sizeof(int));

	pTemp = mxGetField(prhs[0], 0, "postorder");
	pr = mxGetPr(pTemp);
	pPostP = mxGetField(prhs[0], 0, "postorder_parents");
	for(i=0; i<loops; i++){
		temp = (int)pr[i] - 1;
		pTemp = mxGetCell(pPostP, temp);
		pr1 = mxGetPr(pTemp);
		collect_order[i] = (int)pr1[0] - 1;
		collect_order[i+loops] = temp;
	}

	plhs[0] = mxDuplicateArray(prhs[1]);
	plhs[1] = mxDuplicateArray(prhs[2]);

	pSeparator = mxGetField(prhs[0], 0, "separator");
	for(loop=0; loop<loops; loop++){
		parent = collect_order[loop];
		child  = collect_order[loop+loops];
		pBig  = mxGetCell(pCliques, child);
		i = nCliques * child + parent;
		pSmall = mxGetCell(pSeparator, i);
		pClpot = mxGetCell(plhs[0], child);
        pSeppot = marginalise_pot(pClpot, pBig, pSmall, ns, maximize);
		mxSetCell(plhs[1], i, pSeppot);
		
		pBig  = mxGetCell(pCliques, parent);
		pClpot = mxGetCell(plhs[0], parent);
		multiply_pot(pClpot, pSeppot, pBig, pSmall, ns);
	}
	free(collect_order);
}
	

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -