📄 modelsearch.c
字号:
#include "ModelSearch.h"#include <math.h>/* lame hack globals */extern int gMaxParentsPerNode;extern long gMaxParameterCount;extern long gMaxBytesPerModel;extern int gLimitBytes;extern float gDelta;extern int gP0Multiplier;extern long gNumBoundsUsed;extern float gTau;extern int gAllowRemove;extern long gNumByTie;extern long gNumByWin;extern long gMinSamplesInDecision;extern long gMaxSamplesInDecision;extern long gTotalSamplesInDecision;extern long gNumZeroSamplesInDecision;extern long gNumCurrentInTie;extern long gNumCurrentByWin;extern long gNumCurrentByDefault;extern long gNumByCycleConflict;extern long gNumByParameterConflict;extern long gNumByParentLimit;extern long gNumByMemoryLimit;extern long gNumByParameterLimit;extern long gNumByChangeLimit;extern long gNumRemoved;extern long gNumAdded;extern long gNumUsedWholeDB;typedef struct BNUserData_ { BeliefNetNode changedOne; int changedParentID; /* -1 for removal, 0 for nothing, 1 for reverse, 2 for add */ int changeComplexity; BeliefNet currentNet;} BNUserDataStruct, *BNUserData;typedef struct BNNodeUserData_ { double avgDataLL;// double score;// double upperBound;// double lowerBound; double p0; int isChangedFromCurrent; struct BNNodeUserData_ *current; // null if this is a current node} BNNodeUserDataStruct, *BNNodeUserData;static void _FreeUserData(BNUserData netData) { if(netData->changedOne) { MFreePtr(BNNodeGetUserData(netData->changedOne)); BNNodeFree(netData->changedOne); } MFreePtr(netData);}int _BNHasCycle(BNUserData netData) { //BeliefNet newBN; BeliefNetNode changedNode, parentNode; int hasCycle; if(netData->changeComplexity == -1 || netData->changeComplexity == 0) { hasCycle = BNHasCycle(netData->currentNet); } else { changedNode = BNGetNodeByID(netData->currentNet, BNNodeGetID(netData->changedOne)); parentNode = BNGetNodeByID(netData->currentNet, netData->changedParentID); if(BNNodeGetNumChildren(changedNode) == 0 || BNNodeGetNumParents(parentNode) == 0) { hasCycle = 0; } else { //newBN = BNCloneNoCPTs(netData->currentNet); //changedNode = BNGetNodeByID(newBN, BNNodeGetID(netData->changedOne)); //parentNode = BNGetNodeByID(newBN, netData->changedParentID); if(netData->changeComplexity == 2) { BNFlushStructureCache(netData->currentNet); BNNodeAddParent(changedNode, parentNode); hasCycle = BNHasCycle(netData->currentNet); BNNodeRemoveParent(changedNode, BNNodeLookupParentIndex(changedNode, parentNode)); BNFlushStructureCache(netData->currentNet); } else { /* must be a remove, but that is special cased elsewhere and can't cause a cycle */ //BNNodeRemoveParent(changedNode, // BNNodeLookupParentIndex(changedNode, parentNode)); hasCycle = 0; } //hasCycle = BNHasCycle(newBN); //BNFree(newBN); } } return hasCycle;}BeliefNetNode _BNNodeCloneNoCPT(BeliefNetNode bnn, BeliefNet newBN);static BNUserData _InitUserData(int nodeID, int doChange, int changedParentID, int isAdd, BeliefNet current) { BNNodeUserData data; BNUserData netData; BeliefNetNode changedBnn, currentBnn; int changeComplexity = 0; if(doChange) { changedBnn = _BNNodeCloneNoCPT(BNGetNodeByID(current, nodeID), current); data = MNewPtr(sizeof(BNNodeUserDataStruct)); BNNodeSetUserData(changedBnn, data); data->avgDataLL = 0; data->isChangedFromCurrent = 1; data->p0 = 1; currentBnn = BNGetNodeByID(current, nodeID); data->current = BNNodeGetUserData(currentBnn); if(isAdd) { changeComplexity = 2; /* HERE break BN ADT right here */ VLAppend(changedBnn->parentIDs, (void *)changedParentID); changedBnn->numParentCombinations = -1; BNNodeInitCPT(changedBnn); } else { /* remove */ changeComplexity = -1; /* HERE break BN ADT right here */ VLRemove(changedBnn->parentIDs, BNNodeLookupParentIndexByID(changedBnn, changedParentID)); changedBnn->numParentCombinations = -1; BNNodeInitCPT(changedBnn); } /* HERE HERE Heuristic hack */ //BNNodeSmoothProbabilities(changedBnn, 1.0 / // (float)BNNodeGetNumValues(changedBnn)); } else { changedBnn = 0; } netData = MNewPtr(sizeof(BNUserDataStruct)); netData->changedOne = changedBnn; netData->changedParentID = changedParentID; netData->changeComplexity = changeComplexity; netData->currentNet = current; return netData;}static void _InitCurrentUserData(BeliefNet current) { int i; BeliefNetNode bnn; BNNodeUserData data; if(BNGetUserData(current) == 0) { for(i = 0 ; i < BNGetNumNodes(current) ; i++) { bnn = BNGetNodeByID(current, i); data = MNewPtr(sizeof(BNNodeUserDataStruct)); BNNodeSetUserData(bnn, data); data->avgDataLL = 0; data->isChangedFromCurrent = 0; data->p0 = 1; data->current = 0; } /* flag to keep it from getting inited too many times */ BNSetUserData(current, (void *)1); }}static BeliefNetNode _BNGetNodeByID(BNUserData netData, int id) { if(netData->changedOne) { if(BNNodeGetID(netData->changedOne) == id) { return netData->changedOne; } } return BNGetNodeByID(netData->currentNet, id);}static double _CalculateCP(BeliefNetNode bnn, double event, double row) { /* HERE HERE Heuristic hack */ return (event + 1.0) / (row + (float)BNNodeGetNumValues(bnn)); //return event / row;}static double _GetEpsilonNormal(BNUserData firstData, BNUserData secondData, BeliefNet current) { BeliefNetNode bnn; int i, j, k, numCPTRows, n1; double likelihood, numSamples; double bound; StatTracker st; bound = 0; //n1 = -1; if(firstData->changedOne) { bnn = firstData->changedOne; n1 = BNNodeGetID(bnn); st = StatTrackerNew(); numCPTRows = BNNodeGetNumCPTRows(bnn); numSamples = BNNodeGetNumSamples(bnn); for(j = 0 ; j < numCPTRows ; j++) { /* HACK for efficiency break BNN ADT */ for(k = 0 ; k < BNNodeGetNumValues(bnn) ; k++) { likelihood = (bnn->eventCounts[j][k] / numSamples) * log(_CalculateCP(bnn, bnn->eventCounts[j][k], bnn->rowCounts[j])); for(i = 0 ; i < bnn->eventCounts[j][k] ; i++) { StatTrackerAddSample(st, likelihood); } } } bound += StatTrackerGetNormalBound(st, gDelta); StatTrackerFree(st); bnn = _BNGetNodeByID(secondData, n1); st = StatTrackerNew(); numCPTRows = BNNodeGetNumCPTRows(bnn); numSamples = BNNodeGetNumSamples(bnn); for(j = 0 ; j < numCPTRows ; j++) { /* HACK for efficiency break BNN ADT */ for(k = 0 ; k < BNNodeGetNumValues(bnn) ; k++) { likelihood = (bnn->eventCounts[j][k] / numSamples) * log(_CalculateCP(bnn, bnn->eventCounts[j][k], bnn->rowCounts[j])); for(i = 0 ; i < bnn->eventCounts[j][k] ; i++) { StatTrackerAddSample(st, likelihood); } } } bound += StatTrackerGetNormalBound(st, gDelta); StatTrackerFree(st); } else if(secondData->changedOne) { bnn = secondData->changedOne; n1 = BNNodeGetID(bnn); st = StatTrackerNew(); numCPTRows = BNNodeGetNumCPTRows(bnn); numSamples = BNNodeGetNumSamples(bnn); for(j = 0 ; j < numCPTRows ; j++) { /* HACK for efficiency break BNN ADT */ for(k = 0 ; k < BNNodeGetNumValues(bnn) ; k++) { likelihood = (bnn->eventCounts[j][k] / numSamples) * log(_CalculateCP(bnn, bnn->eventCounts[j][k], bnn->rowCounts[j])); for(i = 0 ; i < bnn->eventCounts[j][k] ; i++) { StatTrackerAddSample(st, likelihood); } } } bound += StatTrackerGetNormalBound(st, gDelta); StatTrackerFree(st); bnn = _BNGetNodeByID(firstData, n1); st = StatTrackerNew(); numCPTRows = BNNodeGetNumCPTRows(bnn); numSamples = BNNodeGetNumSamples(bnn); for(j = 0 ; j < numCPTRows ; j++) { /* HACK for efficiency break BNN ADT */ for(k = 0 ; k < BNNodeGetNumValues(bnn) ; k++) { likelihood = (bnn->eventCounts[j][k] / numSamples) * log(_CalculateCP(bnn, bnn->eventCounts[j][k], bnn->rowCounts[j])); for(i = 0 ; i < bnn->eventCounts[j][k] ; i++) { StatTrackerAddSample(st, likelihood); } } } bound += StatTrackerGetNormalBound(st, gDelta); StatTrackerFree(st); } return bound;}static double _GetNodeScoreRange(BeliefNetNode bnn) { double p0; BNNodeUserData data = BNNodeGetUserData(bnn); if(data == 0) { p0 = (1.0 / (5.0 * (double)BNNodeGetNumValues(bnn))); } else { p0 = min(1.0 / (5.0 * (double)BNNodeGetNumValues(bnn)), data->p0 / (double)gP0Multiplier); } return fabs(log(p0));}static double _GetComparedNodesScoreRange(BNUserData firstData, BNUserData secondData) { double scoreRange; scoreRange = 0; if(firstData->changedOne) { scoreRange += _GetNodeScoreRange(firstData->changedOne); } else if(secondData->changedOne) { scoreRange += _GetNodeScoreRange(secondData->changedOne); } return scoreRange;}void _UpdateNodeAveDataLL(BeliefNetNode bnn) { int numCPTRows; BNNodeUserData data; int j, k; double numSamples; double prob, logCP; data = BNNodeGetUserData(bnn); /* if this is the current node or is changed from it */ if(data->isChangedFromCurrent || (data->current == 0)) { data->avgDataLL = 0; numCPTRows = BNNodeGetNumCPTRows(bnn); numSamples = BNNodeGetNumSamples(bnn); for(j = 0 ; j < numCPTRows ; j++) { /* HACK for efficiency break BNN ADT */ for(k = 0 ; k < BNNodeGetNumValues(bnn) ; k++) { if(bnn->eventCounts[j][k] > 0) { prob = bnn->eventCounts[j][k] / numSamples; /* HERE HERE Heuristic Hack */ //logCP = log(bnn->eventCounts[j][k] / bnn->rowCounts[j]); logCP = log(_CalculateCP(bnn, bnn->eventCounts[j][k], bnn->rowCounts[j])); if(data->p0 > bnn->eventCounts[j][k] / bnn->rowCounts[j] && bnn->eventCounts[j][k] / bnn->rowCounts[j] > 0) { data->p0 = bnn->eventCounts[j][k] / bnn->rowCounts[j]; } data->avgDataLL += prob * logCP; } } } }}static void _UpdateNetScore(BNUserData netData) { if(netData->changedOne) { _UpdateNodeAveDataLL(netData->changedOne); } //if(netData->changedTwo) { // _UpdateNodeAveDataLL(netData->changedTwo); //}}static double _GetNodeScore(BeliefNetNode bnn) { BNNodeUserData data = BNNodeGetUserData(bnn); if(data->current == 0 || data->isChangedFromCurrent) { return data->avgDataLL; } else { return data->current->avgDataLL; }}double _GetDeltaScore(BNUserData firstData, BNUserData secondData) { double scoreOne, scoreTwo; scoreOne = scoreTwo = 0; if(firstData->changedOne) { scoreOne += _GetNodeScore(firstData->changedOne); scoreTwo += _GetNodeScore(_BNGetNodeByID(secondData, BNNodeGetID(firstData->changedOne))); } else if(secondData->changedOne) { scoreOne += _GetNodeScore(_BNGetNodeByID(firstData, BNNodeGetID(secondData->changedOne))); scoreTwo += _GetNodeScore(secondData->changedOne); } DebugMessage(1, 4, " first %lf second %lf\n", scoreOne, scoreTwo); return scoreOne - scoreTwo;}static int _IsFirstNetBetter(BNUserData firstData, BNUserData secondData) { return _GetDeltaScore(firstData, secondData) >= 0;}static int _IsCurrentNet(BNUserData data) { return (data->changedOne == 0); // && (data->changedTwo == 0);}static void _PickWinnerInTieFreeRest(ModelSearch ms, int bestIndex) { BNUserData winner; int i; /* use best in tie */ winner = VLRemove(ms->choices, bestIndex); for(i = VLLength(ms->choices) - 1 ; i >= 0 ; i--) { //if(_IsCurrentNet(VLIndex(ms->choices, i))) { // /* This favors the current net in ties */ // gNumCurrentInTie++; // _FreeUserData(winner); // BNFree(winner); // winner = VLRemove(ms->choices, i); //} else { _FreeUserData(VLRemove(ms->choices, i)); //} } if(_IsCurrentNet(winner)) { gNumCurrentInTie++; } VLAppend(ms->choices, winner);}void _GetOneStepChoicesForSearch(ModelSearch ms) { int i; BeliefNetNode currentNode, parentNode; //BNNodeUserData data; BNUserData netData; long preAllocation; ms->initialExampleNumber = -1; DebugMessage(VLLength(ms->choices) > 0, 0, "******Warning list in choices not empty\n"); preAllocation = MGetTotalAllocation(); /* the no-change-net */ VLAppend(ms->choices, _InitUserData(ms->nodeID, 0, -1, 0, ms->currentModel)); DebugMessage(1, 3, " allocated current net, size %ld\n", MGetTotalAllocation() - preAllocation); for(i = 0 ; i < BNGetNumNodes(ms->currentModel) ; i++) { if(i != ms->nodeID) { currentNode = BNGetNodeByID(ms->currentModel, ms->nodeID); parentNode = BNGetNodeByID(ms->currentModel, i); /* if they aren't related consider adding dst as a parent */ if(BNNodeLookupParentIndex(currentNode, parentNode) == -1 && BNNodeLookupParentIndex(parentNode, currentNode) == -1) { preAllocation = MGetTotalAllocation(); netData = _InitUserData(ms->nodeID, 1, i, 1, ms->currentModel); DebugMessage(1, 4, " allocated new net, size %ld\n", MGetTotalAllocation() - preAllocation); if((gMaxParentsPerNode == -1 || BNNodeGetNumParents(netData->changedOne) <= gMaxParentsPerNode) && (gMaxParameterCount == -1 || (BNNodeGetNumParameters(netData->changedOne) <= gMaxParameterCount)) && (ms->linkChangeCounts[i] <= 2) && (gLimitBytes == -1 || ((MGetTotalAllocation() - preAllocation) <= gMaxBytesPerModel))&& !_BNHasCycle(netData)) { VLAppend(ms->choices, netData); } else {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -