⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 init_pot.c

📁 贝叶斯算法(matlab编写) 安装,添加目录 /home/ai2/murphyk/matlab/FullBNT
💻 C
📖 第 1 页 / 共 2 页
字号:
	sequence = malloc(NZB * 2 * sizeof(int));
	bigTable = malloc(NZB * sizeof(double));
	samemask = malloc(sdim * sizeof(int));
	diffmask = malloc(diffdim * sizeof(int));
	bCumprod = malloc(bdim * sizeof(int));
	sCumprod = malloc(sdim * sizeof(int));
	weight = malloc(ND * sizeof(int));
	ssubv = malloc(sdim * sizeof(int));

	count = 0;
	count1 = 0;
	for(i=0; i<bdim; i++){
		match = 0;
		for(j=0; j<sdim; j++){
			if(pbDomain[i] == psDomain[j]){
				samemask[count] = i;
				match = 1;
				count++;
				break;
			}
		}
		if(match == 0){
			diffmask[count1] = i; 
			count1++;
		}
	}

	bCumprod[0] = 1;
	for(i=0; i<bdim-1; i++){
		bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
	}
	sCumprod[0] = 1;
	for(i=0; i<sdim-1; i++){
		sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
	}

	count = 0;
	compute_fixed_weight(weight, pbSize, diffmask, bCumprod, ND, diffdim);
	for(i=0; i<NZS; i++){
		sindex = sir[i];
		ind_subv(sindex, sCumprod, sdim, ssubv);
		temp = 0;
		for(j=0; j<sdim; j++){
			temp += ssubv[j] * bCumprod[samemask[j]];
		}
		for(j=0; j<ND; j++){
			bindex = weight[j] + temp;
			bigTable[nzCounts] = spr[i];
			sequence[count] = bindex;
			count++;
			sequence[count] = nzCounts;
			nzCounts++;
			count++;
		}
	}

	pTemp = mxGetField(bigPot, 0, "T");
	if(pTemp)mxDestroyArray(pTemp);
	qsort(sequence, nzCounts, sizeof(int) * 2, compare);
	pTemp = convert_ill_table_to_sparse(bigTable, sequence, nzCounts, NB);
	mxSetField(bigPot, 0, "T", pTemp);

	free(sequence); 
	free(bigTable);
	free(samemask);
	free(diffmask);
	free(bCumprod);
	free(sCumprod);
	free(weight);
	free(ssubv);
}

void multiply_spPot_by_fuPot(mxArray *bigPot, const mxArray *smallPot){
	int     i, j, count, bdim, sdim, NB, NZB, bindex, sindex, nzCounts=0;
	int     *mask, *index, *bir, *bjc, *bCumprod, *sCumprod, *bsubv, *ssubv;
	double  *bigTable, *pbDomain, *psDomain, *pbSize, *psSize, *bpr, *spr, value;
	mxArray *pTemp;

	pTemp = mxGetField(bigPot, 0, "domain");
	pbDomain = mxGetPr(pTemp);
	bdim = mxGetNumberOfElements(pTemp);
	pTemp = mxGetField(smallPot, 0, "domain");
	psDomain = mxGetPr(pTemp);
	sdim = mxGetNumberOfElements(pTemp);

	pTemp = mxGetField(bigPot, 0, "sizes");
	pbSize = mxGetPr(pTemp);
	pTemp = mxGetField(smallPot, 0, "sizes");
	psSize = mxGetPr(pTemp);

	NB = 1;
	for(i=0; i<bdim; i++){
		NB *= (int)pbSize[i];
	}

	pTemp = mxGetField(bigPot, 0, "T");
	bpr = mxGetPr(pTemp);
	bir = mxGetIr(pTemp);
	bjc = mxGetJc(pTemp);
	NZB = bjc[1];

	pTemp = mxGetField(smallPot, 0, "T");
	spr = mxGetPr(pTemp);

	bigTable = malloc(NZB * sizeof(double));
	index = malloc(NZB * sizeof(double));
	mask = malloc(sdim * sizeof(int));
	bCumprod = malloc(bdim * sizeof(int));
	sCumprod = malloc(sdim * sizeof(int));
	bsubv = malloc(bdim * sizeof(int));
	ssubv = malloc(sdim * sizeof(int));

	for(i=0; i<NZB; i++){
		bigTable[i] = 0;
	}
	count = 0;
	for(i=0; i<sdim; i++){
		for(j=0; j<bdim; j++){
			if(psDomain[i] == pbDomain[j]){
				mask[count] = j;
				count++;
				break;
			}
		}
	}
	
	bCumprod[0] = 1;
	for(i=0; i<bdim-1; i++){
		bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
	}
	sCumprod[0] = 1;
	for(i=0; i<sdim-1; i++){
		sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
	}

	for(i=0; i<NZB; i++){
		bindex = bir[i];
		ind_subv(bindex, bCumprod, bdim, bsubv);
		for(j=0; j<sdim; j++){
			ssubv[j] = bsubv[mask[j]];
		}
		sindex = subv_ind(sdim, sCumprod, ssubv);
		value = spr[sindex];
		if(value != 0){
			bigTable[nzCounts] = bpr[i] * value;
			index[nzCounts] = bindex;
			nzCounts++;
		}
	}

	pTemp = mxGetField(bigPot, 0, "T");
	if(pTemp)mxDestroyArray(pTemp);
	pTemp = convert_table_to_sparse(bigTable, index, nzCounts, NB);
	mxSetField(bigPot, 0, "T", pTemp);

	free(bigTable);
	free(index);
	free(mask);
	free(bCumprod);
	free(sCumprod);
	free(bsubv);
	free(ssubv);
}

void multiply_spPot_by_spPot(mxArray *bigPot, const mxArray *smallPot){
	int     i, j, count, bdim, sdim, NB, NZB, NZS, position, bindex, sindex, nzCounts=0;
	int     *mask, *index, *result, *bir, *sir, *bjc, *sjc, *bCumprod, *sCumprod, *bsubv, *ssubv;
	double  *bigTable, *pbDomain, *psDomain, *pbSize, *psSize, *bpr, *spr, value;
	mxArray *pTemp;

	pTemp = mxGetField(bigPot, 0, "domain");
	pbDomain = mxGetPr(pTemp);
	bdim = mxGetNumberOfElements(pTemp);
	pTemp = mxGetField(smallPot, 0, "domain");
	psDomain = mxGetPr(pTemp);
	sdim = mxGetNumberOfElements(pTemp);

	pTemp = mxGetField(bigPot, 0, "sizes");
	pbSize = mxGetPr(pTemp);
	pTemp = mxGetField(smallPot, 0, "sizes");
	psSize = mxGetPr(pTemp);

	NB = 1;
	for(i=0; i<bdim; i++){
		NB *= (int)pbSize[i];
	}

	pTemp = mxGetField(bigPot, 0, "T");
	bpr = mxGetPr(pTemp);
	bir = mxGetIr(pTemp);
	bjc = mxGetJc(pTemp);
	NZB = bjc[1];

	pTemp = mxGetField(smallPot, 0, "T");
	spr = mxGetPr(pTemp);
	sir = mxGetIr(pTemp);
	sjc = mxGetJc(pTemp);
	NZS = sjc[1];

	bigTable = malloc(NZB * sizeof(double));
	index = malloc(NZB * sizeof(double));
	mask = malloc(sdim * sizeof(int));
	bCumprod = malloc(bdim * sizeof(int));
	sCumprod = malloc(sdim * sizeof(int));
	bsubv = malloc(bdim * sizeof(int));
	ssubv = malloc(sdim * sizeof(int));

	for(i=0; i<NZB; i++){
		bigTable[i] = 0;
	}
	count = 0;
	for(i=0; i<sdim; i++){
		for(j=0; j<bdim; j++){
			if(psDomain[i] == pbDomain[j]){
				mask[count] = j;
				count++;
				break;
			}
		}
	}
	
	bCumprod[0] = 1;
	for(i=0; i<bdim-1; i++){
		bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
	}
	sCumprod[0] = 1;
	for(i=0; i<sdim-1; i++){
		sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
	}

	for(i=0; i<NZB; i++){
		value = bpr[i];
		bindex = bir[i];
		ind_subv(bindex, bCumprod, bdim, bsubv);
		for(j=0; j<sdim; j++){
			ssubv[j] = bsubv[mask[j]];
		}
		sindex = subv_ind(sdim, sCumprod, ssubv);
		result = (int *) bsearch(&sindex, sir, NZS, sizeof(int), compare);
		if(result){
			position = result - sir;
			value *= spr[position];
			bigTable[nzCounts] = value;
			index[nzCounts] = bindex;
			nzCounts++;
		}
	}

	pTemp = mxGetField(bigPot, 0, "T");
	if(pTemp)mxDestroyArray(pTemp);
	pTemp = convert_table_to_sparse(bigTable, index, nzCounts, NB);
	mxSetField(bigPot, 0, "T", pTemp);

	free(bigTable);
	free(index);
	free(mask);
	free(bCumprod);
	free(sCumprod);
	free(bsubv);
	free(ssubv);
}


void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){
	int     i, j, c, loop, nNodes, nCliques, ndomain, dims[2];
	double  *pClqs, *pr, *pt, *pSize;
	mxArray *pTemp, *pTemp1, *pStruct, *pCliques, *pBigpot, *pSmallpot;
	const char *field_names[] = {"domain", "T", "sizes"};

	nNodes = mxGetNumberOfElements(prhs[1]);
	pCliques = mxGetField(prhs[0], 0, "cliques");
	nCliques = mxGetNumberOfElements(pCliques);
	pTemp = mxGetField(prhs[0], 0, "eff_node_sizes");
	pSize = mxGetPr(pTemp);

	plhs[0] = mxCreateCellArray(1, &nCliques);
    for(i=0; i<nCliques; i++){
        pStruct = mxCreateStructMatrix(1, 1, 3, field_names);
		mxSetCell(plhs[0], i, pStruct);
		pTemp = mxGetCell(pCliques, i);
		ndomain = mxGetNumberOfElements(pTemp);
		pt = mxGetPr(pTemp);
		pTemp1 = mxDuplicateArray(pTemp);
		mxSetField(pStruct, 0, "domain", pTemp1);
		
		pTemp = mxCreateDoubleMatrix(1, ndomain, mxREAL);
		mxSetField(pStruct, 0, "sizes", pTemp);
		pr = mxGetPr(pTemp);
        for(j=0; j<ndomain; j++){
            pr[j] = pSize[(int)pt[j]-1];
        }
    }

	pClqs = mxGetPr(prhs[1]);
	for(loop=0; loop<nNodes; loop++){
		c = (int)pClqs[loop] - 1;
		pSmallpot = mxGetCell(prhs[2], loop);
		pTemp = mxGetField(pSmallpot, 0, "T");
		pBigpot = mxGetCell(plhs[0], c);
		pTemp1 = mxGetField(pBigpot, 0, "T");
		if(pTemp1){
			if(mxIsSparse(pTemp))
				multiply_spPot_by_spPot(pBigpot, pSmallpot);
			else multiply_spPot_by_fuPot(pBigpot, pSmallpot);
		}
		else{
			if(mxIsSparse(pTemp))
				multiply_null_by_spPot(pBigpot, pSmallpot);
			else multiply_null_by_fuPot(pBigpot, pSmallpot);
		}		
	}

	dims[0] = nCliques;
	dims[1] = nCliques;
	plhs[1] = mxCreateCellArray(2, dims);
}


⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -