📄 vfdt-engine.c
字号:
while(e != 0 && vfdt->numGrowing > 0) {
_ProcessExample(vfdt, e);
vfdt->examplesSeen++;
if((vfdt->messageLevel >= 1) && (vfdt->examplesSeen % 1000 == 0)) {
printf("processed %ld examples\n", vfdt->examplesSeen);
}
e = ExampleRead(input, vfdt->spec);
}
if(vfdt->messageLevel >= 1) {
printf("finished with all the examples, there are %ld growing nodes\n",
vfdt->numGrowing);
}
}
//void VFDTBootstrapC45(VFDTPtr vfdt, char *fileStem, int overprune, int runC45) {
// char command[500], fileName[100];
// FILE *tree;
// if(runC45) {
// sprintf(command, "c4.5 -f %s -u >> /dev/null", fileStem);
// if(vfdt->messageLevel >= 1) {
// printf("%s\n", command);
// }
// system(command);
// }
// DecisionTreeFree(vfdt->dtree);
// sprintf(fileName, "%s.tree", fileStem);
// tree = fopen(fileName, "r");
// if(overprune) {
// vfdt->dtree = DecisionTreeReadC45Overprune(tree, vfdt->spec);
// } else {
// vfdt->dtree = DecisionTreeReadC45(tree, vfdt->spec);
// }
// fclose(tree);
// /* now reactivate the thing */
// _ReactivateLeaves(vfdt);
//}
void VFDTProcessExamples(VFDTPtr vfdt, FILE *input) {
_ProcessExamples(vfdt, input);
}
void VFDTProcessExampleBatch(VFDTPtr vfdt, ExamplePtr e) {
_ProcessExampleBatch(vfdt, e);
}
void VFDTBatchExamplesDone(VFDTPtr vfdt) {
vfdt->batchMode = 1;
_forceSplits(vfdt, vfdt->dtree);
}
void VFDTProcessExamplesBatch(VFDTPtr vfdt, FILE *input) {
vfdt->batchMode = 1;
_ProcessExamplesBatch(vfdt, input);
}
void VFDTProcessExample(VFDTPtr vfdt, ExamplePtr e) {
vfdt->examplesSeen++;
_ProcessExample(vfdt, e);
}
int VFDTIsDoneLearning(VFDTPtr vfdt) {
return !DecisionTreeIsTreeGrowing(vfdt->dtree);
}
long VFDTGetNumGrowing(VFDTPtr vfdt) {
return vfdt->numGrowing;
}
long VFDTGetNumBoundsUsed(VFDTPtr vfdt) {
return vfdt->numBoundsUsed;
}
void VFDTPrintStats(VFDTPtr vfdt, FILE *out) {
int i;
VoidAListPtr list = VALNew();
ExampleGroupStatsPtr egs;
DecisionTreePtr growNode;
long seenSum = 0;
long commonCountSum = 0;
long pureCount = 0;
DecisionTreeGatherGrowingNodes(vfdt->dtree, list);
for(i = VALLength(list) - 1 ; i >= 0 ; i--) {
growNode = VALIndex(list, i);
egs = ((VFDTGrowingDataPtr)DecisionTreeGetGrowingData(growNode))->egs;
seenSum += ExampleGroupStatsNumExamplesSeen(egs);
commonCountSum += ExampleGroupStatsGetMostCommonClassCount(egs);
if(ExampleGroupStatsIsPure(egs)) {
pureCount++;
}
}
if(VALLength(list) > 0) {
fprintf(out, "growing - seen %ld ave %ld - avg mcc %ld - pure %ld\n", seenSum, seenSum / VALLength(list), commonCountSum / VALLength(list), pureCount);
} else {
fprintf(out, "There aren't any growing nodes\n");
}
VALFree(list);
fprintf(out, "used %ld bounds checks\n", vfdt->numBoundsUsed);
DecisionTreePrintStats(vfdt->dtree, out);
}
DecisionTreePtr VFDTGetLearnedTree(VFDTPtr vfdt) {
int i;
VoidAListPtr list = VALNew();
ExampleGroupStatsPtr egs;
DecisionTreePtr growNode;
DecisionTreePtr finalTree = DecisionTreeClone(vfdt->dtree);
DecisionTreeGatherGrowingNodes(finalTree, list);
for(i = VALLength(list) - 1 ; i >= 0 ; i--) {
growNode = VALIndex(list, i);
DecisionTreeSetTypeLeaf(growNode);
egs = ((VFDTGrowingDataPtr)DecisionTreeGetGrowingData(growNode))->egs;
if(!vfdt->batchMode) {
/* only use the laplace if we aren't in batch mode */
DecisionTreeSetClass(growNode,
ExampleGroupStatsGetMostCommonClassLaplace(egs,
((VFDTGrowingDataPtr)DecisionTreeGetGrowingData(growNode))->parentClass, 5));
} else {
DecisionTreeSetClass(growNode,
ExampleGroupStatsGetMostCommonClass(egs));
}
}
VALFree(list);
return finalTree;
}
//static float _TestAccuracy(DecisionTreePtr dt, VoidAListPtr examples) {
// long mistakes = 0;
// ExamplePtr e;
// int class;
// int i;
// for(i = 0 ; i < VALLength(examples) ; i++) {
// e = VALIndex(examples, i);
// class = DecisionTreeClassify(dt, e);
// if(class != ExampleGetClass(e)) {
// mistakes++;
// }
// }
// return 1.0 - ((float)mistakes / (float)VALLength(examples));
//}
//float _bestPruneAccuracy;
//DecisionTreePtr _bestPruneNode;
//static void _REPrune(DecisionTreePtr dt, DecisionTreePtr currentNode,
// VoidAListPtr examples) {
// NodeType oldNodeType;
// int oldNodeClass;
// float newAccuracy;
// int i;
// if(currentNode->nodeType == dtnGrowing || currentNode->nodeType == dtnLeaf) {
// return;
// } else {
// oldNodeType = currentNode->nodeType;
// oldNodeClass = currentNode->class;
// currentNode->class = DecisionTreeGetMostCommonClass(currentNode);
// currentNode->nodeType = dtnLeaf;
// newAccuracy = _TestAccuracy(dt, examples);
// if(newAccuracy >= _bestPruneAccuracy) {
// _bestPruneAccuracy = newAccuracy;
// _bestPruneNode = currentNode;
// }
// currentNode->nodeType = oldNodeType;
// currentNode->class = oldNodeClass;
// for(i = 0 ; i < VALLength(currentNode->children) ; i++) {
// _REPrune(dt, VALIndex(currentNode->children, i), examples);
// }
// }
//}
/* HERE I could split the examples as I recur, but...I'm not... */
//void VFDTREPrune(DecisionTreePtr dt, VoidAListPtr examples) {
// float currentAccuracy;
// int progress = 1;
// int i;
// while(progress) {
// _bestPruneAccuracy = 0;
// _bestPruneNode = 0;
// _REPrune(dt, dt, examples);
// if(_bestPruneNode == 0 ||
// (_bestPruneAccuracy < _TestAccuracy(dt, examples))) {
// /* if there is nothing to prune or pruning makes things worse */
// progress = 0;
// } else {
// _bestPruneNode->class = DecisionTreeGetMostCommonClass(_bestPruneNode);
// _bestPruneNode->nodeType = dtnLeaf;
// for(i = 0 ; i < VALLength(_bestPruneNode->children) ; i++) {
// DecisionTreeFree(VALIndex(_bestPruneNode->children, i));
// }
// VALFree(_bestPruneNode->children);
// }
// }
//}
typedef struct _PRUNEDATA_ {
long errors;
int class;
int errorDelta;
DecisionTreePtr parent;
} PruneData, *PruneDataPtr;
static void _InitPruneData(DecisionTreePtr current, DecisionTreePtr parent) {
PruneDataPtr data;
long i;
data = MNewPtr(sizeof(PruneData));
data->errors = 0;
data->class = DecisionTreeGetMostCommonClass(current);
data->errorDelta = 0;
data->parent = parent;
DecisionTreeSetGrowingData(current, data);
if(current->nodeType != dtnGrowing && current->nodeType != dtnLeaf) {
for(i = 0 ; i < VALLength(current->children) ; i++) {
_InitPruneData(VALIndex(current->children, i), current);
}
}
}
static void _UpdateErrorCounts(DecisionTreePtr current, ExamplePtr e) {
PruneDataPtr data = DecisionTreeGetGrowingData(current);
if(data->class != ExampleGetClass(e)) {
data->errors += 1;
data->errorDelta += 1;
}
if(current->nodeType != dtnGrowing && current->nodeType != dtnLeaf) {
_UpdateErrorCounts(DecisionTreeOneStepClassify(current, e), e);
}
}
static long _InitErrorDelta(DecisionTreePtr current) {
PruneDataPtr data = DecisionTreeGetGrowingData(current);
long sum = 0;
long i;
if(current->nodeType != dtnGrowing && current->nodeType != dtnLeaf) {
for(i = 0 ; i < VALLength(current->children) ; i++) {
sum += _InitErrorDelta(VALIndex(current->children, i));
}
data->errorDelta -= sum;
return sum;
} else {
return data->errors;
}
}
static void _FreePruneData(DecisionTreePtr dt) {
long i;
MFreePtr(DecisionTreeGetGrowingData(dt));
if(dt->nodeType != dtnGrowing && dt->nodeType != dtnLeaf) {
for(i = 0 ; i < VALLength(dt->children) ; i++) {
_FreePruneData(VALIndex(dt->children, i));
}
}
}
static DecisionTreePtr _FindBestPruneNode(DecisionTreePtr dt) {
PruneDataPtr bestData = DecisionTreeGetGrowingData(dt);
DecisionTreePtr best = dt;
PruneDataPtr tmpData;
DecisionTreePtr tmp;
long i;
if(dt->nodeType != dtnGrowing && dt->nodeType != dtnLeaf) {
for(i = 0 ; i < VALLength(dt->children) ; i++) {
tmp = _FindBestPruneNode(VALIndex(dt->children, i));
if(tmp != 0) {
tmpData = DecisionTreeGetGrowingData(tmp);
if(tmpData->errorDelta < bestData->errorDelta) {
best = tmp;
bestData = tmpData;
}
}
}
if(bestData->errorDelta <= 0) {
return best;
}
}
return 0;
}
static void _PrintPruneData(DecisionTreePtr dt, int level) {
PruneDataPtr bestData = DecisionTreeGetGrowingData(dt);
long i;
for(i = 0 ; i < level ; i++) {
printf(" ");
}
printf("l%d e%ld d%d c%d\n", level,
bestData->errors, bestData->errorDelta, bestData->class);
if(dt->nodeType != dtnGrowing && dt->nodeType != dtnLeaf) {
for(i = 0 ; i < VALLength(dt->children) ; i++) {
_PrintPruneData(VALIndex(dt->children, i), level + 1);
}
}
}
static void _PruneNode(DecisionTreePtr dt) {
PruneDataPtr data = DecisionTreeGetGrowingData(dt);
DecisionTreePtr current;
PruneDataPtr currentData;
long i;
/* update parent delta errors */
current = data->parent;
while(current != 0) {
currentData = DecisionTreeGetGrowingData(current);
/* this seems backwards, but by improving errors below the parent
looks like a less good place to prune */
currentData->errorDelta -= data->errorDelta;
current = currentData->parent;
}
/* free stuff below the node */
if(dt->nodeType != dtnGrowing && dt->nodeType != dtnLeaf) {
for(i = 0 ; i < VALLength(dt->children) ; i++) {
current = VALIndex(dt->children, i);
_FreePruneData(current);
DecisionTreeFree(current);
}
}
/* update this node */
DecisionTreeSetTypeLeaf(dt);
DecisionTreeSetClass(dt, data->class);
data->errorDelta = data->errors;
}
void VFDTREPrune(DecisionTreePtr dt, VoidAListPtr examples) {
long i;
int progress;
DecisionTreePtr pruneNode;
// propogate the prune data over the whole tree
_InitPruneData(dt, 0);
// pass the prune set through the tree recording errors
for(i = 0 ; i < VALLength(examples) ; i++) {
_UpdateErrorCounts(dt, VALIndex(examples, i));
}
// collect the error deltas
_InitErrorDelta(dt);
//_PrintPruneData(dt, 0);
progress = 1;
// while there is an improvement
while(progress) {
// find the best candidate
pruneNode = _FindBestPruneNode(dt);
// prune the best candidate if appropriate
if(pruneNode != 0) {
//printf("prune a node\n");
_PruneNode(pruneNode);
} else {
progress = 0;
}
}
// free the prune data
_FreePruneData(dt);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -