📄 distribute_evidence.c
字号:
position = (result - sequence) / 2;
if(maximize)
sTable[position] = (sTable[position] < bpr[i]) ? bpr[i] : sTable[position];
else sTable[position] += bpr[i];
}
else {
if(maximize)
sTable[nzCounts] = (sTable[nzCounts] < bpr[i]) ? bpr[i] : sTable[nzCounts];
else sTable[nzCounts] += bpr[i];
sequence[count] = sindex;
count++;
sequence[count] = nzCounts;
nzCounts++;
count++;
}
}
pTemp = mxGetField(smallPot, 0, "T");
if(pTemp)mxDestroyArray(pTemp);
qsort(sequence, nzCounts, sizeof(int) * 2, compare);
pTemp = convert_ill_table_to_sparse(sTable, sequence, nzCounts, NS);
mxSetField(smallPot, 0, "T", pTemp);
free(sTable);
free(sequence);
free(mask);
free(bCumprod);
free(sCumprod);
free(bsubv);
free(ssubv);
}
void divide_null_by_spPot(mxArray *bigPot, const mxArray *smallPot){
int i, j, count, count1, match, temp, bdim, sdim, diffdim, NB, NS, ND, NZB, NZS, bindex, sindex;
int *samemask, *diffmask, *rir, *rjc, *sir, *sjc, *bCumprod, *sCumprod, *ssubv, *weight;
double *pbDomain, *psDomain, *pbSize, *psSize, *rpr, *spr, value;
mxArray *pTemp, *pTemp1;
pTemp = mxGetField(bigPot, 0, "domain");
pbDomain = mxGetPr(pTemp);
bdim = mxGetNumberOfElements(pTemp);
pTemp = mxGetField(smallPot, 0, "domain");
psDomain = mxGetPr(pTemp);
sdim = mxGetNumberOfElements(pTemp);
pTemp = mxGetField(bigPot, 0, "sizes");
pbSize = mxGetPr(pTemp);
pTemp = mxGetField(smallPot, 0, "sizes");
psSize = mxGetPr(pTemp);
NB = 1;
for(i=0; i<bdim; i++){
NB *= (int)pbSize[i];
}
pTemp = mxGetField(smallPot, 0, "T");
spr = mxGetPr(pTemp);
sir = mxGetIr(pTemp);
sjc = mxGetJc(pTemp);
NZS = sjc[1];
if(sdim == 0){
pTemp1 = mxGetField(bigPot, 0, "T");
if(pTemp1)mxDestroyArray(pTemp1);
pTemp = mxCreateSparse(NB, 1, NB, mxREAL);
mxSetField(bigPot, 0, "T", pTemp);
rpr = mxGetPr(pTemp);
rir = mxGetIr(pTemp);
rjc = mxGetJc(pTemp);
rjc[0] = 0;
rjc[1] = NB;
value = *spr;
if(value == 0) value = 1;
for(i=0; i<NB; i++){
rpr[i] = 1 / value;
rir[i] = i;
}
return;
}
NS = 1;
for(i=0; i<sdim; i++){
NS *= (int)psSize[i];
}
ND = NB / NS;
pTemp = mxCreateSparse(NB, 1, NB, mxREAL);
rpr = mxGetPr(pTemp);
rir = mxGetIr(pTemp);
rjc = mxGetJc(pTemp);
rjc[0] = 0;
rjc[1] = NB;
for(i=0; i<NB; i++){
rpr[i] = 1;
rir[i] = i;
}
NZB = ND * NZS;
diffdim = bdim - sdim;
samemask = malloc(sdim * sizeof(int));
diffmask = malloc(diffdim * sizeof(int));
bCumprod = malloc(bdim * sizeof(int));
sCumprod = malloc(sdim * sizeof(int));
weight = malloc(ND * sizeof(int));
ssubv = malloc(sdim * sizeof(int));
count = 0;
count1 = 0;
for(i=0; i<bdim; i++){
match = 0;
for(j=0; j<sdim; j++){
if(pbDomain[i] == psDomain[j]){
samemask[count] = i;
match = 1;
count++;
break;
}
}
if(match == 0){
diffmask[count1] = i;
count1++;
}
}
bCumprod[0] = 1;
for(i=0; i<bdim-1; i++){
bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
}
sCumprod[0] = 1;
for(i=0; i<sdim-1; i++){
sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
}
count = 0;
compute_fixed_weight(weight, pbSize, diffmask, bCumprod, ND, diffdim);
for(i=0; i<NZS; i++){
sindex = sir[i];
ind_subv(sindex, sCumprod, sdim, ssubv);
temp = 0;
for(j=0; j<sdim; j++){
temp += ssubv[j] * bCumprod[samemask[j]];
}
for(j=0; j<ND; j++){
bindex = weight[j] + temp;
rpr[bindex] = 1 / (spr[i]);
}
}
pTemp1 = mxGetField(bigPot, 0, "T");
if(pTemp1)mxDestroyArray(pTemp1);
mxSetField(bigPot, 0, "T", pTemp);
free(samemask);
free(diffmask);
free(bCumprod);
free(sCumprod);
free(weight);
free(ssubv);
}
void divide_spPot_by_spPot(mxArray *bigPot, const mxArray *smallPot){
int i, j, count, bdim, sdim, NB, NZB, NZS, position, bindex, sindex;
int *mask, *result, *bir, *sir, *bjc, *sjc, *bCumprod, *sCumprod, *bsubv, *ssubv;
double *pbDomain, *psDomain, *pbSize, *psSize, *bpr, *spr, value;
mxArray *pTemp, *pTemp1;
pTemp = mxGetField(bigPot, 0, "domain");
pbDomain = mxGetPr(pTemp);
bdim = mxGetNumberOfElements(pTemp);
pTemp = mxGetField(smallPot, 0, "domain");
psDomain = mxGetPr(pTemp);
sdim = mxGetNumberOfElements(pTemp);
pTemp = mxGetField(bigPot, 0, "sizes");
pbSize = mxGetPr(pTemp);
pTemp = mxGetField(smallPot, 0, "sizes");
psSize = mxGetPr(pTemp);
NB = 1;
for(i=0; i<bdim; i++){
NB *= (int)pbSize[i];
}
pTemp1 = mxGetField(bigPot, 0, "T");
bpr = mxGetPr(pTemp1);
bir = mxGetIr(pTemp1);
bjc = mxGetJc(pTemp1);
NZB = bjc[1];
pTemp = mxGetField(smallPot, 0, "T");
spr = mxGetPr(pTemp);
sir = mxGetIr(pTemp);
sjc = mxGetJc(pTemp);
NZS = sjc[1];
if(sdim == 0){
value = *spr;
if(value == 0)value = 1;
for(i=0; i<NZB; i++){
bpr[i] /= value;
}
return;
}
mask = malloc(sdim * sizeof(int));
bCumprod = malloc(bdim * sizeof(int));
sCumprod = malloc(sdim * sizeof(int));
bsubv = malloc(bdim * sizeof(int));
ssubv = malloc(sdim * sizeof(int));
count = 0;
for(i=0; i<sdim; i++){
for(j=0; j<bdim; j++){
if(psDomain[i] == pbDomain[j]){
mask[count] = j;
count++;
break;
}
}
}
bCumprod[0] = 1;
for(i=0; i<bdim-1; i++){
bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
}
sCumprod[0] = 1;
for(i=0; i<sdim-1; i++){
sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
}
for(i=0; i<NZB; i++){
bindex = bir[i];
ind_subv(bindex, bCumprod, bdim, bsubv);
for(j=0; j<sdim; j++){
ssubv[j] = bsubv[mask[j]];
}
sindex = subv_ind(sdim, sCumprod, ssubv);
result = (int *) bsearch(&sindex, sir, NZS, sizeof(int), compare);
if(result){
position = result - sir;
bpr[i] /= spr[position];
}
}
free(mask);
free(bCumprod);
free(sCumprod);
free(bsubv);
free(ssubv);
}
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){
int i, j, loop, loops, nCliques, temp, count, parent, child, maximize, *distribute_order;
double *pr, *pr1;
mxArray *pTemp, *pPreCh, *pClpot, *pSeppot;
pTemp = mxGetField(prhs[0], 0, "cliques");
nCliques = mxGetNumberOfElements(pTemp);
loops = nCliques - 1;
pTemp = mxGetField(prhs[0], 0, "maximize");
maximize = (int)mxGetScalar(pTemp);
distribute_order = malloc(2 * loops * sizeof(int));
pTemp = mxGetField(prhs[0], 0, "preorder");
pr = mxGetPr(pTemp);
pPreCh = mxGetField(prhs[0], 0, "preorder_children");
count = 0;
for(i=0; i<nCliques; i++){
temp = (int)pr[i] - 1;
pTemp = mxGetCell(pPreCh, temp);
pr1 = mxGetPr(pTemp);
loop = mxGetNumberOfElements(pTemp);
for(j=0; j<loop; j++){
distribute_order[count] = temp;
distribute_order[count + loops] = (int)pr1[j] - 1;
count++;
}
}
plhs[0] = mxDuplicateArray(prhs[1]);
plhs[1] = mxDuplicateArray(prhs[2]);
for(loop=0; loop<loops; loop++){
parent = distribute_order[loop];
child = distribute_order[loop+loops];
i = nCliques * child + parent;
pClpot = mxGetCell(plhs[0], child);
pTemp = mxGetField(pClpot, 0, "T");
pSeppot = mxGetCell(plhs[1], i);
if(pTemp){
if(mxIsEmpty(pTemp))
divide_null_by_spPot(pClpot, pSeppot);
else
divide_spPot_by_spPot(pClpot, pSeppot);
}
else divide_null_by_spPot(pClpot, pSeppot);
pClpot = mxGetCell(plhs[0], parent);
marginal_spPot_to_spPot(pClpot, pSeppot, maximize);
mxSetCell(plhs[1], i, pSeppot);
pClpot = mxGetCell(plhs[0], child);
multiply_spPot_by_spPot(pClpot, pSeppot);
}
free(distribute_order);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -