⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bnlearn-engine.c

📁 数据挖掘方面的源码
💻 C
📖 第 1 页 / 共 3 页
字号:
#include "bnlearn-engine.h"
#include <stdio.h>
#include <string.h>

#include <assert.h>

#include <sys/times.h>
#include <unistd.h>
#include <time.h>
#include <math.h>

#define BOOL char
#define TRUE 1
#define FALSE 0

long  gInitialParameterCount=0;
long  gBranchFactor=0;
long  gMaxBytesPerModel=0;
BeliefNet gPriorNet=0;


BNLearnParams *BNLearn_NewParams() {
   BNLearnParams *param = (BNLearnParams*)MNewPtr(sizeof(BNLearnParams));

   param->gDataMemory             = NULL;
   param->gDataFile               = NULL;

   param->gInputNetMemory         = NULL;
   param->gInputNetFile           = NULL;
   param->gInputNetEmptySpec      = NULL;

   param->gOutputNetFilename      = NULL;
   param->gOutputNetToMemory      = 0;   
   param->gOutputNetMemory        = NULL;  

   param->gLimitBytes             = -1;
   param->gLimitSeconds           = -1;
   param->gNoReverse              = 0;
   param->gKappa                  = 0.5;
   param->gOnlyEstimateParameters = 0;
   param->gMaxSearchSteps         = -1;
   param->gMaxParentsPerNode      = -1;
   param->gMaxParameterGrowthMult = -1;
   param->gMaxParameterCount      = -1;
   param->gSeed                   = -1;
   param->gSmoothAmount           = 1.0;
   param->gCheckModelCycle        = 0;
   param->gCallbackAPI            = NULL;

   return param;
}

void BNLearn_FreeParams(BNLearnParams *params) {
   MFreePtr(params);
}



BNLearnParams *gParams;


typedef struct BNUserData_ {
   BeliefNetNode changedOne;
   BeliefNetNode changedTwo;

   /* -1 for removal, 0 for nothing, 1 for reverse, 2 for add */
   int changeComplexity;

   // data passed into the callback API, if one exists
   void *callbackAPIdata;   
} BNUserDataStruct, *BNUserData;


typedef struct BNNodeUserData_ {
   double         avgDataLL;
   int            isChangedFromCurrent;
   struct BNNodeUserData_ *current; // null if this is a current node

   // For caching scores:
   double *writebackScore;      //  when you compute the score, write it here
   BOOL *writebackScoreIsValid; //  when you compute the score, set this to valid
   BOOL scoreIsValid; // is avgDataLL already a valid score?

} BNNodeUserDataStruct, *BNNodeUserData;

static void _FreeUserData(BeliefNet bn) {
   int i;
   BeliefNetNode bnn;
   BNUserData netData;

   for(i = 0 ; i < BNGetNumNodes(bn) ; i++) {
      bnn = BNGetNodeByID(bn, i);
      
      MFreePtr(BNNodeGetUserData(bnn));
   }
   netData=BNGetUserData(bn);
   if (gParams->gCallbackAPI && gParams->gCallbackAPI->NetFree)
      gParams->gCallbackAPI->NetFree(&netData->callbackAPIdata, bn);

   MFreePtr(netData);
}



static void _InitUserData(BeliefNet bn, BeliefNet current, BNLAction action, int childId, int parentId) {
   int i;
   BeliefNetNode bnn, currentNode;
   BNNodeUserData data;
   BNUserData netData;

   for(i = 0 ; i < BNGetNumNodes(bn) ; i++) {
      bnn = BNGetNodeByID(bn, i);
      
      data = MNewPtr(sizeof(BNNodeUserDataStruct));
      BNNodeSetUserData(bnn, data);

      data->avgDataLL = 0;
      data->writebackScore=NULL;
      data->writebackScoreIsValid=NULL;
      data->scoreIsValid=FALSE;

      data->isChangedFromCurrent = 0;

      if(bn == current) { /* the init for the current net */
         data->current = 0;
      } else {
         currentNode = BNGetNodeByID(current, i);
         data->current = BNNodeGetUserData(currentNode);
      }
   }

   netData = MNewPtr(sizeof(BNUserDataStruct));
   BNSetUserData(bn, netData);
   netData->changedOne = netData->changedTwo = 0;
   netData->changeComplexity = 0;
   if (gParams->gCallbackAPI && gParams->gCallbackAPI->NetInit)
      gParams->gCallbackAPI->NetInit(&netData->callbackAPIdata, bn, current, action, childId, parentId);
}


static BeliefNetNode _BNGetNodeByID(BeliefNet bn, BeliefNet current, int id) {
   BNUserData netData;

   netData = BNGetUserData(bn);

   if(netData->changedOne) {
      if(BNNodeGetID(netData->changedOne) == id) {
         return netData->changedOne;
      }
   }
   if(netData->changedTwo) {
      if(BNNodeGetID(netData->changedTwo) == id) {
         return netData->changedTwo;
      }
   }

   return BNGetNodeByID(current, id);
}

static int _IsLinkCoveredIn(BeliefNetNode parent, BeliefNetNode child, 
                                                      BeliefNet bn) {
   /* x->y is covered in bn if par(y) = par(x) U x */
   /* this is important for checking equivilance because 
        G and G' are equivilant if identical except for reversal of x->y
                           iff x -> y is covered in G */
   /* NOTE this is k^2 but could be k if the parent lists were kept sorted */

   int i;
   BeliefNetNode childParent;
   int covered = 1;

   if(BNNodeGetNumParents(parent) != BNNodeGetNumParents(child) - 1) {
      covered = 0;
   }

   for(i = 0 ; i < BNNodeGetNumParents(child) && covered ; i++) {
      childParent = BNNodeGetParent(child, i);
      if((parent != childParent) && (!BNNodeHasParent(parent, childParent))) {
         covered = 0;
      }
   }

   return covered;
}



/********************************
  Caching scores for speed
  *********************************/
// Indexed by id of the node, then id of the additional parent.
double **gScoreWithAdditionalParent;
BOOL **gScoreWithAdditionalParentIsValid;
double **gScoreWithRemovedParent;
BOOL **gScoreWithRemovedParentIsValid;
// Indexed by id of the node
double *gScoreAsIs;
BOOL *gScoreAsIsIsValid;

void _InitScoreCache(BeliefNet bn)
{
  int numNodes, i, j;
  numNodes = BNGetNumNodes(bn);
  gScoreWithAdditionalParentIsValid = MNewPtr(sizeof(BOOL*)*numNodes);
  gScoreWithAdditionalParent = MNewPtr(sizeof(double*)*numNodes);
  gScoreWithRemovedParentIsValid = MNewPtr(sizeof(BOOL*)*numNodes);
  gScoreWithRemovedParent = MNewPtr(sizeof(double*)*numNodes);
  gScoreAsIsIsValid = MNewPtr(sizeof(BOOL)*numNodes);
  gScoreAsIs = MNewPtr(sizeof(double)*numNodes);

  for (i=0; i<numNodes; i++) {
    gScoreWithAdditionalParentIsValid[i] = MNewPtr(sizeof(BOOL)*numNodes);
    gScoreWithAdditionalParent[i] = MNewPtr(sizeof(double)*numNodes);
    gScoreWithRemovedParentIsValid[i] = MNewPtr(sizeof(BOOL)*numNodes);
    gScoreWithRemovedParent[i] = MNewPtr(sizeof(double)*numNodes);

    gScoreAsIsIsValid[i]=FALSE;
    for (j=0; j<numNodes; j++) {
      gScoreWithAdditionalParentIsValid[i][j] = FALSE;
      gScoreWithRemovedParentIsValid[i][j] = FALSE;
    }
  }
}


void _InvalidateScoreCache(BeliefNet bn)
{
  BNNodeUserData data;
  BeliefNetNode node;
  int i, j, numNodes;

  numNodes = BNGetNumNodes(bn);
  for (i=0; i<numNodes; i++) {
    node = BNGetNodeByID(bn, i);
    data = BNNodeGetUserData(node);
    if (data->isChangedFromCurrent) {
       DebugMessage(1, 3, "Invalidating cache for node %d\n", i);
       for (j=0; j<numNodes; j++) {
         gScoreWithAdditionalParentIsValid[i][j]=FALSE;
         gScoreWithRemovedParentIsValid[i][j]=FALSE;
       }
       gScoreAsIsIsValid[i] = FALSE;
    }
  }
}


void _SetCachedScore_AddedParent(BNNodeUserData dataOfNewChild, int childId, int parentId) {
  if (gScoreWithAdditionalParentIsValid[childId][parentId]) {
    dataOfNewChild->avgDataLL = gScoreWithAdditionalParent[childId][parentId];
    dataOfNewChild->scoreIsValid = TRUE;
  }
  else
  {
     DebugMessage(1, 3, "need to calc modifying %d by adding %d\n", childId, parentId);
    dataOfNewChild->scoreIsValid = FALSE;
    dataOfNewChild->writebackScore = &(gScoreWithAdditionalParent[childId][parentId]);
    dataOfNewChild->writebackScoreIsValid = &(gScoreWithAdditionalParentIsValid[childId][parentId]);
  }    
}


void _SetCachedScore_RemovedParent(BNNodeUserData dataOfNewChild, int childId, int parentId) {
  if (gScoreWithRemovedParentIsValid[childId][parentId]) {
    dataOfNewChild->avgDataLL = gScoreWithRemovedParent[childId][parentId];
    dataOfNewChild->scoreIsValid = TRUE;
  }
  else
  {
     DebugMessage(1, 3, "need to calc modifying %d by removing %d\n", childId, parentId);
    dataOfNewChild->scoreIsValid = FALSE;
    dataOfNewChild->writebackScore = &(gScoreWithRemovedParent[childId][parentId]);
    dataOfNewChild->writebackScoreIsValid = &(gScoreWithRemovedParentIsValid[childId][parentId]);
  }    
}

void _SetCachedScore_NoParentChange(BNNodeUserData dataOfNewChild, int childId) {
  if (gScoreAsIsIsValid[childId]) {
    dataOfNewChild->avgDataLL = gScoreAsIs[childId];
    dataOfNewChild->scoreIsValid = TRUE;
  }
  else
  {
    DebugMessage(1, 3, "need to calc %d \n", childId);
    dataOfNewChild->scoreIsValid = FALSE;
    dataOfNewChild->writebackScore = &(gScoreAsIs[childId]);
    dataOfNewChild->writebackScoreIsValid = &(gScoreAsIsIsValid[childId]);
  }    
}

/****************************
** End of Caching routines
*****************************/



/* 
  2 possibilities with nodes (a, b):
   no link : make new net for each with a parent link added
   a -> b  : for a make new net with no link for b new net & reverse link

   remember: don't add any BNs with cycles!
   also: don't add any BNs that are equivilent
*/
void _GetOneStepChoicesForBN(BeliefNet bn, VoidListPtr list) {
   int i, j, srcParentOfDstIndex, dstParentOfSrcIndex;
   BeliefNet newBN;
   BeliefNetNode srcNode, dstNode;
   BNNodeUserData data;
   BNUserData netData;
   int isLinkCovered;
   long preAllocation;

   _InitUserData(bn, bn, BNL_NO_CHANGE, 0, 0);

   newBN = BNCloneNoCPTs(bn);
   _InitUserData(newBN, bn, BNL_NO_CHANGE, 0, 0);
   // mattr: The following loop may not be neccessary because any score
   //  checking will look at node->current->AvgDataLL instead of 
   //  node->AvgDataLL.
   for (i=0; i<BNGetNumNodes(bn); i++) {
      _SetCachedScore_NoParentChange( BNNodeGetUserData( BNGetNodeByID(bn, i) ), i );
   }
   VLAppend(list, newBN);

   if(gParams->gOnlyEstimateParameters) {
      return;
   }

   for(i = 0 ; i < BNGetNumNodes(bn) ; i++) {
      for(j = i + 1 ; j < BNGetNumNodes(bn) ; j++) {
         srcNode = BNGetNodeByID(bn, i);
         dstNode = BNGetNodeByID(bn, j);

         dstParentOfSrcIndex = BNNodeLookupParentIndex(srcNode, dstNode);
         srcParentOfDstIndex = BNNodeLookupParentIndex(dstNode, srcNode);

         if(dstParentOfSrcIndex == -1 && srcParentOfDstIndex == -1) {
            /* nodes unrelated, make 2 copies, one with link each way */

            /* copy one */
            preAllocation = MGetTotalAllocation();
            newBN = BNCloneNoCPTs(bn);

            BNFlushStructureCache(newBN);

            srcNode = BNGetNodeByID(newBN, i);
            dstNode = BNGetNodeByID(newBN, j);

            BNNodeAddParent(srcNode, dstNode);
            BNNodeInitCPT(srcNode);

            _InitUserData(newBN, bn, BNL_ADDED_PARENT, i, j);
            data = BNNodeGetUserData(srcNode);
            data->isChangedFromCurrent = 1;
            _SetCachedScore_AddedParent(data, i, j);
            netData = BNGetUserData(newBN);
            netData->changedOne = srcNode;
            netData->changeComplexity = 2;

            isLinkCovered = 0;

            if((gParams->gMaxParentsPerNode == -1 ||
                   BNNodeGetNumParents(srcNode) <= gParams->gMaxParentsPerNode) &&

                   (gParams->gMaxParameterGrowthMult == -1 ||
                    (BNGetNumParameters(newBN) < 
                        (gParams->gMaxParameterGrowthMult * gInitialParameterCount))) &&

                   (gParams->gMaxParameterCount == -1 ||
                    (BNNodeGetNumParameters(srcNode) <= 
                        gParams->gMaxParameterCount)) &&

             (gParams->gLimitBytes == -1 ||
              ((MGetTotalAllocation() - preAllocation) <  gMaxBytesPerModel))&&

                   !BNHasCycle(newBN)) {
               VLAppend(list, newBN);
               isLinkCovered = _IsLinkCoveredIn(dstNode, srcNode, newBN);
            } else {
               _FreeUserData(newBN);
               BNFree(newBN);
            }

            if(!isLinkCovered) {
               /* if it is covered then copy two would be equivilant to
                        copy one, so don't bother adding it */

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -