📄 vfbn1.c
字号:
#include "vfml.h"
#include <stdio.h>
#include <string.h>
#include <assert.h>
#include <sys/times.h>
#include <unistd.h>
#include <time.h>
#include <math.h>
#define kOneOverE (0.367879441)
char *gFileStem = "DF";
char *gStartingNet = "";
int gUseStartingNet = 0;
char *gNetOutput = "";
int gOutputNet = 0;
char *gSourceDirectory = ".";
int gDoTests = 0;
float gDelta = 0.00001;
float gDeltaStar = 0.00001;
float gTau = 0.001;
int gChunk = 10000;
long gLimitBytes = -1;
double gLimitSeconds = -1;
int gStdin = 0;
int gNoReverse = 0;
int gCacheTestSet = 1;
int gDoBatch = 0;
double gKappa = 0.5;
int gUseHeuristicBound = 1;
int gUseNormalApprox = 0;
int gOnlyEstimateParameters = 0;
int gMaxSearchSteps = -1;
int gMaxParentsPerNode = -1;
int gMaxParameterGrowthMult = -1;
long gMaxParameterCount = -1;
int gUseStructuralPriorInTie = 0;
int gSeed = -1;
/* hack globals */
int gP0Multiplier = 1;
long gNumBoundsUsed = 0;
long gNumCheckPointBoundsUsed = 0;
double gObservedP0 = 1;
double gEntropyRangeNet = 0;
long gInitialParameterCount = 0;
long gBranchFactor = 0;
long gMaxBytesPerModel = 0;
BeliefNet gPriorNet = 0;
long gSamplesNeeded = 0;
static void _printUsage(char *name) {
printf("%s : 'Very Fast Belief Net' structure learning\n", name);
printf("-f <filestem>\t Set the name of the dataset (default DF)\n");
printf("-source <dir>\t Set the source data directory (default '.')\n");
printf("-startFrom <filename>\t use net in <filename> as starting point,\n\t\t must be BIF file (default start from empty net)\n");
printf("-outputTo <filename>\t output the learned net to <filename>\n");
printf("-delta <prob> \t Allowed chance of error in each decision\n\t\t (default 0.00001 that's .001 percent)\n");
printf("-tau <tie error>\t Call a tie when score might change < than\n\t\t tau percent. (default 0.001)\n");
printf("-chunk <count> \t accumulate 'count' examples before testing for\n\t\t a winning search step(default 10000)\n");
printf("-limitMegs <count>\t Limit dynamic memory allocation to 'count'\n\t\t megabytes, don't consider networks that take too much\n\t\t space (default no limit)\n");
printf("-limitMinutes <count>\t Limit the run to <count> minutes then\n\t\t return current model (default no limit)\n");
printf("-normal \t Use normal bound (default Hoeffding)\n");
//printf("-correctBound \t Use the correct bound (default heuristic bound)\n");
printf("-stdin \t\t Reads training examples from stdin instead of from\n\t\t <stem>.data causes a 2 second delay to help give\n\t\t input time to setup (default off)\n");
printf("-noReverse \t Doesn't reverse links to make nets for next search\n\t\t step (default reverse links)\n");
printf("-parametersOnly\t Only estimate parameters for current\n\t\t structure, no other learning\n");
printf("-seed <s>\t Seed for random numbers (default random)\n");
printf("-maxSearchSteps <num>\tLimit to <num> search steps (default no max).\n");
printf("-maxParentsPerNode <num>\tLimit each node to <num> parents\n\t\t (default no max).\n");
printf("-maxParameterGrowthMult <mult>\tLimit net to <mult> times starting\n\t\t # of parameters (default no max).\n");
printf("-maxParameterCount <count>\tLimit net to <count> parameters\n\t\t (default no max).\n");
printf("-kappa <kappa> the structure prior penalty for batch (0 - 1), 1 is\n\t\t no penalty (default 0.5)\n");
printf("-structureTie Use the structural prior penalty in ties (default don't)\n");
printf("-batch \t Run in batch mode, repeatedly scan disk, don't do hoeffding\n\t\t bounds (default off).\n");
printf("-v\t\t Can be used multiple times to increase the debugging output\n");
}
static void _processArgs(int argc, char *argv[]) {
int i;
/* HERE on the ones that use the next arg make sure it is there */
for(i = 1 ; i < argc ; i++) {
if(!strcmp(argv[i], "-f")) {
gFileStem = argv[i+1];
/* ignore the next argument */
i++;
} else if(!strcmp(argv[i], "-startFrom")) {
gStartingNet = argv[i+1];
gUseStartingNet = 1;
/* ignore the next argument */
i++;
} else if(!strcmp(argv[i], "-outputTo")) {
gNetOutput = argv[i+1];
gOutputNet = 1;
/* ignore the next argument */
i++;
} else if(!strcmp(argv[i], "-source")) {
gSourceDirectory = argv[i+1];
/* ignore the next argument */
i++;
} else if(!strcmp(argv[i], "-u")) {
gDoTests = 1;
} else if(!strcmp(argv[i], "-v")) {
DebugSetMessageLevel(DebugGetMessageLevel() + 1);
} else if(!strcmp(argv[i], "-structureTie")) {
gUseStructuralPriorInTie = 1;
} else if(!strcmp(argv[i], "-h")) {
_printUsage(argv[0]);
exit(0);
} else if(!strcmp(argv[i], "-delta")) {
sscanf(argv[i+1], "%f", &gDelta);
/* ignore the next argument */
i++;
} else if(!strcmp(argv[i], "-kappa")) {
sscanf(argv[i+1], "%lf", &gKappa);
/* ignore the next argument */
i++;
} else if(!strcmp(argv[i], "-tau")) {
sscanf(argv[i+1], "%f", &gTau);
/* ignore the next argument */
i++;
} else if(!strcmp(argv[i], "-chunk")) {
sscanf(argv[i+1], "%d", &gChunk);
/* ignore the next argument */
i++;
} else if(!strcmp(argv[i], "-limitMegs")) {
sscanf(argv[i+1], "%ld", &gLimitBytes);
gLimitBytes *= 1024 * 1024;
/* ignore the next argument */
i++;
} else if(!strcmp(argv[i], "-limitMinutes")) {
sscanf(argv[i+1], "%lf", &gLimitSeconds);
gLimitSeconds *= 60;
/* ignore the next argument */
i++;
} else if(!strcmp(argv[i], "-maxSearchSteps")) {
sscanf(argv[i+1], "%d", &gMaxSearchSteps);
/* ignore the next argument */
i++;
} else if(!strcmp(argv[i], "-maxParentsPerNode")) {
sscanf(argv[i+1], "%d", &gMaxParentsPerNode);
/* ignore the next argument */
i++;
} else if(!strcmp(argv[i], "-maxParameterGrowthMult")) {
sscanf(argv[i+1], "%d", &gMaxParameterGrowthMult);
/* ignore the next argument */
i++;
} else if(!strcmp(argv[i], "-maxParameterCount")) {
sscanf(argv[i+1], "%ld", &gMaxParameterCount);
/* ignore the next argument */
i++;
} else if(!strcmp(argv[i], "-stdin")) {
sleep(2);
gStdin = 1;
} else if(!strcmp(argv[i], "-normal")) {
gUseNormalApprox = 1;
gUseHeuristicBound = 1;
} else if(!strcmp(argv[i], "-correctBound")) {
gUseHeuristicBound = 0;
} else if(!strcmp(argv[i], "-parametersOnly")) {
gOnlyEstimateParameters = 1;
gDoBatch = 1;
} else if(!strcmp(argv[i], "-noReverse")) {
gNoReverse = 1;
} else if(!strcmp(argv[i], "-seed")) {
sscanf(argv[i+1], "%d", &gSeed);
/* ignore the next argument */
i++;
} else if(!strcmp(argv[i], "-noCacheTestSet")) {
gCacheTestSet = 0;
} else if(!strcmp(argv[i], "-batch")) {
gDoBatch = 1;
} else {
printf("Unknown argument: %s. use -h for help\n", argv[i]);
exit(0);
}
}
DebugMessage(1, 1, "Stem: %s\n", gFileStem);
DebugMessage(1, 1, "Source: %s\n", gSourceDirectory);
DebugMessage(!gDoBatch, 1, "Delta: %.13f\n", gDelta);
DebugMessage(!gDoBatch, 1, "Tau: %f\n", gTau);
DebugMessage(gStdin, 1, "Reading examples from standard in.\n");
DebugMessage(!gDoBatch, 1,
"Gather %d examples before checking for winner\n", gChunk);
DebugMessage(gDoTests, 1, "Running tests\n");
}
VoidAListPtr _testSet;
int _testCacheInited = 0;
static void _doTests(ExampleSpecPtr es, BeliefNet bn, long learnCount, long learnTime, long allocation, int finalOutput) {
int oldPool = MGetActivePool();
char fileNames[255];
FILE *exampleIn;
ExamplePtr e;
long i;
long tested, errors;
// struct tms starttime;
// struct tms endtime;
errors = tested = 0;
/* don't track this allocation against other VFBN stuff */
MSetActivePool(0);
if(gCacheTestSet) {
if(!_testCacheInited) {
_testSet = VALNew();
sprintf(fileNames, "%s/%s.test", gSourceDirectory, gFileStem);
exampleIn = fopen(fileNames, "r");
DebugError(exampleIn == 0, "Unable to open the .test file");
e = ExampleRead(exampleIn, es);
while(e != 0) {
VALAppend(_testSet, e);
e = ExampleRead(exampleIn, es);
}
fclose(exampleIn);
_testCacheInited = 1;
}
for(i = 0 ; i < VALLength(_testSet) ; i++) {
e = VALIndex(_testSet, i);
if(!ExampleIsClassUnknown(e)) {
tested++;
// if(ExampleGetClass(e) != DecisionTreeClassify(dt, e)) {
// errors++;
// }
}
}
} else {
sprintf(fileNames, "%s/%s.test", gSourceDirectory, gFileStem);
exampleIn = fopen(fileNames, "r");
DebugError(exampleIn == 0, "Unable to open the .test file");
DebugMessage(1, 1, "opened test file, starting scan...\n");
e = ExampleRead(exampleIn, es);
while(e != 0) {
if(!ExampleIsClassUnknown(e)) {
tested++;
// if(ExampleGetClass(e) != DecisionTreeClassify(dt, e)) {
// errors++;
// }
}
ExampleFree(e);
e = ExampleRead(exampleIn, es);
}
fclose(exampleIn);
}
// if(finalOutput) {
// DebugMessage(1, 1,
// printf("Tested %ld examples made %ld errors\n", (long)tested,
// (long)errors);
// }
// printf("%.4f\t%ld\n", ((float)errors/(float)tested) * 100,
// (long)DecisionTreeCountNodes(dt));
// } else {
// printf(">> %ld\t%.4f\t%ld\t%ld\t%.2lf\t%.2f\n",
// learnCount,
// ((float)errors/(float)tested) * 100,
// (long)DecisionTreeCountNodes(dt),
// growingNodes,
// ((double)learnTime) / 100,
// ((double)allocation) / (1024 * 1024));
// }
// fflush(stdout);
MSetActivePool(oldPool);
}
typedef struct BNUserData_ {
BeliefNetNode changedOne;
BeliefNetNode changedTwo;
/* -1 for removal, 0 for nothing, 1 for reverse, 2 for add */
int changeComplexity;
double scoreRange;
} BNUserDataStruct, *BNUserData;
typedef struct BNNodeUserData_ {
double avgDataLL;
double score;
double upperBound;
double lowerBound;
double p0;
int isChangedFromCurrent;
struct BNNodeUserData_ *current; // null if this is a current node
} BNNodeUserDataStruct, *BNNodeUserData;
static void _FreeUserData(BeliefNet bn) {
int i;
BeliefNetNode bnn;
for(i = 0 ; i < BNGetNumNodes(bn) ; i++) {
bnn = BNGetNodeByID(bn, i);
MFreePtr(BNNodeGetUserData(bnn));
}
MFreePtr(BNGetUserData(bn));
}
static void _InitUserData(BeliefNet bn, BeliefNet current) {
int i;
BeliefNetNode bnn, currentNode;
BNNodeUserData data;
BNUserData netData;
for(i = 0 ; i < BNGetNumNodes(bn) ; i++) {
bnn = BNGetNodeByID(bn, i);
data = MNewPtr(sizeof(BNNodeUserDataStruct));
BNNodeSetUserData(bnn, data);
data->score = 0;
data->avgDataLL = 0;
data->isChangedFromCurrent = 0;
data->p0 = 1;
if(bn == current) { /* the init for the current net */
data->current = 0;
} else {
currentNode = BNGetNodeByID(current, i);
data->current = BNNodeGetUserData(currentNode);
}
}
netData = MNewPtr(sizeof(BNUserDataStruct));
BNSetUserData(bn, netData);
netData->changedOne = netData->changedTwo = 0;
netData->changeComplexity = 0;
netData->scoreRange = 0;
}
static double _GetNodeScoreRange(BeliefNetNode bnn) {
double p0;
BNNodeUserData data = BNNodeGetUserData(bnn);
if(data == 0) {
p0 = (1.0 / (5.0 * (double)BNNodeGetNumValues(bnn)));
} else {
p0 = min(1.0 / (5.0 * (double)BNNodeGetNumValues(bnn)),
data->p0 / (double)gP0Multiplier);
}
return fabs(log(p0));
}
static double _GetNetScoreRange(BeliefNet bn) {
int i;
BeliefNetNode bnn;
double numCPTEntries = 0;
for(i = 0 ; i < BNGetNumNodes(bn) ; i++) {
bnn = BNGetNodeByID(bn, i);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -