📄 bnlearn-engine.c
字号:
data->avgDataLL -= numValues*gamma_1;
}
/* now scale it by the likelihood of the node given structural prior */
data->avgDataLL += _GetStructuralDifferenceScoreNode(bnn);
if (data->writebackScore) {
*(data->writebackScore) = data->avgDataLL;
*(data->writebackScoreIsValid) = TRUE;
}
}
static void _UpdateCurrentNetScoreBD(BeliefNet bn) {
int i;
for(i = 0 ; i < BNGetNumNodes(bn) ; i++) {
_UpdateNodeBD(BNGetNodeByID(bn, i));
}
}
static void _UpdateNetScoreBD(BeliefNet bn) {
BNUserData netData = BNGetUserData(bn);
if(netData->changedOne) {
_UpdateNodeBD(netData->changedOne);
}
if(netData->changedTwo) {
_UpdateNodeBD(netData->changedTwo);
}
}
static double _GetNodeScore(BeliefNetNode bnn) {
BNNodeUserData data = BNNodeGetUserData(bnn);
if(data->current == 0 || data->isChangedFromCurrent) {
return data->avgDataLL;
} else {
return data->current->avgDataLL;
}
}
float _ScoreBNOptimized(BeliefNet bn);
static int _IsFirstNetBetter(BeliefNet first, BeliefNet second) {
BNUserData firstData = BNGetUserData(first);
BNUserData secondData = BNGetUserData(second);
int i1, i2, i;
double scoreOne, scoreTwo;
scoreOne = scoreTwo = 0;
i1 = i2 = -1;
if(firstData->changedOne) {
i = i1 = BNNodeGetID(firstData->changedOne);
scoreOne += _GetNodeScore(BNGetNodeByID(first, i));
scoreTwo += _GetNodeScore(BNGetNodeByID(second, i));
}
if(firstData->changedTwo) {
i = i2 = BNNodeGetID(firstData->changedTwo);
scoreOne += _GetNodeScore(BNGetNodeByID(first, i));
scoreTwo += _GetNodeScore(BNGetNodeByID(second, i));
}
if(secondData->changedOne) {
i = BNNodeGetID(secondData->changedOne);
if(i != i1 && i != i2) {
/* don't compare same nodes twice */
scoreOne += _GetNodeScore(BNGetNodeByID(first, i));
scoreTwo += _GetNodeScore(BNGetNodeByID(second, i));
}
}
if(secondData->changedTwo) {
i = BNNodeGetID(secondData->changedTwo);
if(i != i1 && i != i2) {
/* don't compare same nodes twice */
scoreOne += _GetNodeScore(BNGetNodeByID(first, i));
scoreTwo += _GetNodeScore(BNGetNodeByID(second, i));
}
}
return scoreOne > scoreTwo;
}
float _ScoreBNOptimized(BeliefNet bn) {
int i;
double score;
score = 0;
for(i = 0 ; i < BNGetNumNodes(bn) ; i++) {
score += _GetNodeScore(BNGetNodeByID(bn, i));
}
return score;
}
static void _UpdateCPTsForFrom(BeliefNet target, BeliefNet source) {
BeliefNetNode targetNode, sourceNode;
BNNodeUserData data;
int i;
for(i = 0 ; i < BNGetNumNodes(target) ; i++) {
targetNode = BNGetNodeByID(target, i);
sourceNode = BNGetNodeByID(source, i);
data = BNNodeGetUserData(targetNode);
if(!data->isChangedFromCurrent) {
BNNodeInitCPT(targetNode);
BNNodeSetCPTFrom(targetNode, sourceNode);
}
}
}
static int _IsCurrentNet(BeliefNet bn) {
BNUserData data;
data = (BNUserData)BNGetUserData(bn);
return (data->changedOne == 0) && (data->changedTwo == 0);
}
static void _CompareNetsFreeLoosers(BeliefNet current,
VoidListPtr netChoices) {
BeliefNet netOne, netTwo;
_UpdateCurrentNetScoreBD(current);
netOne = VLRemove(netChoices, VLLength(netChoices) - 1);
_UpdateNetScoreBD(netOne);
while(VLLength(netChoices)) {
netTwo = VLRemove(netChoices, VLLength(netChoices) - 1);
_UpdateNetScoreBD(netTwo);
if(_IsFirstNetBetter(netOne, netTwo)) {
_FreeUserData(netTwo);
BNFree(netTwo);
} else {
_FreeUserData(netOne);
BNFree(netOne);
netOne = netTwo;
}
}
VLAppend(netChoices, netOne);
}
void _AllocFailed(int allocationSize) {
printf("MEMORY ALLOCATION FAILED, size %d\n", allocationSize);
}
static int _IsTimeExpired(struct tms starttime) {
struct tms endtime;
long seconds;
if(gParams->gLimitSeconds != -1) {
times(&endtime);
seconds = (double)(endtime.tms_utime - starttime.tms_utime) / 100.0;
return seconds >= gParams->gLimitSeconds;
}
return 0;
}
void BNLearn(BNLearnParams *params) {
ExampleSpecPtr es;
int allDone, searchStep;
BeliefNet bn;
long learnTime, allocation, seenTotal;
struct tms starttime;
struct tms endtime;
VoidListPtr netChoices;
VoidListPtr previousWinners;
gParams = params;
previousWinners = VLNew();
MSetAllocFailFunction(_AllocFailed);
// Set up the input network
if (gParams->gInputNetMemory) {
bn = gParams->gInputNetMemory;
}
else if(gParams->gInputNetFile) {
bn = BNReadBIFFILEP(gParams->gInputNetFile);
if(bn == 0) {
DebugMessage(1, 1, "couldn't read net\n");
}
BNZeroCPTs(bn);
}
else if(gParams->gInputNetEmptySpec) {
bn = BNNewFromSpec(gParams->gInputNetEmptySpec);
}
else {
DebugError(1,"Error, at least one of gInputNetMemory, gInputNetFile, or gInputNetEmptySpec must be specified");
return;
}
assert(bn);
es = BNGetExampleSpec(bn);
if (!gParams->gDataMemory && !gParams->gDataFile) {
DebugError(1,"Error, at least one of gDataMemory or gDataFile must be set");
return;
}
gInitialParameterCount = BNGetNumParameters(bn);
gBranchFactor = BNGetNumNodes(bn) * BNGetNumNodes(bn);
if(gParams->gLimitBytes != -1) {
gMaxBytesPerModel = gParams->gLimitBytes / gBranchFactor;
DebugMessage(1, 2, "Limit models to %.4lf megs\n",
gMaxBytesPerModel / (1024.0 * 1024.0));
}
gPriorNet = BNClone(bn);
RandomInit();
/* seed */
if(gParams->gSeed != -1) {
RandomSeed(gParams->gSeed);
} else {
gParams->gSeed = RandomRange(1, 30000);
RandomSeed(gParams->gSeed);
}
DebugMessage(1, 1, "running with seed %d\n", gParams->gSeed);
DebugMessage(1, 1, "allocation %ld\n", MGetTotalAllocation());
DebugMessage(1, 1, "initial parameters %ld\n", gInitialParameterCount);
times(&starttime);
seenTotal = 0;
learnTime = 0;
_InitScoreCache(bn);
allDone = 0;
searchStep = 0;
while(!allDone) {
searchStep++;
DebugMessage(1, 2, "============== Search step: %d ==============\n",
searchStep);
DebugMessage(1, 2, " Total samples: %ld\n", seenTotal);
DebugMessage(1, 2, " allocation before choices %ld\n",
MGetTotalAllocation());
DebugMessage(1, 2, " best with score %f:\n", _ScoreBN(bn));
if(DebugGetMessageLevel() >= 2) {
BNPrintStats(bn);
}
if(DebugGetMessageLevel() >= 3) {
BNWriteBIF(bn, DebugGetTarget());
}
netChoices = VLNew();
_GetOneStepChoicesForBN(bn, netChoices);
DebugMessage(1, 2, " allocation after choices %ld there are %d\n",
MGetTotalAllocation(), VLLength(netChoices));
if (gParams->gDataMemory) {
_OptimizedAddSamples(bn, netChoices, gParams->gDataMemory);
seenTotal += VLLength(gParams->gDataMemory);
if(_IsTimeExpired(starttime))
allDone = 1;
}
if (gParams->gDataFile) {
int stepDone=0;
ExamplePtr e = ExampleRead(gParams->gDataFile, es);
_OptimizedAddSampleInit(bn, netChoices);
while(!stepDone && e != 0) {
seenTotal++;
/* put the eg in every net choice */
_OptimizedAddSample(bn, netChoices, e);
if(_IsTimeExpired(starttime)) {
stepDone = allDone = 1;
}
ExampleFree(e);
e = ExampleRead(gParams->gDataFile, es);
} /* !stepDone && e != 0 */
}
_CompareNetsFreeLoosers(bn, netChoices);
/* if the winner is the current one then we are all done */
if(BNStructureEqual(bn, (BeliefNet)VLIndex(netChoices, 0)) ||
!_IsFirstNetBetter((BeliefNet)VLIndex(netChoices,0), bn)) {
/* make sure to free all loosing choices and the netChoices list */
allDone = 1;
} else if(gParams->gMaxSearchSteps != -1 && searchStep >= gParams->gMaxSearchSteps) {
DebugMessage(1, 1, "Stopped because of search step limit\n");
allDone = 1;
}
/* copy all the CPTs that are only in bn into the new winner */
/* only really needed for final output but I do it for debugging too */
_UpdateCPTsForFrom((BeliefNet)VLIndex(netChoices, 0), bn);
_InvalidateScoreCache((BeliefNet)VLIndex(netChoices, 0));
if(gParams->gOnlyEstimateParameters) {
allDone = 1;
}
if(gParams->gCheckModelCycle) {
/* now check all previous winners */
/* if we detect a cycle, pick the best that happens
in the period of the cycle */
if(!allDone) {
int i;
for(i = 0 ; i < VLLength(previousWinners) ; i++) {
if(BNStructureEqual(bn, (BeliefNet)VLIndex(previousWinners, i))) {
allDone = 1;
}
if(allDone) {
if(_ScoreBN(bn) <
_ScoreBN((BeliefNet)VLIndex(previousWinners, i))) {
bn = VLIndex(previousWinners, i);
}
}
}
}
VLAppend(previousWinners, BNClone(bn));
}
_FreeUserData(bn);
BNFree(bn);
bn = (BeliefNet)VLRemove(netChoices, 0);
_FreeUserData(bn);
VLFree(netChoices);
DebugMessage(1, 2, " allocation after all free %ld\n",
MGetTotalAllocation());
/* reset data file */
if (gParams->gDataFile)
rewind(gParams->gDataFile);
} /* while !allDone */
if (gParams->gSmoothAmount != 0) {
BNSmoothProbabilities(bn, gParams->gSmoothAmount);
}
times(&endtime);
learnTime += endtime.tms_utime - starttime.tms_utime;
DebugMessage(1, 1, "done learning...\n");
DebugMessage(1, 1, "time %.2lfs\n", ((double)learnTime) / 100);
DebugMessage(1, 1, "Total Samples: %ld\n", seenTotal);
if(DebugGetMessageLevel() >= 1) {
DebugMessage(1, 1, "Samples per round:\n");
}
allocation = MGetTotalAllocation();
//printf("Final score: %f\n", _ScoreBN(bn));
if (gParams->gOutputNetFilename) {
FILE *netOut = fopen(gParams->gOutputNetFilename, "w");
BNWriteBIF(bn, netOut);
fclose(netOut);
}
if (gParams->gOutputNetToMemory) {
gParams->gOutputNetMemory = bn;
}
else
BNFree(bn);
//ExampleSpecFree(es);
DebugMessage(1, 1, " allocation %ld\n", MGetTotalAllocation());
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -