init_pot.c

来自「麻省理工学院的人工智能工具箱,很珍贵,希望对大家有用!」· C语言 代码 · 共 222 行

C
222
字号
/* C mex init_pot for in @jtree_ndx_inf_engine directory               */
/* The file enter_evidence.m in directory @jtree_ndx_inf_engine call it*/

/**************************************/
/* init_pot.c has 6 input & 2 output  */
/* engine                             */
/* nodes                              */
/* CPDs                               */
/* clqs                               */
/* pots                               */
/* ndx2                               */
/*                                    */
/* clpot                              */
/* seppot                             */
/**************************************/

#include "mex.h"

void multiply_by_table_ndxB(double *Tbig, const double *Tsmall, const mxArray *ndx){
	int i, I, J, N, S;
	int *pData;
	double value;

	pData = mxGetData(ndx);
	I = mxGetM(ndx);
	J = mxGetN(ndx);
	N = I * J;
	S = mxGetNumberOfElements(ndx);

	if(S == 1){
		value = Tsmall[*pData];
		for(i=0; i<N; i++){
			*Tbig++ *= value;
		}
	}
	else{
		for(i=0; i<N; i++){
			*Tbig++ *= Tsmall[*pData++];
		}
	}
}

void multiply_by_table_ndxSD(double *Tbig, const double *Tsmall, const mxArray *ndx){
	int     i, j, k, S, D;
	mxArray *psmall, *pdiff;
	int     *prd, *prs;

	psmall = mxGetField(ndx, 0, "small");
	pdiff = mxGetField(ndx, 0, "diff");
	prs = mxGetData(psmall);
	prd = mxGetData(pdiff);
	S = mxGetNumberOfElements(psmall);
	D = mxGetNumberOfElements(pdiff);

	for(i=0; i<S; i++){
		for(j=0; j<D; j++){
			k = prd[j] + prs[i];
			Tbig[k] *= Tsmall[i];
		}
	}
}

void multiply_by_table_ndxD(mxArray *Tbig, const mxArray *Tsmall, const mxArray *index){
	double  *pb, *ps;
	int     i, j, I, J, N, pointer, temp;
	int     *ndx;
	char    *used;

	N = mxGetNumberOfElements(Tbig);
	pb = mxGetPr(Tbig);
	ndx = mxGetData(index);
	ps = mxGetPr(Tsmall);
	J = mxGetNumberOfElements(index);
	I = N / J;

	if(J == 1){
		for(i=0; i<N; i++){
			*pb++ *= *ps++;
		}
		return;
	}

	if(I == 1){
		for(i=0; i<N; i++){
			*pb++ *= *ps;
		}
		return;
	}

	used = (char *)malloc(N * sizeof(char));
	for(i=0; i<N; i++){
		used[i] = 0;
	}
	pointer = 0;
	for(i=0; i<I; i++){
		while(used[pointer]){
			pointer++;
		}
		temp = pointer;
		used[pointer] = 1;
		for(j=0; j<J; j++){
			temp = pointer + ndx[j];
			pb[temp] *= *ps;
			used[temp] = 1;
		}
		ps++;
		pointer++;
	}

	free(used);
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){

	int     i, j, n, c, loop, num, nNodes, nCliques, buflen;
	double  *clweight, *pNode, *pClq, *pr, *pt, *pNdxID;
	mxArray *pTemp, *pCliques, *pClpot, *ndx;
	int     dims[2];
	char    *buf, ch0;
	const mxArray *mult_ndx;
	const mxArray *ndx_type;


	nNodes = mxGetNumberOfElements(prhs[1]);
	pCliques = mxGetField(prhs[0], 0, "cliques");
	nCliques = mxGetNumberOfElements(pCliques);

	pTemp = mxGetField(prhs[0], 0, "clique_weight");
	clweight = mxGetPr(pTemp);

	plhs[0] = mxCreateCellArray(1, &nCliques);
    for(i=0; i<nCliques; i++){
		num = (int)clweight[i];
        pTemp = mxCreateDoubleMatrix(num, 1, mxREAL);
        pr = mxGetPr(pTemp);
        for(j=0; j<num; j++){
            pr[j] = 1.0;
        }
        mxSetCell(plhs[0], i, pTemp);
    }

	buflen = 3;
	buf = (char *)malloc(buflen * sizeof(char));
	ndx_type = mxGetField(prhs[0], 0, "ndx_type");
	mxGetString(ndx_type, buf, buflen);
	ch0 = buf[0];

	switch(ch0){
	case 'B': 
		mult_ndx = (mxArray *)mexGetArrayPtr("B_NDX", "global");
		break;
	case 'S':
		mult_ndx = (mxArray *)mexGetArrayPtr("SD_NDX", "global");
		break;
	case 'D':
		mult_ndx = (mxArray *)mexGetArrayPtr("D_NDX", "global");
		break;
	default :
		mexErrMsgTxt("There is no this kind of index. \n");
	}
    if(mult_ndx == NULL){
		mexErrMsgTxt("Could not get the global index.\n");
    }

	pTemp = mxGetField(prhs[0], 0, "mult_node_ndx_id");
	pNdxID = mxGetPr(pTemp);
	pTemp = mxGetField(prhs[0], 0, "clq_ass_to_node");
	pClq = mxGetPr(pTemp);
	pNode = mxGetPr(prhs[1]);
	for(loop=0; loop<nNodes; loop++){
		n = (int)pNode[loop] - 1;
		c = (int)pClq[n] - 1;
		ndx = mxGetCell(mult_ndx, (int)pNdxID[n]-1);
		pTemp = mxGetCell(prhs[2], loop);
		pt = mxGetPr(pTemp);
		pClpot = mxGetCell(plhs[0], c);
		pr = mxGetPr(pClpot);
		switch(ch0){
		case 'B': 
			multiply_by_table_ndxB(pr, pt, ndx);
			break;
		case 'S':
			multiply_by_table_ndxSD(pr, pt, ndx);
			break;
		case 'D':
			multiply_by_table_ndxD(pClpot, pTemp, ndx);
			break;
		}
	}

	if(nrhs>3){
		nNodes = mxGetNumberOfElements(prhs[3]);
		pClq = mxGetPr(prhs[3]);
		for(loop=0; loop<nNodes; loop++){
			c = (int)pClq[loop] - 1;
			ndx = mxGetCell(prhs[5], loop);
			pTemp = mxGetCell(prhs[4], loop);
			pt = mxGetPr(pTemp);
			pClpot = mxGetCell(plhs[0], c);
			pr = mxGetPr(pClpot);
			switch(ch0){
			case 'B': 
				multiply_by_table_ndxB(pr, pt, ndx);
				break;
			case 'S':
				multiply_by_table_ndxSD(pr, pt, ndx);
				break;
			case 'D':
				multiply_by_table_ndxD(pClpot, pTemp, ndx);
				break;
			}
		}
	}

	free(buf);
	dims[0] = nCliques;
	dims[1] = nCliques;
	plhs[1] = mxCreateCellArray(2, dims);
}


⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?