⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 gdecisiontree.cpp

📁 一个非常有用的开源代码
💻 CPP
字号:
/*	Copyright (C) 2006, Mike Gashler	This library is free software; you can redistribute it and/or	modify it under the terms of the GNU Lesser General Public	License as published by the Free Software Foundation; either	version 2.1 of the License, or (at your option) any later version.	see http://www.gnu.org/copyleft/lesser.html*/#include "GDecisionTree.h"#include "GArff.h"#include "../GClasses/GMacros.h"#include <stdlib.h>//#define DEBUGLOG#ifdef DEBUGLOG#define dbglog0(a)			fprintf(stderr, a)#define dbglog1(a,b)		fprintf(stderr, a, b)#define dbglog2(a,b,c)		fprintf(stderr, a, b, c)#define dbglog3(a,b,c,d)	fprintf(stderr, a, b, c, d)#else // DEBUGLOG#define dbglog0(a)		((void)0)#define dbglog1(a,b)	((void)0)#define dbglog2(a,b,c)	((void)0)#define dbglog3(a,b,d)	((void)0)#endif // !DEBUGLOGclass GDecisionTreeNode{friend class GDecisionTree;protected:	int m_nAttribute;	int m_nChildren;	int m_nSampleSize;	GDecisionTreeNode** m_ppChildren;	double* m_pOutputValues;	double m_dPivot;public:	GDecisionTreeNode()	{		m_nAttribute = -1;		m_nChildren = 0;		m_nSampleSize = 0;		m_ppChildren = NULL;		m_pOutputValues = NULL;		m_dPivot = 0;	}	~GDecisionTreeNode()	{		if(m_ppChildren)		{			int n;			for(n = 0; n < m_nChildren; n++)				delete(m_ppChildren[n]);			delete(m_ppChildren);		}		delete(m_pOutputValues);	}	GDecisionTreeNode* DeepCopy(GArffRelation* pRelation, GDecisionTreeNode* pInterestingNode, GDecisionTreeNode** ppOutInterestingCopy)	{		GDecisionTreeNode* pNewNode = new GDecisionTreeNode();		pNewNode->m_nAttribute = m_nAttribute;		pNewNode->m_nChildren = m_nChildren;		pNewNode->m_nSampleSize = m_nSampleSize;		pNewNode->m_dPivot = m_dPivot;		if(m_ppChildren)		{			GAssert(!m_pOutputValues, "Can't have children and output values");			pNewNode->m_pOutputValues = NULL;			pNewNode->m_ppChildren = new GDecisionTreeNode*[m_nChildren];			int n;			for(n = 0; n < m_nChildren; n++)				pNewNode->m_ppChildren[n] = m_ppChildren[n]->DeepCopy(pRelation, pInterestingNode, ppOutInterestingCopy);		}		else		{			GAssert(m_pOutputValues, "expected output values");			pNewNode->m_ppChildren = NULL;			int nCount = pRelation->GetOutputCount();			pNewNode->m_pOutputValues = new double[nCount];			int n;			for(n = 0; n < nCount; n++)				pNewNode->m_pOutputValues[n] = m_pOutputValues[n];		}		if(this == pInterestingNode)			*ppOutInterestingCopy = pNewNode;		return pNewNode;	}	void Print(GArffRelation* pRelation, int nSpaces, const char* szValue)	{		int n;		for(n = 0; n < nSpaces; n++)			printf("  ");		if(m_ppChildren)		{			GArffAttribute* pAttr = pRelation->GetAttribute(m_nAttribute);			if(pAttr->IsContinuous())				printf("%s -> %s (%f)?\n", szValue, pAttr->GetName(), m_dPivot);			else				printf("%s -> %s?\n", szValue, pAttr->GetName());			for(n = 0; n < m_nChildren; n++)				m_ppChildren[n]->Print(pRelation, nSpaces + 1, pAttr->GetValue(n));		}		else		{			int nCount = pRelation->GetOutputCount();			printf("%s -> ", szValue);			for(n = 0; n < nCount; n++)			{				GArffAttribute* pAttr = pRelation->GetAttribute(pRelation->GetOutputIndex(n));				if(n > 0)					printf(", ");				printf("%s=%s", pAttr->GetName(), pAttr->GetValue((int)m_pOutputValues[n]));			}			printf("\n");		}	}	// Recursive function that counts the number of times a particular	// value is found in a particular output in this branch of the tree	void CountValues(int nOutput, int* pnCounts)	{		if(m_ppChildren)		{			int n;			for(n = 0; n < m_nChildren; n++)				m_ppChildren[n]->CountValues(nOutput, pnCounts);		}		else		{			int nVal = (int)m_pOutputValues[nOutput];			pnCounts[nVal] += m_nSampleSize;		}	}	double FindSumOutputValue(int nOutput)	{		if(m_ppChildren)		{			double dSum = 0;			int n;			for(n = 0; n < m_nChildren; n++)				dSum += m_ppChildren[n]->FindSumOutputValue(nOutput);			return dSum;		}		else			return m_pOutputValues[nOutput] * m_nSampleSize;	}	void PruneChildren(GArffRelation* pRelation)	{		// Create output values by finding the most common outputs among children		GAssert(m_ppChildren, "This is a leaf node");		int nOutputCount = pRelation->GetOutputCount();		m_pOutputValues = new double[nOutputCount];		int n;		for(n = 0; n < nOutputCount; n++)		{			// Count the number of occurrences of each possible value for this output attribute			GArffAttribute* pAttr = pRelation->GetAttribute(pRelation->GetOutputIndex(n));			int nValueCount = pAttr->GetValueCount();			if(nValueCount <= 0)				m_pOutputValues[n] = FindSumOutputValue(n) / m_nSampleSize;			else			{				Holder<int*> hCounts(new int[nValueCount]);				int* pnCounts = hCounts.Get();				memset(pnCounts, '\0', sizeof(int) * nValueCount);				CountValues(n, pnCounts);				// Find the most frequent value				int i;				int nMax = 0;				for(i = 1; i < nValueCount; i++)				{					if(pnCounts[i] > pnCounts[nMax])						nMax = i;				}				m_pOutputValues[n] = (double)nMax;			}		}		// Delete the children		for(n = 0; n < m_nChildren; n++)			delete(m_ppChildren[n]);		delete(m_ppChildren);		m_ppChildren = NULL;	}};// -----------------------------------------------------------------GDecisionTree::GDecisionTree(GArffRelation* pRelation): GSupervisedLearner(pRelation){	m_pRoot = NULL;	m_dTrainingPortion = .65;}GDecisionTree::GDecisionTree(GDecisionTree* pThat, GDecisionTreeNode* pInterestingNode, GDecisionTreeNode** ppOutInterestingCopy): GSupervisedLearner(pThat->m_pRelation){	m_pRelation = pThat->m_pRelation;	m_pRoot = pThat->m_pRoot->DeepCopy(pThat->m_pRelation, pInterestingNode, ppOutInterestingCopy);}GDecisionTree::~GDecisionTree(){	delete(m_pRoot);}void GDecisionTree::Train(GArffData* pData){	int nTrainRows = (int)(m_dTrainingPortion * pData->GetSize());	GArffData* pPruningData = pData->SplitBySize(nTrainRows);	TrainWithoutPruning(pData);	Prune(pPruningData);	pData->Merge(pPruningData);	delete(pPruningData);}void GDecisionTree::TrainWithoutPruning(GArffData* pTrainingData){	delete(m_pRoot);	if(pTrainingData->GetSize() > 0)	{		m_pRoot = new GDecisionTreeNode();		int nAttributes = m_pRelation->GetAttributeCount();		Holder<bool*> hUsedAttributes(new bool[nAttributes]);		bool* pUsedAttributes = hUsedAttributes.Get();		int n;		for(n = 0; n < nAttributes; n++)			pUsedAttributes[n] = false;		BuildNode(m_pRoot, pTrainingData, pUsedAttributes);	}	else		m_pRoot = NULL;}// This constructs the decision tree in a recursive depth-first mannervoid GDecisionTree::BuildNode(GDecisionTreeNode* pNode, GArffData* pData, bool* pUsedAttributes){	int n;#ifdef DEBUGLOG	// Log debug stuff	dbglog1("BuildNode from %d rows\n", pData->GetSize());	int nAttrCount = pRelation->GetAttributeCount();	for(n = 0; n < pData->GetSize(); n++)	{		double* pRow = pData->GetRow(n);		dbglog0("\t");		int i;		for(i = 0; i < nAttrCount; i++)		{			GArffAttribute* pAttr = pRelation->GetAttribute(i);			dbglog1("%s, ", pAttr->GetValue((int)pRow[i]));		}		dbglog0("\n");	}#endif // DEBUGLOG	// Pick the best attribute to divide on	pNode->m_nSampleSize = pData->GetSize();	GAssert(pNode->m_nSampleSize > 0, "Can't work without data");	double dBestGain = -1e100;	double dBestPivot = 0;	int nBestAttribute = -1;	if(!pData->IsOutputHomogenous(m_pRelation))	{		double dGain;		double dPivot;		int nAttr;		int nInputCount = m_pRelation->GetInputCount();		for(n = 0; n < nInputCount; n++)		{			nAttr = m_pRelation->GetInputIndex(n);			if(pUsedAttributes[nAttr])				continue;			dGain = MeasureInfoGain(pData, nAttr, &dPivot);			if(nBestAttribute < 0 || dGain > dBestGain)			{				dBestGain = dGain;				nBestAttribute = nAttr;				dBestPivot = dPivot;			}		}	}	Holder<double*> hMostCommonOutputs(pData->MakeSetOfMostCommonOutputs(m_pRelation));	GAssert(hMostCommonOutputs.Get(), "Failed to get output values");	if(nBestAttribute < 0)	{		// There are no input attributes left on which to divide, so this is a leaf		dbglog0("Leaf\n");		pNode->m_pOutputValues = hMostCommonOutputs.Drop();		return;	}	// Get rid of any unknown values for the best attribute	pData->ReplaceMissingAttributeWithMostCommonValue(m_pRelation, nBestAttribute);	// Create child nodes	pUsedAttributes[nBestAttribute] = true;	pNode->m_nAttribute = nBestAttribute;	pNode->m_dPivot = dBestPivot;	GArffAttribute* pAttr = m_pRelation->GetAttribute(nBestAttribute);	dbglog2("Attribute=%d (%s)\n", nBestAttribute, pAttr->GetName());	GAssert(pAttr->IsInput(), "Expected an input");	GArffData** ppParts;	int nChildCount;	if(pAttr->IsContinuous())	{		ppParts = new GArffData*[2];		ppParts[0] = pData->SplitByPivot(nBestAttribute, dBestPivot);		ppParts[1] = new GArffData(pData->GetSize());		ppParts[1]->Merge(pData);		nChildCount = 2;	}	else	{		ppParts = pData->SplitByAttribute(m_pRelation, nBestAttribute);		nChildCount = pAttr->GetValueCount();	}	pNode->m_nChildren = nChildCount;	pNode->m_ppChildren = new GDecisionTreeNode*[nChildCount];	for(n = 0; n < nChildCount; n++)	{		GDecisionTreeNode* pChildNode = new GDecisionTreeNode();		pNode->m_ppChildren[n] = pChildNode;		if(ppParts[n] && ppParts[n]->GetSize() > 0)		{			BuildNode(pChildNode, ppParts[n], pUsedAttributes);			pData->Merge(ppParts[n]);		}		else		{			// There's no data for this child, so just use the most common outputs of the parent			pChildNode->m_nSampleSize = 0;			GAssert(hMostCommonOutputs.Get(), "no outputs");			pChildNode->m_pOutputValues = new double[m_pRelation->GetOutputCount()];			memcpy(pChildNode->m_pOutputValues, hMostCommonOutputs.Get(), sizeof(double) * m_pRelation->GetOutputCount());		}		delete(ppParts[n]);	}	delete(ppParts);	pUsedAttributes[nBestAttribute] = false;}double GDecisionTree::MeasureInfoGain(GArffData* pData, int nAttribute, double* pPivot){	// Measure initial output info	double dGain = m_pRelation->MeasureTotalOutputInfo(pData);	if(dGain == 0)		return 0;	// Seperate by attribute values and measure difference in output info	GArffAttribute* pAttr = m_pRelation->GetAttribute(nAttribute);	GAssert(pAttr->IsInput(), "expected an input attribute");	if(pAttr->IsContinuous())	{		double dSumOutputInfo;		*pPivot = pData->ComputeMinimumInfoPivot(m_pRelation, nAttribute, &dSumOutputInfo);		dGain -= dSumOutputInfo;		return dGain;	}	else	{		*pPivot = 0;		int nRowCount = pData->GetSize();		GArffData** ppParts = pData->SplitByAttribute(m_pRelation, nAttribute);		int nCount = pAttr->GetValueCount();		int n;		for(n = 0; n < nCount; n++)		{			dGain -= ((double)ppParts[n]->GetSize() / nRowCount) * m_pRelation->MeasureTotalOutputInfo(ppParts[n]);			pData->Merge(ppParts[n]);			delete(ppParts[n]);		}		delete(ppParts);		GAssert(pData->GetSize() == nRowCount, "Didn't reassemble data correctly");		return dGain;	}}void GDecisionTree::Eval(double* pRow){	GAssert(m_pRoot, "No tree constructed");	GDecisionTreeNode* pNode = m_pRoot;	GArffAttribute* pAttr;	int nVal;	while(pNode->m_ppChildren)	{		pAttr = m_pRelation->GetAttribute(pNode->m_nAttribute);		GAssert(pAttr->IsInput(), "expected an input");		if(pAttr->IsContinuous())		{			if(pRow[pNode->m_nAttribute] <= pNode->m_dPivot)				pNode = pNode->m_ppChildren[0];			else				pNode = pNode->m_ppChildren[1];		}		else		{			nVal = (int)pRow[pNode->m_nAttribute];			if(nVal < 0)			{				GAssert(nVal == -1, "out of range");				nVal = rand() % pAttr->GetValueCount();			}			GAssert(nVal < pAttr->GetValueCount(), "value out of range");			pNode = pNode->m_ppChildren[nVal];		}	}	GAssert(pNode->m_pOutputValues, "Leaf node has no output values");		// Copy the output values into the row	int n;	int nOutputCount = m_pRelation->GetOutputCount();	for(n = 0; n < nOutputCount; n++)		pRow[m_pRelation->GetOutputIndex(n)] = pNode->m_pOutputValues[n];}void GDecisionTree::Print(){	m_pRoot->Print(m_pRelation, 0, "All");}void GDecisionTree::DeepPruneNode(GDecisionTreeNode* pNode, GArffData* pValidationSet){	if(!pNode->m_ppChildren)		return;	int n;	for(n = 0; n < pNode->m_nChildren; n++)		DeepPruneNode(pNode->m_ppChildren[n], pValidationSet);	GDecisionTreeNode* pNodeCopy;	GDecisionTree tmp(this, pNode, &pNodeCopy);	pNodeCopy->PruneChildren(m_pRelation);	double dOriginalScore = MeasurePredictiveAccuracy(pValidationSet);	double dPrunedScore = tmp.MeasurePredictiveAccuracy(pValidationSet);	if(dPrunedScore >= dOriginalScore)		pNode->PruneChildren(m_pRelation);}void GDecisionTree::Prune(GArffData* pValidationSet){	DeepPruneNode(m_pRoot, pValidationSet);}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -