⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 collect_evidence.c

📁 麻省理工学院的人工智能工具箱,很珍贵,希望对大家有用!
💻 C
字号:
/* C mex for collect_evidence.c in @jtree_ndx_inf_engine directory       *//* The file enter_evidence.m in directory @jtree_ndx_inf_engine call it  *//******************************************//* collect_evidence has 3 input & 2 output*//* engine                                 *//* clpot                                  *//* seppot                                 *//*                                        *//* clpot                                  *//* seppot                                 *//******************************************/#include "mex.h"void multiply_by_table_ndxB(double *Tbig, const double *Tsmall, const mxArray *ndx){	int i, I, J, N, S;	int *pData;	double value;	pData = mxGetData(ndx);	I = mxGetM(ndx);	J = mxGetN(ndx);	N = I * J;	S = mxGetNumberOfElements(ndx);	if(S == 1){		value = Tsmall[*pData];		for(i=0; i<N; i++){			*Tbig++ *= value;		}	}	else{		for(i=0; i<N; i++){			*Tbig++ *= Tsmall[*pData++];		}	}}void multiply_by_table_ndxSD(double *Tbig, const double *Tsmall, const mxArray *ndx){	int     i, j, k, S, D;	mxArray *psmall, *pdiff;	int     *prd, *prs;	psmall = mxGetField(ndx, 0, "small");	pdiff = mxGetField(ndx, 0, "diff");	prs = mxGetData(psmall);	prd = mxGetData(pdiff);	S = mxGetNumberOfElements(psmall);	D = mxGetNumberOfElements(pdiff);	for(i=0; i<S; i++){		for(j=0; j<D; j++){			k = prd[j] + prs[i];			Tbig[k] *= Tsmall[i];		}	}}void multiply_by_table_ndxD(mxArray *Tbig, const mxArray *Tsmall, const mxArray *index){	double  *pb, *ps;	int     i, j, I, J, N, pointer, temp;	int     *ndx;	char    *used;	N = mxGetNumberOfElements(Tbig);	pb = mxGetPr(Tbig);	ndx = mxGetData(index);	ps = mxGetPr(Tsmall);	J = mxGetNumberOfElements(index);	I = N / J;	if(J == 1){		for(i=0; i<N; i++){			*pb++ *= *ps++;		}		return;	}	if(I == 1){		for(i=0; i<N; i++){			*pb++ *= *ps;		}		return;	}	used = (char *)malloc(N * sizeof(char));	for(i=0; i<N; i++){		used[i] = 0;	}	pointer = 0;	for(i=0; i<I; i++){		while(used[pointer]){			pointer++;		}		temp = pointer;		used[pointer] = 1;		for(j=0; j<J; j++){			temp = pointer + ndx[j];			pb[temp] *= *ps;			used[temp] = 1;		}		ps++;		pointer++;	}	free(used);}mxArray *marg_table_ndxB(const double *Tbig, const int maximize, const mxArray *index){	mxArray *result;	double  *Tsmall;	int     i, I, J, N;	int     *ndx;	ndx = mxGetData(index);	I = mxGetM(index);	J = mxGetN(index);	N = I * J;	result = mxCreateDoubleMatrix(I, 1, mxREAL);	Tsmall = mxGetPr(result);	if(maximize){		for (i=0; i<N; i++){			Tsmall[*ndx] = (Tsmall[*ndx] < Tbig[i])? Tbig[i] : Tsmall[*ndx];			ndx++;		}	} 	else{		for(i=0; i<N; i++){			Tsmall[*ndx++] += *Tbig++;		}	}	return result;}mxArray *marg_table_ndxSD(const double *Tbig, const int maximize, const mxArray *index){	mxArray *ptemp, *result;	double  *Tsmall, max, sum;	int     i, j, k, S, D;	int     *small_ndx, *diff_ndx;	ptemp = mxGetField(index, 0, "small");	small_ndx = mxGetData(ptemp);	S = mxGetNumberOfElements(ptemp);	ptemp = mxGetField(index, 0, "diff");	diff_ndx = mxGetData(ptemp);	D = mxGetNumberOfElements(ptemp);	result = mxCreateDoubleMatrix(S, 1, mxREAL);	Tsmall = mxGetPr(result);	if(maximize){		for(i=0; i<S; i++){			max = -1e-20;			for(j=0; j<D; j++){				k = small_ndx[i] + diff_ndx[j];				if(Tbig[k] > max) max = Tbig[k];			}			Tsmall[i] = max;		}	}	else{		for(i=0; i<S; i++){			sum = 0;			for(j=0; j<D; j++){				k = small_ndx[i] + diff_ndx[j];				sum += Tbig[k];			}			Tsmall[i] = sum;		}	}	return result;}mxArray *marg_table_ndxD(const mxArray *Tbig, const int maximize, const mxArray *index){	mxArray *Tsmall;	double  *pb, *ps;	int     i, j, I, J, N, pointer, temp;	int     *ndx;	char    *used;	N = mxGetNumberOfElements(Tbig);	pb = mxGetPr(Tbig);	ndx = mxGetData(index);	J = mxGetNumberOfElements(index);	I = N / J;	if(J == 1){		Tsmall = mxDuplicateArray(Tbig);		return Tsmall;	}	if(I == 1){		Tsmall = mxCreateDoubleMatrix(1, 1, mxREAL);		ps = mxGetPr(Tsmall);		if(maximize){			for(i=0; i<N; i++){				*ps = (*ps < *pb)? *pb : *ps;				pb++;			}		}		else{			for(i=0; i<N; i++){				*ps += *pb++;			}		}		return Tsmall;	}	Tsmall = mxCreateDoubleMatrix(I, 1, mxREAL);	ps = mxGetPr(Tsmall);	used = (char *)malloc(N * sizeof(char));	for(i=0; i<N; i++){		used[i] = 0;	}	pointer = 0;	if(maximize){		for(i=0; i<I; i++){			while(used[pointer]){				pointer++;			}			temp = pointer;			used[pointer] = 1;			for(j=0; j<J; j++){				temp = pointer + ndx[j];				*ps = (*ps < pb[temp])? pb[temp] : *ps;				used[temp] = 1;			}			ps++;			pointer++;		}	}	else{		for(i=0; i<I; i++){			while(used[pointer]){				pointer++;			}			temp = pointer;			used[pointer] = 1;			for(j=0; j<J; j++){				temp = pointer + ndx[j];				*ps += pb[temp];				used[temp] = 1;			}			ps++;			pointer++;		}	}	free(used);	return Tsmall;}void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){	int     i, n, p, np, pn, loop, loops, nCliques, temp, maximize, buflen;	int     *collect_order;	double  *pr, *pr1, *pMargndxId, *pMultndxId;	mxArray *pTemp, *pPostP, *pClpot, *pSeppot, *ndx;	char    *buf, ch0;	const mxArray *ndx_type;	const mxArray *mm_ndx;	pTemp = mxGetField(prhs[0], 0, "cliques");	nCliques = mxGetNumberOfElements(pTemp);	loops = nCliques - 1;	pTemp = mxGetField(prhs[0], 0, "maximize");	maximize = (int)mxGetScalar(pTemp);	collect_order = malloc(2 * loops * sizeof(int));	pTemp = mxGetField(prhs[0], 0, "postorder");	pr = mxGetPr(pTemp);	pPostP = mxGetField(prhs[0], 0, "postorder_parents");	for(i=0; i<loops; i++){		temp = (int)pr[i] - 1;		pTemp = mxGetCell(pPostP, temp);		pr1 = mxGetPr(pTemp);		collect_order[i] = (int)pr1[0] - 1;		collect_order[i+loops] = temp;	}	plhs[0] = mxDuplicateArray(prhs[1]);	plhs[1] = mxDuplicateArray(prhs[2]);	buflen = 3;	buf = (char *)malloc(buflen * sizeof(char));	ndx_type = mxGetField(prhs[0], 0, "ndx_type");	mxGetString(ndx_type, buf, buflen);	ch0 = buf[0];	switch(ch0){	case 'B': 		mm_ndx = (mxArray *)mexGetArrayPtr("B_NDX", "global");		break;	case 'S':		mm_ndx = (mxArray *)mexGetArrayPtr("SD_NDX", "global");		break;	case 'D':		mm_ndx = (mxArray *)mexGetArrayPtr("D_NDX", "global");		break;	default :		mexErrMsgTxt("There is no this kind of index. \n");	}    if(mm_ndx == NULL){		mexErrMsgTxt("Could not get the global index.\n");    }	pTemp = mxGetField(prhs[0], 0, "marg_cl_onto_sep_ndx_id");	pMargndxId = mxGetPr(pTemp);	pTemp = mxGetField(prhs[0], 0, "mult_cl_by_sep_ndx_id");	pMultndxId = mxGetPr(pTemp);	for(loop=0; loop<loops; loop++){		p = collect_order[loop];		n = collect_order[loop+loops];		np = p * nCliques + n;		pn = n * nCliques + p;		ndx = mxGetCell(mm_ndx, (int)pMargndxId[np]-1);		pClpot = mxGetCell(plhs[0], n);		pr = mxGetPr(pClpot);		switch(ch0){		case 'B': 			pSeppot = marg_table_ndxB(pr, maximize, ndx);			break;		case 'S':			pSeppot = marg_table_ndxSD(pr, maximize, ndx);			break;		case 'D':			pSeppot = marg_table_ndxD(pClpot, maximize, ndx);			break;		}		pr1 = mxGetPr(pSeppot);		mxSetCell(plhs[1], pn, pSeppot);		pClpot = mxGetCell(plhs[0], p);		pr = mxGetPr(pClpot);		ndx = mxGetCell(mm_ndx, (int)pMultndxId[pn]-1);		switch(ch0){		case 'B': 			multiply_by_table_ndxB(pr, pr1, ndx);			break;		case 'S':			multiply_by_table_ndxSD(pr, pr1, ndx);			break;		case 'D':			multiply_by_table_ndxD(pClpot, pSeppot, ndx);			break;		}	}	free(collect_order);	free(buf);}	

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -