⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 enter_softev_c.c

📁 麻省理工学院的人工智能工具箱,很珍贵,希望对大家有用!
💻 C
字号:
/* C mex version for enter_softev_c.c in @jtree_C_inf_engine directory */
/* The file enter_evidence.m in directory @jtree_C_inf_engine call it  */

/**************************************/
/* enter_softev has 6 input & 2 output*/
/* engine                             */
/* clqs                               */
/* pot                                */
/* onodes                             */
/* pot_type                           */
/*                                    */
/* clpot                              */
/* loglik                             */
/**************************************/


#include <math.h>
#include "mex.h"

double myNormalise(double *array, const int n){
    int i;
    double sum = 0.0;
    double sum1 = 0.0;

    for(i=0; i<n; i++){
        sum += array[i];
    }
    if(sum==0.0) sum1=1.0;
    else sum1=sum;
    for(i=0; i<n; i++){
        array[i] /= sum1;
    }
    return sum;
}

int mk_marginalise_ndx(const double *pb, const double *ps, const int *ns, const int siz_b, const int siz_s, int *margndx)
{
	int         i, j, k, siz_d, count, count2, pointer, num_ns, temp, temp1, NB, NS, ND, match, distance, ndim;
	int         *same_mask, *diff_mask, *cumprod, *templai, *used, *subv, *model, *diff_size, *cum_weight, *weight_d;

	same_mask = malloc(siz_s * sizeof(int));
	diff_mask = malloc(siz_b * sizeof(int));
	temp = 0;
	count = 0;
	for(i=0; i<siz_b; i++){
		match = 0;
		for(j=0; j<siz_s; j++){
			if(pb[i] == ps[j]){
				same_mask[temp] = i;
				match = 1;
				temp++;
				break;
			}
		}
		if(match == 0){
			diff_mask[count] = i; 
			count++;
		}
	}

	ndim = siz_b;
	NB = 1;
	for(i=0; i<ndim; i++){
		temp = (int)pb[i] - 1;
		NB *= ns[temp];
	}
	NS = 1;
	for(i=0; i<siz_s; i++){
		temp = (int)ps[i] - 1;
		NS *= ns[temp];
	}
	ND = NB / NS;

	siz_d = siz_b - siz_s;

	if(ND == 1){
		for(i=0; i<NB; i++){
			margndx[i] = i;
		}
		free(same_mask);
		free(diff_mask);
		return NS;
	}

	if(siz_d == 1){
		used= malloc(NB * sizeof(int));
		for(i=0; i<NB; i++){
			used[i] = 0;
		}
		cum_weight = malloc(siz_b * sizeof(int));
		cum_weight[0] = 1;
		for(i=1; i<siz_b; i++){
			temp = (int)pb[i-1] - 1;
			cum_weight[i] = cum_weight[i-1] * ns[temp];
		}
		distance = cum_weight[diff_mask[0]];
		pointer = 0;
		for(j=0; j<NS; j++){
			while(used[pointer]){
				pointer++;
			}
			margndx[j] = pointer;
			temp = pointer;
			used[pointer] = 1;
			for(i=1; i<ND; i++){
				temp += distance;
				margndx[j + i * NS] = temp;
				used[temp] = 1;
			}
			pointer++;
		}
		free(same_mask);
		free(diff_mask);
		free(used);
		free(cum_weight);
		return NS;
	}

	used= malloc(NB * sizeof(int));
	for(i=0; i<NB; i++){
		used[i] = 0;
	}

	subv = (int *)malloc(ND * siz_d * sizeof(int));
	diff_size = malloc(siz_d * sizeof(int));
	for(i=0; i<siz_d; i++){
		temp = (int)pb[diff_mask[i]] - 1;
		diff_size[i] = ns[temp];
	}
	model = malloc(ND * sizeof(int));
	for(i=0; i<ND; i++){
		model[i] = 0;
	}

	cum_weight = malloc(siz_b * sizeof(int));
	cum_weight[0] = 1;
	for(i=1; i<siz_b; i++){
		temp = (int)pb[i-1] - 1;
		cum_weight[i] = cum_weight[i-1] * ns[temp];
	}
	weight_d = malloc(siz_d * sizeof(int));
	for(i=0; i<siz_d; i++){
		weight_d[i] = cum_weight[diff_mask[i]];
	}
	
	cumprod = (int *)malloc(siz_d * sizeof(int));
	templai = (int *)malloc(siz_d * sizeof(int));
	cumprod[0] = 1;
	for(i=1; i<siz_d; i++){
		temp = diff_size[i-1];
		cumprod[i] = cumprod[i-1] * temp;
	}
	templai[0] = ND;
	for(i=1; i<siz_d; i++){
		templai[i] = ND / cumprod[i];
	}

	for(j=0; j<siz_d; j++){
		temp1 = j * ND;
		num_ns = diff_size[j];
		if(num_ns == 1){
			for(i=0; i<ND; i++){
				subv[temp1 + i] = 0;
			}
		}
		else{
			temp = 0;
			count2 = 0;
			for(i=0; i<templai[j]; i++){
				if(temp == num_ns) temp = 0;
				for(k=0; k<cumprod[j]; k++){	
					subv[temp1 + count2] = temp;
					count2++;
				}
				temp++;
			}
		}
	}

	for(j=0; j<siz_d; j++){
		temp1 = j * ND;
		for(i=0; i<ND; i++){
			model[i] += weight_d[j] * subv[temp1 + i];
		}
	}

	pointer = 0;
	for(j=0; j<NS; j++){
		while(used[pointer]){
			pointer++;
		}
		temp = pointer;
		used[pointer] = 1;
		for(i=0; i<ND; i++){
			temp = pointer + model[i];
			margndx[j + i * NS] = temp;
			used[temp] = 1;
		}
		pointer++;
	}

	free(model);
	free(diff_size);
	free(cum_weight);
	free(weight_d);
	free(cumprod);
	free(templai);
	free(same_mask);
	free(diff_mask);
	free(used);
	free(subv);
	return NS;
}


void mk_multiply_ndx(const double *pb, const double *ps, const int *ns, const int siz_b, const int siz_s, int *multndx)
{

	int     i, j, count, NB, NS, ndim, yrp, temp;
	int     *mask, *sx, *sy, *cpsy, *subs, *s, *cpsy2;

	mask = malloc(siz_s * sizeof(int));
	count = 0;
	for(i=0; i<siz_s; i++){
		for(j=0; j<siz_b; j++){
			if(ps[i] == pb[j]){
				mask[count] = j;
				count++;
				break;
			}
		}
	}
	
	ndim = siz_b;
	NB = 1;
	sx = (int *)malloc(sizeof(int)*ndim);
	sy = (int *)malloc(sizeof(int)*ndim);
	for(i=0; i<ndim; i++){
		temp = (int)pb[i] - 1;
		sx[i] = ns[temp];
		NB *= ns[temp];
		sy[i] = 1;
	}
	NS = 1;
	for(i=0; i<count; i++){
		temp = (int)ps[i] - 1;
		sy[mask[i]] = ns[temp];
		NS *= ns[temp];
	}

	if(NS == 1){
		for(i=0; i<NB; i++){
			multndx[i] = 0;
		}
		free(mask);
		free(sx);
		free(sy);
		return;
	}
	if(NS == NB){
		for(i=0; i<NB; i++){
			multndx[i] = i;
		}
		free(mask);
		free(sx);
		free(sy);
		return;
	}

	s = (int *)malloc(sizeof(int)*ndim);
	*(cpsy = (int *)malloc(sizeof(int)*ndim)) = 1;
	subs =   (int *)malloc(sizeof(int)*ndim);
	cpsy2 =  (int *)malloc(sizeof(int)*ndim);
	for(i = 0; i < ndim; i++){
		subs[i] = 0;
		s[i] = sx[i] - 1;
	}
			
	for(i = 0; i < ndim-1; i++){
		cpsy[i+1] = cpsy[i]*sy[i]--;
		cpsy2[i] = cpsy[i]*sy[i];
	}
	cpsy2[ndim-1] = cpsy[ndim-1]*(--sy[ndim-1]);

	yrp = 0;
	for(j=0; j<NB; j++){
		multndx[j] = yrp;
		for(i = 0; i < ndim; i++){
			if (subs[i] == s[i]){
				subs[i] = 0;
				if (sy[i])
					yrp -= cpsy2[i];
			}
			else{
				subs[i]++;
				if (sy[i])
					yrp += cpsy[i];
				break;
			}
		}
	}
	free(sx);
	free(sy);
	free(s);
	free(cpsy);
	free(subs);
	free(cpsy2);
    free(mask);
}

void multiply_pot(double *clpot, const double* pr, const int *multndx, const int n){
	int i, temp;

	for(i=0; i<n; i++){
		temp = multndx[i];
		clpot[i] *= pr[temp];
	}
}

/*void divide_pot(double *clpot, const double *pr, const int *multndx, const int n){
	int i, temp;
	double f;

	for(i=0; i<n; i++){
		temp = multndx[i];
		f = (pr[temp] == 0.0) ? 1.0 : pr[temp];
		clpot[i] /= f;
	}
}*/

void array_divide(double *arr0, const double *arr1, const int n){
	int i;
	double f;

	for(i=0; i<n; i++){
		f = (arr1[i] == 0.0) ? 1.0 : arr1[i];
		arr0[i] /= f;
	}
}


void marginalise_pot(const double *clpot, double *pr, const int *margndx, const int n, const int row){
	int i, j, k;
	int col;

	col = n / row;
	for(i=0; i<row; i++){
		pr[i] = 0.0;
		for(j=0; j<col; j++){
			k = margndx[i + j*row];
			pr[i] += clpot[k];
		}
	}
}


/**************************************/
/* enter_softev has 6 input & 2 output*/
/* engine                             */
/* clqs                               */
/* pot                                */
/* onodes                             */
/* pot_type                           */
/*                                    */
/* clpot                              */
/* loglik                             */
/**************************************/


void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){

	int     i, j, loop, loops, nNodes, nCliques, ipointer, pc, temp, buflen, count;
	int     parent, child, siz_b, siz_s, max_table, row, num;
	int     *ns, *multndx, *margndx, *distribute_order, *collect_order, *family_num;
	int     **p_sep;
	double  lik, loglik;
	double  *clweight, *pClNode, *pr, *pr1, *pBigDom, *pSmallDom, *seppot0;
	double  **clpot, **seppot, **family;	
	mxArray *pTemp, *pCliques, *pBig, *pSmall, *pSeparator, *pBnet, *pParents, *pPostP, *pPreCh, *pt;
	char    *buf;

	buflen = mxGetNumberOfElements(prhs[4]) + 1;
	buf = malloc(buflen * sizeof(char));
	mxGetString(prhs[4], buf, buflen);
	if(buf[0] != 'd'){
		mexErrMsgTxt("Currently can only process pot_type 'd'.\n");
	}

	nNodes = mxGetNumberOfElements(prhs[1]);
	pCliques = mxGetField(prhs[0], 0, "cliques");
	nCliques = mxGetNumberOfElements(pCliques);
	loops = nCliques - 1;

	pTemp = mxGetField(prhs[0], 0, "inf_engine");
	pBnet = mxGetField(pTemp, 0, "bnet");
	pParents = mxGetField(pBnet, 0, "parents");
	family = (double **)malloc(nNodes * sizeof(double *));
	family_num = malloc(nNodes * sizeof(int));
	for(i=0; i<nNodes; i++){
		pTemp = mxGetCell(pParents, i);
		pr = mxGetPr(pTemp);
		num = mxGetNumberOfElements(pTemp);
		family[i] = malloc((num+1) * sizeof(double));
		for(j=0; j<num; j++){
			family[i][j] = pr[j];
		}
		family[i][j] = i + 1.0;
		family_num[i] = num + 1;
	}

	ns = malloc(nNodes * sizeof(int));
	pTemp = mxGetField(pBnet, 0, "node_sizes");
	pr = mxGetPr(pTemp);
	for(i=0; i<nNodes; i++){
		ns[i] = (int)pr[i];
	}
	num = mxGetNumberOfElements(prhs[3]);
	pr = mxGetPr(prhs[3]);
	for(i=0; i<num; i++){
		ns[(int)pr[i] - 1] = 1;
	}

	collect_order = malloc(2 * loops * sizeof(int));
	distribute_order = malloc(2 * loops * sizeof(int));

	pTemp = mxGetField(prhs[0], 0, "postorder");
	pr = mxGetPr(pTemp);
	pPostP = mxGetField(prhs[0], 0, "postorder_parents");
	for(i=0; i<loops; i++){
		temp = (int)pr[i] - 1;
		pTemp = mxGetCell(pPostP, temp);
		pr1 = mxGetPr(pTemp);
		collect_order[i] = (int)pr1[0] - 1;
		collect_order[i+loops] = temp;
	}

	pTemp = mxGetField(prhs[0], 0, "preorder");
	pr = mxGetPr(pTemp);
	pPreCh = mxGetField(prhs[0], 0, "preorder_children");
	count = 0;
	for(i=0; i<nCliques; i++){
		temp = (int)pr[i] - 1;
		pTemp = mxGetCell(pPreCh, temp);
		pr1 = mxGetPr(pTemp);
		num = mxGetNumberOfElements(pTemp);
		for(j=0; j<num; j++){
			distribute_order[count] = temp;
			distribute_order[count + loops] = (int)pr1[j] - 1;
			count++;
		}
	}

	max_table = 0;
	pTemp = mxGetField(prhs[0], 0, "clique_weight");
	clweight = mxGetPr(pTemp);
	for(i=0; i<nCliques; i++){
		temp = (int)clweight[i];
		max_table = (max_table > temp) ? max_table : temp;
	}

	p_sep = (int **)malloc(nCliques * sizeof(int *));
	for(i=0; i<nCliques; i++){
		p_sep[i] = malloc(nCliques * sizeof(int));
		for(j=0; j<nCliques; j++){
			p_sep[i][j] = 0;
		}
	}
	for(i=0; i<loops; i++){
		parent = distribute_order[i];
		child  = distribute_order[i+loops];
		p_sep[parent][child] = i;
	}

	clpot = (double **)malloc(nCliques * sizeof(double *));
	for(i=0; i<nCliques; i++){
		clpot[i] = malloc((int)clweight[i] * sizeof(double));
		for(j=0; j<clweight[i]; j++){
            clpot[i][j] = 1.0;
        }
	}
  
	multndx = malloc(max_table * sizeof(int));
	margndx = malloc(max_table * sizeof(int));

	pClNode = mxGetPr(prhs[1]);
	for(loop=0; loop<nNodes; loop++){
		ipointer = (int)pClNode[loop] - 1;
		pBig = mxGetCell(pCliques, ipointer);
		pBigDom = mxGetPr(pBig);
		siz_b = mxGetNumberOfElements(pBig);
		pSmallDom = family[loop];
		siz_s = family_num[loop];
		mk_multiply_ndx(pBigDom, pSmallDom, ns, siz_b, siz_s, multndx);
		num = (int)clweight[ipointer];
		pTemp = mxGetCell(prhs[2], loop);
		pt = mxGetField(pTemp, 0, "T");
		pr = mxGetPr(pt);
		multiply_pot(clpot[ipointer], pr, multndx, num); 
	}

	seppot = (double **)malloc((nCliques-1) * sizeof(double *));
	seppot0 = malloc(max_table * sizeof(double));
	pSeparator = mxGetField(prhs[0], 0, "separator");
	for(loop=0; loop<loops; loop++){
		parent = collect_order[loop];
		child  = collect_order[loop+loops];
		pBig  = mxGetCell(pCliques, child);
		pBigDom = mxGetPr(pBig);
		siz_b = mxGetNumberOfElements(pBig);
		ipointer = nCliques * child + parent;
		pSmall = mxGetCell(pSeparator, ipointer);
		pSmallDom = mxGetPr(pSmall);
		siz_s = mxGetNumberOfElements(pSmall);
		pc = p_sep[parent][child];
		row = mk_marginalise_ndx(pBigDom, pSmallDom, ns, siz_b, siz_s, margndx);
		seppot[pc] = malloc(row * sizeof(double));
		num = (int)clweight[child];
		marginalise_pot(clpot[child], seppot[pc], margndx, num, row);

		pBig  = mxGetCell(pCliques, parent);
		pBigDom = mxGetPr(pBig);
		siz_b = mxGetNumberOfElements(pBig);
		ipointer = nCliques * child + parent;
		pSmall = mxGetCell(pSeparator, ipointer);
		pSmallDom = mxGetPr(pSmall);
		siz_s = mxGetNumberOfElements(pSmall);
		mk_multiply_ndx(pBigDom, pSmallDom, ns, siz_b, siz_s, multndx);
		num = (int)clweight[parent];
		multiply_pot(clpot[parent], seppot[pc], multndx, num); 
	}

	for(loop=0; loop<loops; loop++){
		parent = distribute_order[loop];
		child  = distribute_order[loop+loops];
		ipointer = nCliques * child + parent;
		pBig  = mxGetCell(pCliques, parent);
		pBigDom = mxGetPr(pBig);
		siz_b = mxGetNumberOfElements(pBig);
		pSmall = mxGetCell(pSeparator, ipointer);
		pSmallDom = mxGetPr(pSmall);
		siz_s = mxGetNumberOfElements(pSmall);
		row = mk_marginalise_ndx(pBigDom, pSmallDom, ns, siz_b, siz_s, margndx);
		num = (int)clweight[parent];
		marginalise_pot(clpot[parent], seppot0, margndx, num, row);
		
		pc = p_sep[parent][child];
		array_divide(seppot0, seppot[pc], row);

		pBig  = mxGetCell(pCliques, child);
		pBigDom = mxGetPr(pBig);
		siz_b = mxGetNumberOfElements(pBig);
		pSmall = mxGetCell(pSeparator, ipointer);
		pSmallDom = mxGetPr(pSmall);
		siz_s = mxGetNumberOfElements(pSmall);
		mk_multiply_ndx(pBigDom, pSmallDom, ns, siz_b, siz_s, multndx);
		num = (int)clweight[child];
		multiply_pot(clpot[child], seppot0, multndx, num); 
	}

	for(i=0; i<nCliques; i++){
		num = (int)clweight[i];
        lik = myNormalise(clpot[i], num);
    }
    loglik = log(lik);

	plhs[0] = mxCreateCellArray(1, &nCliques);
    for(i=0; i<nCliques; i++){
		num = (int)clweight[i];
        pTemp = mxCreateDoubleMatrix(num, 1, mxREAL);
        pr = mxGetPr(pTemp);
        for(j=0; j<num; j++){
            pr[j] = clpot[i][j];
        }
        mxSetCell(plhs[0], i, pTemp);
    }

	plhs[1]=mxCreateDoubleMatrix(1, 1, mxREAL);
    *mxGetPr(plhs[1]) = loglik;

	free(multndx);
	for(i=0; i<nCliques; i++){
		free(clpot[i]);
		free(p_sep[i]);
	}
	free(clpot);
	free(p_sep);
	free(margndx);
	for(i=0; i<loops; i++){
		free(seppot[i]);
	}
	free(seppot);

	free(distribute_order);
	free(collect_order);

	free(ns);
	free(buf);

	for(i=0; i<nNodes; i++){
		free(family[i]);
	}
	free(family);
	free(family_num);

}
	

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -