📄 bnlearn-engine.c
字号:
/* copy two */
preAllocation = MGetTotalAllocation();
newBN = BNCloneNoCPTs(bn);
BNFlushStructureCache(newBN);
srcNode = BNGetNodeByID(newBN, i);
dstNode = BNGetNodeByID(newBN, j);
BNNodeAddParent(dstNode, srcNode);
BNNodeInitCPT(dstNode);
_InitUserData(newBN, bn, BNL_ADDED_PARENT, j, i);
data = BNNodeGetUserData(dstNode);
data->isChangedFromCurrent = 1;
_SetCachedScore_AddedParent(data, j, i);
netData = BNGetUserData(newBN);
netData->changedOne = dstNode;
netData->changeComplexity = 2;
if((gParams->gMaxParentsPerNode == -1 ||
BNNodeGetNumParents(dstNode) <= gParams->gMaxParentsPerNode) &&
(gParams->gMaxParameterGrowthMult == -1 ||
(BNGetNumParameters(newBN) <
(gParams->gMaxParameterGrowthMult * gInitialParameterCount))) &&
(gParams->gMaxParameterCount == -1 ||
(BNNodeGetNumParameters(dstNode) <=
gParams->gMaxParameterCount)) &&
(gParams->gLimitBytes == -1 ||
((MGetTotalAllocation() - preAllocation) < gMaxBytesPerModel))&&
!BNHasCycle(newBN)) {
VLAppend(list, newBN);
} else {
_FreeUserData(newBN);
BNFree(newBN);
}
}
} else {
/* one is a parent of the other, make 2 copies, one with
link removed, one with link reversed */
/* copy one: remove */
/* removing a link changes the structure and will not
be equivilant to any of the candidates */
newBN = BNCloneNoCPTs(bn);
BNFlushStructureCache(newBN);
preAllocation = MGetTotalAllocation();
srcNode = BNGetNodeByID(newBN, i);
dstNode = BNGetNodeByID(newBN, j);
if(dstParentOfSrcIndex != -1) {
BNNodeRemoveParent(srcNode, dstParentOfSrcIndex);
BNNodeInitCPT(srcNode);
_InitUserData(newBN, bn, BNL_REMOVED_PARENT, i, j);
data = BNNodeGetUserData(srcNode);
data->isChangedFromCurrent = 1;
_SetCachedScore_RemovedParent(data, i, j);
netData = BNGetUserData(newBN);
netData->changedOne = srcNode;
netData->changeComplexity = -1;
} else {
BNNodeRemoveParent(dstNode, srcParentOfDstIndex);
BNNodeInitCPT(dstNode);
_InitUserData(newBN, bn, BNL_REMOVED_PARENT, j, i);
data = BNNodeGetUserData(dstNode);
data->isChangedFromCurrent = 1;
_SetCachedScore_RemovedParent(data, j, i);
netData = BNGetUserData(newBN);
netData->changedOne = dstNode;
netData->changeComplexity = -1;
}
if(!BNHasCycle(newBN) &&
(gParams->gMaxParameterGrowthMult == -1 ||
(BNGetNumParameters(newBN) <
(gParams->gMaxParameterGrowthMult * gInitialParameterCount))) &&
(gParams->gMaxParameterCount == -1 ||
(BNNodeGetNumParameters(srcNode) <=
gParams->gMaxParameterCount &&
BNNodeGetNumParameters(dstNode) <=
gParams->gMaxParameterCount))&&
(gParams->gLimitBytes == -1 ||
((MGetTotalAllocation() - preAllocation) <
gMaxBytesPerModel))) {
VLAppend(list, newBN);
} else {
_FreeUserData(newBN);
BNFree(newBN);
}
/* copy two: reverse */
if(!gParams->gNoReverse) {
preAllocation = MGetTotalAllocation();
newBN = BNCloneNoCPTs(bn);
BNFlushStructureCache(newBN);
srcNode = BNGetNodeByID(newBN, i);
dstNode = BNGetNodeByID(newBN, j);
if(dstParentOfSrcIndex != -1) {
BNNodeRemoveParent(srcNode, dstParentOfSrcIndex);
BNNodeAddParent(dstNode, srcNode);
BNNodeInitCPT(srcNode);
BNNodeInitCPT(dstNode);
_InitUserData(newBN, bn, BNL_REVERSED_PARENT, i, j);
data = BNNodeGetUserData(srcNode);
data->isChangedFromCurrent = 1;
_SetCachedScore_RemovedParent(data, i, j);
data = BNNodeGetUserData(dstNode);
data->isChangedFromCurrent = 1;
_SetCachedScore_AddedParent(data, j, i);
netData = BNGetUserData(newBN);
netData->changedOne = srcNode;
netData->changedTwo = dstNode;
netData->changeComplexity = 1;
isLinkCovered = _IsLinkCoveredIn(srcNode, dstNode, newBN);
} else {
BNNodeRemoveParent(dstNode, srcParentOfDstIndex);
BNNodeAddParent(srcNode, dstNode);
BNNodeInitCPT(srcNode);
BNNodeInitCPT(dstNode);
_InitUserData(newBN, bn, BNL_REVERSED_PARENT, j, i);
data = BNNodeGetUserData(srcNode);
data->isChangedFromCurrent = 1;
_SetCachedScore_AddedParent(data, i, j);
data = BNNodeGetUserData(dstNode);
data->isChangedFromCurrent = 1;
_SetCachedScore_RemovedParent(data, j, i);
netData = BNGetUserData(newBN);
netData->changedOne = srcNode;
netData->changedTwo = dstNode;
netData->changeComplexity = 1;
isLinkCovered = _IsLinkCoveredIn(dstNode, srcNode, newBN);
}
if((gParams->gMaxParentsPerNode == -1 ||
(BNNodeGetNumParents(srcNode) <= gParams->gMaxParentsPerNode &&
BNNodeGetNumParents(dstNode) <= gParams->gMaxParentsPerNode)) &&
!isLinkCovered &&
(gParams->gLimitBytes == -1 ||
((MGetTotalAllocation() - preAllocation) < gMaxBytesPerModel))&&
(gParams->gMaxParameterGrowthMult == -1 ||
(BNGetNumParameters(newBN) <
(gParams->gMaxParameterGrowthMult * gInitialParameterCount))) &&
(gParams->gMaxParameterCount == -1 ||
(BNNodeGetNumParameters(srcNode) <=
gParams->gMaxParameterCount &&
BNNodeGetNumParameters(dstNode) <=
gParams->gMaxParameterCount))&&
!BNHasCycle(newBN)) {
VLAppend(list, newBN);
} else {
_FreeUserData(newBN);
BNFree(newBN);
}
}
}
}
}
}
static void _OptimizedAddSampleInit(BeliefNet current, VoidListPtr newNets) {
BNNodeUserData nodeData;
int i;
for(i=0; i<VLLength(current->nodes); i++) {
BeliefNetNode node = (BeliefNetNode)VLIndex(current->nodes, i);
nodeData = BNNodeGetUserData(node);
if (!nodeData->scoreIsValid) {
BNNodeZeroCPT(node);
}
}
for(i = 0 ; i < VLLength(newNets) ; i++) {
BNUserData netData = BNGetUserData(VLIndex(newNets, i));
if(netData->changedOne) {
nodeData = BNNodeGetUserData(netData->changedOne);
if (!nodeData->scoreIsValid) {
BNNodeZeroCPT(netData->changedOne);
}
}
if(netData->changedTwo) {
nodeData = BNNodeGetUserData(netData->changedTwo);
if (!nodeData->scoreIsValid) {
BNNodeZeroCPT(netData->changedTwo);
}
}
}
}
static void _OptimizedAddSample(BeliefNet current, VoidListPtr newNets,
ExamplePtr e) {
BNUserData netData;
BNNodeUserData nodeData;
int i;
//BNAddSample(current, e);
for(i=0; i<VLLength(current->nodes); i++) {
BeliefNetNode node = (BeliefNetNode)VLIndex(current->nodes, i);
nodeData = BNNodeGetUserData(node);
if (!nodeData->scoreIsValid) {
BNNodeAddSample(node, e);
}
}
for(i = 0 ; i < VLLength(newNets) ; i++) {
netData = BNGetUserData(VLIndex(newNets, i));
if(netData->changedOne) {
nodeData = BNNodeGetUserData(netData->changedOne);
if (!nodeData->scoreIsValid)
BNNodeAddSample(netData->changedOne, e);
}
if(netData->changedTwo) {
nodeData = BNNodeGetUserData(netData->changedTwo);
if (!nodeData->scoreIsValid)
BNNodeAddSample(netData->changedTwo, e);
}
}
}
static void _OptimizedAddSamples(BeliefNet current, VoidListPtr newNets,
VoidListPtr samples) {
BNUserData netData;
BNNodeUserData nodeData;
int i;
//BNAddSamples(current, samples);
for(i=0; i<VLLength(current->nodes); i++) {
BeliefNetNode node = (BeliefNetNode)VLIndex(current->nodes, i);
nodeData = BNNodeGetUserData(node);
if (!nodeData->scoreIsValid) {
BNNodeZeroCPT(node);
BNNodeAddSamples(node, samples);
}
}
for(i = 0 ; i < VLLength(newNets) ; i++) {
netData = BNGetUserData(VLIndex(newNets, i));
if(netData->changedOne) {
nodeData = BNNodeGetUserData(netData->changedOne);
if (!nodeData->scoreIsValid) {
BNNodeZeroCPT(netData->changedOne);
BNNodeAddSamples(netData->changedOne, samples);
}
}
if(netData->changedTwo) {
nodeData = BNNodeGetUserData(netData->changedTwo);
if (!nodeData->scoreIsValid) {
BNNodeZeroCPT(netData->changedTwo);
BNNodeAddSamples(netData->changedTwo, samples);
}
}
}
}
static float _ScoreBN(BeliefNet bn) {
int numCPTRows;
int i,j,k;
BeliefNetNode bnn;
double score;
double prob, numSamples;
/* score is sum over atribs, over parent combos, over attrib value of:
p_ijk lg P_ijk - p_ij lg p_ij */
score = 0;
for(i = 0 ; i < BNGetNumNodes(bn) ; i++) {
bnn = BNGetNodeByID(bn, i);
numSamples = BNNodeGetNumSamples(bnn);
numCPTRows = BNNodeGetNumCPTRows(bnn);
for(j = 0 ; j < numCPTRows ; j++) {
/* HACK for efficiency break BNN ADT */
for(k = 0 ; k < BNNodeGetNumValues(bnn) ; k++) {
if(numSamples != 0) {
prob = bnn->eventCounts[j][k] / numSamples;
DebugMessage(1, 4,
" i: %d j: %d k: %d eventcount: %lf rowcount: %lf\n",
i, j, k, bnn->eventCounts[j][k],
bnn->rowCounts[j]);
if(prob != 0) {
score += prob * log(prob);
DebugMessage(1, 4, " Score now: %lf\n", score);
}
}
}
if(numSamples != 0) {
prob = bnn->rowCounts[j] / BNNodeGetNumSamples(bnn);
DebugMessage(1, 4,
" i: %d j: %d rowcount: %lf numsamples: %lf\n",
i, j, bnn->rowCounts[j],
numSamples);
if(prob != 0) {
score -= prob * log(prob);
DebugMessage(1, 4, " Score now: %lf\n", score);
}
}
}
}
return score;
}
static double _GetStructuralDifferenceScoreNode(BeliefNetNode bnn) {
int difference = 0;
int i;
BeliefNetNode priorNode;
priorNode = BNGetNodeByID(gPriorNet, BNNodeGetID(bnn));
for(i = 0 ; i < BNNodeGetNumParents(bnn) ; i++) {
if(!BNNodeHasParentID(priorNode, BNNodeGetParentID(bnn, i))) {
difference++;
}
}
for(i = 0 ; i < BNNodeGetNumParents(priorNode) ; i++) {
if(!BNNodeHasParentID(bnn, BNNodeGetParentID(priorNode, i))) {
difference++;
}
}
return difference * log(gParams->gKappa);
}
void _UpdateNodeBD(BeliefNetNode bnn) {
int numCPTRows, numValues;
BNNodeUserData data;
int j, k;
double numSamples;
double gamma_numValues, gamma_1;
data = BNNodeGetUserData(bnn);
if (data->scoreIsValid) {
//printf("cached value is %g\n", data->avgDataLL);
return;
}
if (gParams->gCallbackAPI && gParams->gCallbackAPI->NodeUpdateBD) {
BNUserData netData = BNGetUserData(bnn->bn);
gParams->gCallbackAPI->NodeUpdateBD(&netData->callbackAPIdata, bnn, bnn->eventCounts, bnn->rowCounts);
}
data->avgDataLL = 0;
numCPTRows = BNNodeGetNumCPTRows(bnn);
numSamples = BNNodeGetNumSamples(bnn);
numValues = BNNodeGetNumValues(bnn);
gamma_numValues = StatLogGamma(numValues);
gamma_1 = StatLogGamma(1);
for(j = 0 ; j < numCPTRows ; j++) {
/* HACK for efficiency break BNN ADT */
/* HERE HERE update this to use the probabilities from the
prior network */
data->avgDataLL += gamma_numValues -
StatLogGamma(numValues + bnn->rowCounts[j]);
for(k = 0 ; k < numValues ; k++) {
data->avgDataLL += StatLogGamma(1 + bnn->eventCounts[j][k]);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -