📄 bnlearn-engine.c
字号:
#include "bnlearn-engine.h"
#include <stdio.h>
#include <string.h>
#include <assert.h>
#include <sys/times.h>
#include <unistd.h>
#include <time.h>
#include <math.h>
#define BOOL char
#define TRUE 1
#define FALSE 0
long gInitialParameterCount=0;
long gBranchFactor=0;
long gMaxBytesPerModel=0;
BeliefNet gPriorNet=0;
BNLearnParams *BNLearn_NewParams() {
BNLearnParams *param = (BNLearnParams*)MNewPtr(sizeof(BNLearnParams));
param->gDataMemory = NULL;
param->gDataFile = NULL;
param->gInputNetMemory = NULL;
param->gInputNetFile = NULL;
param->gInputNetEmptySpec = NULL;
param->gOutputNetFilename = NULL;
param->gOutputNetToMemory = 0;
param->gOutputNetMemory = NULL;
param->gLimitBytes = -1;
param->gLimitSeconds = -1;
param->gNoReverse = 0;
param->gKappa = 0.5;
param->gOnlyEstimateParameters = 0;
param->gMaxSearchSteps = -1;
param->gMaxParentsPerNode = -1;
param->gMaxParameterGrowthMult = -1;
param->gMaxParameterCount = -1;
param->gSeed = -1;
param->gSmoothAmount = 1.0;
param->gCheckModelCycle = 0;
param->gCallbackAPI = NULL;
return param;
}
void BNLearn_FreeParams(BNLearnParams *params) {
MFreePtr(params);
}
BNLearnParams *gParams;
typedef struct BNUserData_ {
BeliefNetNode changedOne;
BeliefNetNode changedTwo;
/* -1 for removal, 0 for nothing, 1 for reverse, 2 for add */
int changeComplexity;
// data passed into the callback API, if one exists
void *callbackAPIdata;
} BNUserDataStruct, *BNUserData;
typedef struct BNNodeUserData_ {
double avgDataLL;
int isChangedFromCurrent;
struct BNNodeUserData_ *current; // null if this is a current node
// For caching scores:
double *writebackScore; // when you compute the score, write it here
BOOL *writebackScoreIsValid; // when you compute the score, set this to valid
BOOL scoreIsValid; // is avgDataLL already a valid score?
} BNNodeUserDataStruct, *BNNodeUserData;
static void _FreeUserData(BeliefNet bn) {
int i;
BeliefNetNode bnn;
BNUserData netData;
for(i = 0 ; i < BNGetNumNodes(bn) ; i++) {
bnn = BNGetNodeByID(bn, i);
MFreePtr(BNNodeGetUserData(bnn));
}
netData=BNGetUserData(bn);
if (gParams->gCallbackAPI && gParams->gCallbackAPI->NetFree)
gParams->gCallbackAPI->NetFree(&netData->callbackAPIdata, bn);
MFreePtr(netData);
}
static void _InitUserData(BeliefNet bn, BeliefNet current, BNLAction action, int childId, int parentId) {
int i;
BeliefNetNode bnn, currentNode;
BNNodeUserData data;
BNUserData netData;
for(i = 0 ; i < BNGetNumNodes(bn) ; i++) {
bnn = BNGetNodeByID(bn, i);
data = MNewPtr(sizeof(BNNodeUserDataStruct));
BNNodeSetUserData(bnn, data);
data->avgDataLL = 0;
data->writebackScore=NULL;
data->writebackScoreIsValid=NULL;
data->scoreIsValid=FALSE;
data->isChangedFromCurrent = 0;
if(bn == current) { /* the init for the current net */
data->current = 0;
} else {
currentNode = BNGetNodeByID(current, i);
data->current = BNNodeGetUserData(currentNode);
}
}
netData = MNewPtr(sizeof(BNUserDataStruct));
BNSetUserData(bn, netData);
netData->changedOne = netData->changedTwo = 0;
netData->changeComplexity = 0;
if (gParams->gCallbackAPI && gParams->gCallbackAPI->NetInit)
gParams->gCallbackAPI->NetInit(&netData->callbackAPIdata, bn, current, action, childId, parentId);
}
static BeliefNetNode _BNGetNodeByID(BeliefNet bn, BeliefNet current, int id) {
BNUserData netData;
netData = BNGetUserData(bn);
if(netData->changedOne) {
if(BNNodeGetID(netData->changedOne) == id) {
return netData->changedOne;
}
}
if(netData->changedTwo) {
if(BNNodeGetID(netData->changedTwo) == id) {
return netData->changedTwo;
}
}
return BNGetNodeByID(current, id);
}
static int _IsLinkCoveredIn(BeliefNetNode parent, BeliefNetNode child,
BeliefNet bn) {
/* x->y is covered in bn if par(y) = par(x) U x */
/* this is important for checking equivilance because
G and G' are equivilant if identical except for reversal of x->y
iff x -> y is covered in G */
/* NOTE this is k^2 but could be k if the parent lists were kept sorted */
int i;
BeliefNetNode childParent;
int covered = 1;
if(BNNodeGetNumParents(parent) != BNNodeGetNumParents(child) - 1) {
covered = 0;
}
for(i = 0 ; i < BNNodeGetNumParents(child) && covered ; i++) {
childParent = BNNodeGetParent(child, i);
if((parent != childParent) && (!BNNodeHasParent(parent, childParent))) {
covered = 0;
}
}
return covered;
}
/********************************
Caching scores for speed
*********************************/
// Indexed by id of the node, then id of the additional parent.
double **gScoreWithAdditionalParent;
BOOL **gScoreWithAdditionalParentIsValid;
double **gScoreWithRemovedParent;
BOOL **gScoreWithRemovedParentIsValid;
// Indexed by id of the node
double *gScoreAsIs;
BOOL *gScoreAsIsIsValid;
void _InitScoreCache(BeliefNet bn)
{
int numNodes, i, j;
numNodes = BNGetNumNodes(bn);
gScoreWithAdditionalParentIsValid = MNewPtr(sizeof(BOOL*)*numNodes);
gScoreWithAdditionalParent = MNewPtr(sizeof(double*)*numNodes);
gScoreWithRemovedParentIsValid = MNewPtr(sizeof(BOOL*)*numNodes);
gScoreWithRemovedParent = MNewPtr(sizeof(double*)*numNodes);
gScoreAsIsIsValid = MNewPtr(sizeof(BOOL)*numNodes);
gScoreAsIs = MNewPtr(sizeof(double)*numNodes);
for (i=0; i<numNodes; i++) {
gScoreWithAdditionalParentIsValid[i] = MNewPtr(sizeof(BOOL)*numNodes);
gScoreWithAdditionalParent[i] = MNewPtr(sizeof(double)*numNodes);
gScoreWithRemovedParentIsValid[i] = MNewPtr(sizeof(BOOL)*numNodes);
gScoreWithRemovedParent[i] = MNewPtr(sizeof(double)*numNodes);
gScoreAsIsIsValid[i]=FALSE;
for (j=0; j<numNodes; j++) {
gScoreWithAdditionalParentIsValid[i][j] = FALSE;
gScoreWithRemovedParentIsValid[i][j] = FALSE;
}
}
}
void _InvalidateScoreCache(BeliefNet bn)
{
BNNodeUserData data;
BeliefNetNode node;
int i, j, numNodes;
numNodes = BNGetNumNodes(bn);
for (i=0; i<numNodes; i++) {
node = BNGetNodeByID(bn, i);
data = BNNodeGetUserData(node);
if (data->isChangedFromCurrent) {
DebugMessage(1, 3, "Invalidating cache for node %d\n", i);
for (j=0; j<numNodes; j++) {
gScoreWithAdditionalParentIsValid[i][j]=FALSE;
gScoreWithRemovedParentIsValid[i][j]=FALSE;
}
gScoreAsIsIsValid[i] = FALSE;
}
}
}
void _SetCachedScore_AddedParent(BNNodeUserData dataOfNewChild, int childId, int parentId) {
if (gScoreWithAdditionalParentIsValid[childId][parentId]) {
dataOfNewChild->avgDataLL = gScoreWithAdditionalParent[childId][parentId];
dataOfNewChild->scoreIsValid = TRUE;
}
else
{
DebugMessage(1, 3, "need to calc modifying %d by adding %d\n", childId, parentId);
dataOfNewChild->scoreIsValid = FALSE;
dataOfNewChild->writebackScore = &(gScoreWithAdditionalParent[childId][parentId]);
dataOfNewChild->writebackScoreIsValid = &(gScoreWithAdditionalParentIsValid[childId][parentId]);
}
}
void _SetCachedScore_RemovedParent(BNNodeUserData dataOfNewChild, int childId, int parentId) {
if (gScoreWithRemovedParentIsValid[childId][parentId]) {
dataOfNewChild->avgDataLL = gScoreWithRemovedParent[childId][parentId];
dataOfNewChild->scoreIsValid = TRUE;
}
else
{
DebugMessage(1, 3, "need to calc modifying %d by removing %d\n", childId, parentId);
dataOfNewChild->scoreIsValid = FALSE;
dataOfNewChild->writebackScore = &(gScoreWithRemovedParent[childId][parentId]);
dataOfNewChild->writebackScoreIsValid = &(gScoreWithRemovedParentIsValid[childId][parentId]);
}
}
void _SetCachedScore_NoParentChange(BNNodeUserData dataOfNewChild, int childId) {
if (gScoreAsIsIsValid[childId]) {
dataOfNewChild->avgDataLL = gScoreAsIs[childId];
dataOfNewChild->scoreIsValid = TRUE;
}
else
{
DebugMessage(1, 3, "need to calc %d \n", childId);
dataOfNewChild->scoreIsValid = FALSE;
dataOfNewChild->writebackScore = &(gScoreAsIs[childId]);
dataOfNewChild->writebackScoreIsValid = &(gScoreAsIsIsValid[childId]);
}
}
/****************************
** End of Caching routines
*****************************/
/*
2 possibilities with nodes (a, b):
no link : make new net for each with a parent link added
a -> b : for a make new net with no link for b new net & reverse link
remember: don't add any BNs with cycles!
also: don't add any BNs that are equivilent
*/
void _GetOneStepChoicesForBN(BeliefNet bn, VoidListPtr list) {
int i, j, srcParentOfDstIndex, dstParentOfSrcIndex;
BeliefNet newBN;
BeliefNetNode srcNode, dstNode;
BNNodeUserData data;
BNUserData netData;
int isLinkCovered;
long preAllocation;
_InitUserData(bn, bn, BNL_NO_CHANGE, 0, 0);
newBN = BNCloneNoCPTs(bn);
_InitUserData(newBN, bn, BNL_NO_CHANGE, 0, 0);
// mattr: The following loop may not be neccessary because any score
// checking will look at node->current->AvgDataLL instead of
// node->AvgDataLL.
for (i=0; i<BNGetNumNodes(bn); i++) {
_SetCachedScore_NoParentChange( BNNodeGetUserData( BNGetNodeByID(bn, i) ), i );
}
VLAppend(list, newBN);
if(gParams->gOnlyEstimateParameters) {
return;
}
for(i = 0 ; i < BNGetNumNodes(bn) ; i++) {
for(j = i + 1 ; j < BNGetNumNodes(bn) ; j++) {
srcNode = BNGetNodeByID(bn, i);
dstNode = BNGetNodeByID(bn, j);
dstParentOfSrcIndex = BNNodeLookupParentIndex(srcNode, dstNode);
srcParentOfDstIndex = BNNodeLookupParentIndex(dstNode, srcNode);
if(dstParentOfSrcIndex == -1 && srcParentOfDstIndex == -1) {
/* nodes unrelated, make 2 copies, one with link each way */
/* copy one */
preAllocation = MGetTotalAllocation();
newBN = BNCloneNoCPTs(bn);
BNFlushStructureCache(newBN);
srcNode = BNGetNodeByID(newBN, i);
dstNode = BNGetNodeByID(newBN, j);
BNNodeAddParent(srcNode, dstNode);
BNNodeInitCPT(srcNode);
_InitUserData(newBN, bn, BNL_ADDED_PARENT, i, j);
data = BNNodeGetUserData(srcNode);
data->isChangedFromCurrent = 1;
_SetCachedScore_AddedParent(data, i, j);
netData = BNGetUserData(newBN);
netData->changedOne = srcNode;
netData->changeComplexity = 2;
isLinkCovered = 0;
if((gParams->gMaxParentsPerNode == -1 ||
BNNodeGetNumParents(srcNode) <= gParams->gMaxParentsPerNode) &&
(gParams->gMaxParameterGrowthMult == -1 ||
(BNGetNumParameters(newBN) <
(gParams->gMaxParameterGrowthMult * gInitialParameterCount))) &&
(gParams->gMaxParameterCount == -1 ||
(BNNodeGetNumParameters(srcNode) <=
gParams->gMaxParameterCount)) &&
(gParams->gLimitBytes == -1 ||
((MGetTotalAllocation() - preAllocation) < gMaxBytesPerModel))&&
!BNHasCycle(newBN)) {
VLAppend(list, newBN);
isLinkCovered = _IsLinkCoveredIn(dstNode, srcNode, newBN);
} else {
_FreeUserData(newBN);
BNFree(newBN);
}
if(!isLinkCovered) {
/* if it is covered then copy two would be equivilant to
copy one, so don't bother adding it */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -