⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bnlearn-engine.c

📁 数据挖掘方面的源码
💻 C
📖 第 1 页 / 共 3 页
字号:
               /* copy two */
               preAllocation = MGetTotalAllocation();
               newBN = BNCloneNoCPTs(bn);
               BNFlushStructureCache(newBN);

               srcNode = BNGetNodeByID(newBN, i);
               dstNode = BNGetNodeByID(newBN, j);

               BNNodeAddParent(dstNode, srcNode);
               BNNodeInitCPT(dstNode);

               _InitUserData(newBN, bn, BNL_ADDED_PARENT, j, i);
               data = BNNodeGetUserData(dstNode);
               data->isChangedFromCurrent = 1;
               _SetCachedScore_AddedParent(data, j, i);
               netData = BNGetUserData(newBN);
               netData->changedOne = dstNode;
               netData->changeComplexity = 2;

               if((gParams->gMaxParentsPerNode == -1 ||
                      BNNodeGetNumParents(dstNode) <= gParams->gMaxParentsPerNode) &&

                    (gParams->gMaxParameterGrowthMult == -1 ||
                       (BNGetNumParameters(newBN) < 
                        (gParams->gMaxParameterGrowthMult * gInitialParameterCount))) &&

                   (gParams->gMaxParameterCount == -1 ||
                    (BNNodeGetNumParameters(dstNode) <= 
                        gParams->gMaxParameterCount)) &&

             (gParams->gLimitBytes == -1 ||
              ((MGetTotalAllocation() - preAllocation) <  gMaxBytesPerModel))&&

                      !BNHasCycle(newBN)) {
                  VLAppend(list, newBN);
               } else {
                  _FreeUserData(newBN);
                  BNFree(newBN);
               }
            }
         } else {
            /* one is a parent of the other, make 2 copies, one with
                     link removed, one with link reversed */

            /* copy one: remove */
            /* removing a link changes the structure and will not
                        be equivilant to any of the candidates */
            newBN = BNCloneNoCPTs(bn);
            BNFlushStructureCache(newBN);

            preAllocation = MGetTotalAllocation();

            srcNode = BNGetNodeByID(newBN, i);
            dstNode = BNGetNodeByID(newBN, j);

            if(dstParentOfSrcIndex != -1) {
               BNNodeRemoveParent(srcNode, dstParentOfSrcIndex);
               BNNodeInitCPT(srcNode);

               _InitUserData(newBN, bn, BNL_REMOVED_PARENT, i, j);
               data = BNNodeGetUserData(srcNode);
               data->isChangedFromCurrent = 1;
               _SetCachedScore_RemovedParent(data, i, j);
               netData = BNGetUserData(newBN);
               netData->changedOne = srcNode;
               netData->changeComplexity = -1;
            } else {
               BNNodeRemoveParent(dstNode, srcParentOfDstIndex);
               BNNodeInitCPT(dstNode);

               _InitUserData(newBN, bn, BNL_REMOVED_PARENT, j, i);
               data = BNNodeGetUserData(dstNode);
               data->isChangedFromCurrent = 1;
               _SetCachedScore_RemovedParent(data, j, i);
               netData = BNGetUserData(newBN);
               netData->changedOne = dstNode;
               netData->changeComplexity = -1;
            }

            if(!BNHasCycle(newBN) &&

               (gParams->gMaxParameterGrowthMult == -1 ||
                  (BNGetNumParameters(newBN) < 
                  (gParams->gMaxParameterGrowthMult * gInitialParameterCount))) &&

               (gParams->gMaxParameterCount == -1 ||
                  (BNNodeGetNumParameters(srcNode) <= 
                   gParams->gMaxParameterCount  &&
                  BNNodeGetNumParameters(dstNode) <= 
                                         gParams->gMaxParameterCount))&&


               (gParams->gLimitBytes == -1 ||
                    ((MGetTotalAllocation() - preAllocation) <  
                                                gMaxBytesPerModel))) {
               VLAppend(list, newBN);
            } else {
               _FreeUserData(newBN);
               BNFree(newBN);
            }

            /* copy two: reverse */
            if(!gParams->gNoReverse) {
               preAllocation = MGetTotalAllocation();
               newBN = BNCloneNoCPTs(bn);
               BNFlushStructureCache(newBN);

               srcNode = BNGetNodeByID(newBN, i);
               dstNode = BNGetNodeByID(newBN, j);

               if(dstParentOfSrcIndex != -1) {
                  BNNodeRemoveParent(srcNode, dstParentOfSrcIndex);
                  BNNodeAddParent(dstNode, srcNode);
                  BNNodeInitCPT(srcNode);
                  BNNodeInitCPT(dstNode);

                  _InitUserData(newBN, bn, BNL_REVERSED_PARENT, i, j);
                  data = BNNodeGetUserData(srcNode);
                  data->isChangedFromCurrent = 1;
                  _SetCachedScore_RemovedParent(data, i, j);                  
                  data = BNNodeGetUserData(dstNode);
                  data->isChangedFromCurrent = 1;
                  _SetCachedScore_AddedParent(data, j, i);  
                  netData = BNGetUserData(newBN);
                  netData->changedOne = srcNode;
                  netData->changedTwo = dstNode;
                  netData->changeComplexity = 1;

                  isLinkCovered = _IsLinkCoveredIn(srcNode, dstNode, newBN);
               } else {
                  BNNodeRemoveParent(dstNode, srcParentOfDstIndex);
                  BNNodeAddParent(srcNode, dstNode);
                  BNNodeInitCPT(srcNode);
                  BNNodeInitCPT(dstNode);

                  _InitUserData(newBN, bn, BNL_REVERSED_PARENT, j, i);
                  data = BNNodeGetUserData(srcNode);
                  data->isChangedFromCurrent = 1;
                  _SetCachedScore_AddedParent(data, i, j);

                  data = BNNodeGetUserData(dstNode);
                  data->isChangedFromCurrent = 1;
                  _SetCachedScore_RemovedParent(data, j, i);
                  netData = BNGetUserData(newBN);
                  netData->changedOne = srcNode;
                  netData->changedTwo = dstNode;
                  netData->changeComplexity = 1;

                  isLinkCovered = _IsLinkCoveredIn(dstNode, srcNode, newBN);
               }

               if((gParams->gMaxParentsPerNode == -1 ||
                      (BNNodeGetNumParents(srcNode) <= gParams->gMaxParentsPerNode &&
                       BNNodeGetNumParents(dstNode) <= gParams->gMaxParentsPerNode)) &&

                !isLinkCovered &&

             (gParams->gLimitBytes == -1 ||
              ((MGetTotalAllocation() - preAllocation) <  gMaxBytesPerModel))&&

                (gParams->gMaxParameterGrowthMult == -1 ||
                    (BNGetNumParameters(newBN) < 
                        (gParams->gMaxParameterGrowthMult * gInitialParameterCount))) &&

               (gParams->gMaxParameterCount == -1 ||
                  (BNNodeGetNumParameters(srcNode) <= 
                   gParams->gMaxParameterCount  &&
                  BNNodeGetNumParameters(dstNode) <= 
                                         gParams->gMaxParameterCount))&&

                !BNHasCycle(newBN)) {
                  VLAppend(list, newBN);
               } else {
                  _FreeUserData(newBN);
                  BNFree(newBN);
               }
            }
         }
      }
   }
}

static void _OptimizedAddSampleInit(BeliefNet current, VoidListPtr newNets) {
   BNNodeUserData nodeData;
   int i;

   for(i=0; i<VLLength(current->nodes); i++) {
      BeliefNetNode node = (BeliefNetNode)VLIndex(current->nodes, i);
      nodeData = BNNodeGetUserData(node);
      if (!nodeData->scoreIsValid) {
         BNNodeZeroCPT(node);
      }
   }

   for(i = 0 ; i < VLLength(newNets) ; i++) {
      BNUserData netData = BNGetUserData(VLIndex(newNets, i));

      if(netData->changedOne) {
         nodeData = BNNodeGetUserData(netData->changedOne);
         if (!nodeData->scoreIsValid) {
            BNNodeZeroCPT(netData->changedOne);
         }
      }
      if(netData->changedTwo) {
         nodeData = BNNodeGetUserData(netData->changedTwo);
         if (!nodeData->scoreIsValid) {
            BNNodeZeroCPT(netData->changedTwo);
         }
      }
   }
}


static void _OptimizedAddSample(BeliefNet current, VoidListPtr newNets, 
                                               ExamplePtr e) {
   BNUserData netData;
   BNNodeUserData nodeData;
   int i;

   //BNAddSample(current, e);
   for(i=0; i<VLLength(current->nodes); i++) {
      BeliefNetNode node = (BeliefNetNode)VLIndex(current->nodes, i);
      nodeData = BNNodeGetUserData(node);
      if (!nodeData->scoreIsValid) {
         BNNodeAddSample(node, e);
      }
   }

   for(i = 0 ; i < VLLength(newNets) ; i++) {
      netData = BNGetUserData(VLIndex(newNets, i));

      if(netData->changedOne) {
         nodeData = BNNodeGetUserData(netData->changedOne);
         if (!nodeData->scoreIsValid)
            BNNodeAddSample(netData->changedOne, e);
      }
      if(netData->changedTwo) {
         nodeData = BNNodeGetUserData(netData->changedTwo);
         if (!nodeData->scoreIsValid)
            BNNodeAddSample(netData->changedTwo, e);
      }
   }
}

static void _OptimizedAddSamples(BeliefNet current, VoidListPtr newNets, 
                                               VoidListPtr samples) {
   BNUserData netData;
   BNNodeUserData nodeData;
   int i;

   //BNAddSamples(current, samples);
   for(i=0; i<VLLength(current->nodes); i++) {
      BeliefNetNode node = (BeliefNetNode)VLIndex(current->nodes, i);
      nodeData = BNNodeGetUserData(node);
      if (!nodeData->scoreIsValid) {
         BNNodeZeroCPT(node);
         BNNodeAddSamples(node, samples);
      }
   }

   for(i = 0 ; i < VLLength(newNets) ; i++) {
      netData = BNGetUserData(VLIndex(newNets, i));

      if(netData->changedOne) {
         nodeData = BNNodeGetUserData(netData->changedOne);
         if (!nodeData->scoreIsValid) {
            BNNodeZeroCPT(netData->changedOne);
            BNNodeAddSamples(netData->changedOne, samples); 
         }
      }
      if(netData->changedTwo) {
         nodeData = BNNodeGetUserData(netData->changedTwo);
         if (!nodeData->scoreIsValid) {
            BNNodeZeroCPT(netData->changedTwo);
            BNNodeAddSamples(netData->changedTwo, samples); 
         }
      }
   }
}

static float _ScoreBN(BeliefNet bn) {
   int numCPTRows;
   int i,j,k;
   BeliefNetNode bnn;
   double score;
   double prob, numSamples;

   /* score is sum over atribs, over parent combos, over attrib value of:
         p_ijk lg P_ijk - p_ij lg p_ij */


   score = 0;

   for(i = 0 ; i < BNGetNumNodes(bn) ; i++) {
      bnn = BNGetNodeByID(bn, i);

      numSamples = BNNodeGetNumSamples(bnn);
      numCPTRows = BNNodeGetNumCPTRows(bnn);

      for(j = 0 ; j < numCPTRows ; j++) {
         /* HACK for efficiency break BNN ADT */
         for(k = 0 ; k < BNNodeGetNumValues(bnn) ; k++) {
            if(numSamples != 0) {
               prob = bnn->eventCounts[j][k] / numSamples;
               DebugMessage(1, 4,
                 "   i: %d j: %d k: %d eventcount: %lf rowcount: %lf\n",
                     i, j, k, bnn->eventCounts[j][k],
                       bnn->rowCounts[j]);
               if(prob != 0) {
                  score += prob * log(prob);
                  DebugMessage(1, 4, "   Score now: %lf\n", score);
               }
            }
         }
         if(numSamples != 0) {
            prob = bnn->rowCounts[j] / BNNodeGetNumSamples(bnn);
            DebugMessage(1, 4,
              "   i: %d j: %d rowcount: %lf numsamples: %lf\n",
                  i, j, bnn->rowCounts[j], 
                            numSamples);
            if(prob != 0) {
               score -= prob * log(prob);
               DebugMessage(1, 4, "   Score now: %lf\n", score);
            }
         }
      }
   }

   return score;
}

static double _GetStructuralDifferenceScoreNode(BeliefNetNode bnn) {
   int difference = 0;
   int i;
   BeliefNetNode priorNode;

   priorNode = BNGetNodeByID(gPriorNet, BNNodeGetID(bnn));
   for(i = 0 ; i < BNNodeGetNumParents(bnn) ; i++) {
      if(!BNNodeHasParentID(priorNode, BNNodeGetParentID(bnn, i))) {
         difference++;
      }
   }
   for(i = 0 ; i < BNNodeGetNumParents(priorNode) ; i++) {
      if(!BNNodeHasParentID(bnn, BNNodeGetParentID(priorNode, i))) {
         difference++;
      }
   }

   return difference * log(gParams->gKappa);
}


void _UpdateNodeBD(BeliefNetNode bnn) {
   int numCPTRows, numValues;
   BNNodeUserData data;
   int j, k;
   double numSamples;
   double gamma_numValues, gamma_1;

   data = BNNodeGetUserData(bnn);
   if (data->scoreIsValid) {
     //printf("cached value is %g\n", data->avgDataLL);
     return;
   }

   if (gParams->gCallbackAPI && gParams->gCallbackAPI->NodeUpdateBD) {
      BNUserData netData = BNGetUserData(bnn->bn);
      gParams->gCallbackAPI->NodeUpdateBD(&netData->callbackAPIdata, bnn, bnn->eventCounts, bnn->rowCounts);
   }

   data->avgDataLL = 0;
   numCPTRows = BNNodeGetNumCPTRows(bnn);
   numSamples = BNNodeGetNumSamples(bnn);
   numValues = BNNodeGetNumValues(bnn);     

   gamma_numValues = StatLogGamma(numValues);
   gamma_1 = StatLogGamma(1);
   for(j = 0 ; j < numCPTRows ; j++) {
      /* HACK for efficiency break BNN ADT */

      /* HERE HERE update this to use the probabilities from the
               prior network */
      data->avgDataLL += gamma_numValues -
         StatLogGamma(numValues + bnn->rowCounts[j]);
      for(k = 0 ; k < numValues ; k++) {
         data->avgDataLL += StatLogGamma(1 + bnn->eventCounts[j][k]);
      }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -