⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bnlearn-engine.c

📁 数据挖掘方面的源码
💻 C
📖 第 1 页 / 共 3 页
字号:
      data->avgDataLL -= numValues*gamma_1;
   }
   /* now scale it by the likelihood of the node given structural prior */
   data->avgDataLL += _GetStructuralDifferenceScoreNode(bnn);

   if (data->writebackScore) {
     *(data->writebackScore) = data->avgDataLL;
     *(data->writebackScoreIsValid) = TRUE;
   }



}

static void _UpdateCurrentNetScoreBD(BeliefNet bn) {
   int i;

   for(i = 0 ; i < BNGetNumNodes(bn) ; i++) {
      _UpdateNodeBD(BNGetNodeByID(bn, i));
   }
}

static void _UpdateNetScoreBD(BeliefNet bn) {
   BNUserData netData = BNGetUserData(bn);

   if(netData->changedOne) {
      _UpdateNodeBD(netData->changedOne);
   }
   if(netData->changedTwo) {
      _UpdateNodeBD(netData->changedTwo);
   }
}


static double _GetNodeScore(BeliefNetNode bnn) {
   BNNodeUserData data = BNNodeGetUserData(bnn);
   
   if(data->current == 0 || data->isChangedFromCurrent) {
      return data->avgDataLL;
   } else {
      return data->current->avgDataLL;
   }
}

float _ScoreBNOptimized(BeliefNet bn);

static int _IsFirstNetBetter(BeliefNet first, BeliefNet second) {
   BNUserData firstData = BNGetUserData(first);
   BNUserData secondData = BNGetUserData(second);
   int i1, i2, i;
   double scoreOne, scoreTwo;
   
   scoreOne = scoreTwo = 0;
   i1 = i2 = -1;

   if(firstData->changedOne) {
      i = i1 = BNNodeGetID(firstData->changedOne);
      scoreOne += _GetNodeScore(BNGetNodeByID(first, i));
      scoreTwo += _GetNodeScore(BNGetNodeByID(second, i));
   }

   if(firstData->changedTwo) {
      i = i2 = BNNodeGetID(firstData->changedTwo);
      scoreOne += _GetNodeScore(BNGetNodeByID(first, i));
      scoreTwo += _GetNodeScore(BNGetNodeByID(second, i));
   }

   if(secondData->changedOne) {
      i = BNNodeGetID(secondData->changedOne);
      if(i != i1 && i != i2) {
         /* don't compare same nodes twice */
         scoreOne += _GetNodeScore(BNGetNodeByID(first, i));
         scoreTwo += _GetNodeScore(BNGetNodeByID(second, i));
      }
   }

   if(secondData->changedTwo) {
      i = BNNodeGetID(secondData->changedTwo);
      if(i != i1 && i != i2) {
         /* don't compare same nodes twice */
         scoreOne += _GetNodeScore(BNGetNodeByID(first, i));
         scoreTwo += _GetNodeScore(BNGetNodeByID(second, i));
      }
   }

   return scoreOne > scoreTwo;
}

float _ScoreBNOptimized(BeliefNet bn) {
   int i;
   double score;

   score = 0;

   for(i = 0 ; i < BNGetNumNodes(bn) ; i++) {
      score += _GetNodeScore(BNGetNodeByID(bn, i));
   }

   return score;
}

static void _UpdateCPTsForFrom(BeliefNet target, BeliefNet source) {
   BeliefNetNode targetNode, sourceNode;
   BNNodeUserData data;
   int i;

   for(i = 0 ; i < BNGetNumNodes(target) ; i++) {
      targetNode = BNGetNodeByID(target, i);
      sourceNode = BNGetNodeByID(source, i);
      data = BNNodeGetUserData(targetNode);

      if(!data->isChangedFromCurrent) {
         BNNodeInitCPT(targetNode);
         BNNodeSetCPTFrom(targetNode, sourceNode);
      }
   }
}

static int _IsCurrentNet(BeliefNet bn) {
   BNUserData data;

   data = (BNUserData)BNGetUserData(bn);

   return (data->changedOne == 0) && (data->changedTwo == 0);
}

static void _CompareNetsFreeLoosers(BeliefNet current, 
                                           VoidListPtr netChoices) {
   BeliefNet netOne, netTwo;

   _UpdateCurrentNetScoreBD(current);

   netOne = VLRemove(netChoices, VLLength(netChoices) - 1);
   _UpdateNetScoreBD(netOne);
   while(VLLength(netChoices)) {
      netTwo = VLRemove(netChoices, VLLength(netChoices) - 1);
      _UpdateNetScoreBD(netTwo);
      if(_IsFirstNetBetter(netOne, netTwo)) {
         _FreeUserData(netTwo);
         BNFree(netTwo);
      } else {
         _FreeUserData(netOne);
         BNFree(netOne);
         netOne = netTwo;
      }
   }

   VLAppend(netChoices, netOne);
}

void _AllocFailed(int allocationSize) {
   printf("MEMORY ALLOCATION FAILED, size %d\n", allocationSize);
}

static int _IsTimeExpired(struct tms starttime) {
   struct tms endtime;
   long seconds;

   if(gParams->gLimitSeconds != -1) {
      times(&endtime);
      seconds = (double)(endtime.tms_utime - starttime.tms_utime) / 100.0;

      return seconds >= gParams->gLimitSeconds;
   }

   return 0;
}


void BNLearn(BNLearnParams *params) {
   ExampleSpecPtr es;
   int allDone, searchStep;
   BeliefNet bn;

   long learnTime, allocation, seenTotal;
   struct tms starttime;
   struct tms endtime;

   VoidListPtr netChoices;
   VoidListPtr previousWinners;

   gParams = params;
   previousWinners = VLNew();

   MSetAllocFailFunction(_AllocFailed);   

   // Set up the input network
   if (gParams->gInputNetMemory) {
      bn = gParams->gInputNetMemory;      
   }
   else if(gParams->gInputNetFile) {
      bn = BNReadBIFFILEP(gParams->gInputNetFile);
      if(bn == 0) {
         DebugMessage(1, 1, "couldn't read net\n");
      }
      BNZeroCPTs(bn);
   }
   else if(gParams->gInputNetEmptySpec) {
      bn = BNNewFromSpec(gParams->gInputNetEmptySpec);
   }
   else {
      DebugError(1,"Error, at least one of gInputNetMemory, gInputNetFile, or gInputNetEmptySpec must be specified");
      return;
   }
   assert(bn);
   es = BNGetExampleSpec(bn);

   if (!gParams->gDataMemory && !gParams->gDataFile) {
      DebugError(1,"Error, at least one of gDataMemory or gDataFile must be set");
      return;
   }

   gInitialParameterCount = BNGetNumParameters(bn);
   gBranchFactor = BNGetNumNodes(bn) * BNGetNumNodes(bn);
   if(gParams->gLimitBytes != -1) {
      gMaxBytesPerModel = gParams->gLimitBytes / gBranchFactor;
      DebugMessage(1, 2, "Limit models to %.4lf megs\n", 
                     gMaxBytesPerModel / (1024.0 * 1024.0));
   }

   gPriorNet = BNClone(bn);

   RandomInit();
   /* seed */
   if(gParams->gSeed != -1) {
      RandomSeed(gParams->gSeed);
   } else {
      gParams->gSeed = RandomRange(1, 30000);
      RandomSeed(gParams->gSeed);
   }

   DebugMessage(1, 1, "running with seed %d\n", gParams->gSeed);
   DebugMessage(1, 1, "allocation %ld\n", MGetTotalAllocation());
   DebugMessage(1, 1, "initial parameters %ld\n", gInitialParameterCount);

   times(&starttime);

   seenTotal = 0;
   learnTime = 0;

   _InitScoreCache(bn);
   allDone = 0;
   searchStep = 0;
   while(!allDone) {
      searchStep++;
      DebugMessage(1, 2, "============== Search step: %d ==============\n",
             searchStep);
      DebugMessage(1, 2, "  Total samples: %ld\n", seenTotal);
      DebugMessage(1, 2, " allocation before choices %ld\n", 
                                               MGetTotalAllocation());

      DebugMessage(1, 2, "   best with score %f:\n", _ScoreBN(bn));
      if(DebugGetMessageLevel() >= 2) {
         BNPrintStats(bn);
      }

      if(DebugGetMessageLevel() >= 3) {
         BNWriteBIF(bn, DebugGetTarget());
      }

      netChoices = VLNew();

      _GetOneStepChoicesForBN(bn, netChoices);

      DebugMessage(1, 2, " allocation after choices %ld there are %d\n", 
                         MGetTotalAllocation(), VLLength(netChoices));

      if (gParams->gDataMemory) {
         _OptimizedAddSamples(bn, netChoices, gParams->gDataMemory);
         seenTotal += VLLength(gParams->gDataMemory);
         if(_IsTimeExpired(starttime))
            allDone = 1;      
      }
      if (gParams->gDataFile) {
         
         int stepDone=0;
         ExamplePtr e = ExampleRead(gParams->gDataFile, es);
         _OptimizedAddSampleInit(bn, netChoices);    
         while(!stepDone && e != 0) {
            seenTotal++;
            /* put the eg in every net choice */
            _OptimizedAddSample(bn, netChoices, e);

            if(_IsTimeExpired(starttime)) {
               stepDone = allDone = 1;
            }
            ExampleFree(e);
            e = ExampleRead(gParams->gDataFile, es);
         } /* !stepDone && e != 0 */
      }


      _CompareNetsFreeLoosers(bn, netChoices);

      /* if the winner is the current one then we are all done */
      if(BNStructureEqual(bn, (BeliefNet)VLIndex(netChoices, 0)) ||
         !_IsFirstNetBetter((BeliefNet)VLIndex(netChoices,0), bn)) {
         /* make sure to free all loosing choices and the netChoices list */
         allDone = 1;
      } else if(gParams->gMaxSearchSteps != -1 && searchStep >= gParams->gMaxSearchSteps) {
         DebugMessage(1, 1, "Stopped because of search step limit\n");
         allDone = 1;
      }

      /* copy all the CPTs that are only in bn into the new winner */
       /* only really needed for final output but I do it for debugging too */
      _UpdateCPTsForFrom((BeliefNet)VLIndex(netChoices, 0), bn);
      _InvalidateScoreCache((BeliefNet)VLIndex(netChoices, 0));   

      if(gParams->gOnlyEstimateParameters) {
         allDone = 1;
      }

      if(gParams->gCheckModelCycle) {
         /* now check all previous winners */
         /* if we detect a cycle, pick the best that happens
               in the period of the cycle */
         if(!allDone) {
            int i;
            for(i = 0 ; i < VLLength(previousWinners) ; i++) {
               if(BNStructureEqual(bn, (BeliefNet)VLIndex(previousWinners, i))) {
                  allDone = 1;
               }

               if(allDone) {
                  if(_ScoreBN(bn) < 
                          _ScoreBN((BeliefNet)VLIndex(previousWinners, i))) {
                     bn = VLIndex(previousWinners, i);
                  }
               }
            }
         }
         VLAppend(previousWinners, BNClone(bn));
      }

      _FreeUserData(bn);
      BNFree(bn);
      bn = (BeliefNet)VLRemove(netChoices, 0);
      _FreeUserData(bn);
      VLFree(netChoices);

      DebugMessage(1, 2, " allocation after all free %ld\n", 
                                           MGetTotalAllocation());

      /* reset data file */
      if (gParams->gDataFile)
         rewind(gParams->gDataFile);

   } /* while !allDone */

   if (gParams->gSmoothAmount != 0) {
     BNSmoothProbabilities(bn, gParams->gSmoothAmount);
   }

   times(&endtime);
   learnTime += endtime.tms_utime - starttime.tms_utime;

   DebugMessage(1, 1, "done learning...\n");
   DebugMessage(1, 1, "time %.2lfs\n", ((double)learnTime) / 100);

   DebugMessage(1, 1, "Total Samples: %ld\n", seenTotal);
   if(DebugGetMessageLevel() >= 1) {
      DebugMessage(1, 1, "Samples per round:\n");
   }

   allocation = MGetTotalAllocation();

   //printf("Final score: %f\n", _ScoreBN(bn));
   
   if (gParams->gOutputNetFilename) {
      FILE *netOut = fopen(gParams->gOutputNetFilename, "w");
      BNWriteBIF(bn, netOut);
      fclose(netOut);
   }

   if (gParams->gOutputNetToMemory) {
      gParams->gOutputNetMemory = bn;
   }
   else
      BNFree(bn);

   //ExampleSpecFree(es);
   DebugMessage(1, 1, "   allocation %ld\n", MGetTotalAllocation());
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -