📄 gdecisiontree.cpp
字号:
/* Copyright (C) 2006, Mike Gashler This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. see http://www.gnu.org/copyleft/lesser.html*/#include "GDecisionTree.h"#include "GArff.h"#include "../GClasses/GMacros.h"#include <stdlib.h>//#define DEBUGLOG#ifdef DEBUGLOG#define dbglog0(a) fprintf(stderr, a)#define dbglog1(a,b) fprintf(stderr, a, b)#define dbglog2(a,b,c) fprintf(stderr, a, b, c)#define dbglog3(a,b,c,d) fprintf(stderr, a, b, c, d)#else // DEBUGLOG#define dbglog0(a) ((void)0)#define dbglog1(a,b) ((void)0)#define dbglog2(a,b,c) ((void)0)#define dbglog3(a,b,d) ((void)0)#endif // !DEBUGLOGclass GDecisionTreeNode{public: GDecisionTreeNode() { } virtual ~GDecisionTreeNode() { } virtual bool IsLeaf() = 0; virtual GDecisionTreeNode* DeepCopy(GArffRelation* pRelation, GDecisionTreeNode* pInterestingNode, GDecisionTreeNode** ppOutInterestingCopy) = 0; virtual void Print(GArffRelation* pRelation, int nSpaces, const char* szValue) = 0; virtual void CountValues(int nOutput, int* pnCounts) = 0; virtual double FindSumOutputValue(int nOutput) = 0;};class GDecisionTreeInteriorNode : public GDecisionTreeNode{friend class GDecisionTree;protected: int m_nAttribute; double m_dPivot; int m_nChildren; GDecisionTreeNode** m_ppChildren;public: GDecisionTreeInteriorNode(int nAttribute, double dPivot) : GDecisionTreeNode() { m_nAttribute = nAttribute; m_dPivot = dPivot; m_nChildren = 0; m_ppChildren = NULL; } virtual ~GDecisionTreeInteriorNode() { if(m_ppChildren) { int n; for(n = 0; n < m_nChildren; n++) delete(m_ppChildren[n]); delete(m_ppChildren); } } virtual bool IsLeaf() { return false; } virtual GDecisionTreeNode* DeepCopy(GArffRelation* pRelation, GDecisionTreeNode* pInterestingNode, GDecisionTreeNode** ppOutInterestingCopy) { GDecisionTreeInteriorNode* pNewNode = new GDecisionTreeInteriorNode(m_nAttribute, m_dPivot); pNewNode->m_nChildren = m_nChildren; pNewNode->m_ppChildren = new GDecisionTreeNode*[m_nChildren]; int n; for(n = 0; n < m_nChildren; n++) pNewNode->m_ppChildren[n] = m_ppChildren[n]->DeepCopy(pRelation, pInterestingNode, ppOutInterestingCopy); if(this == pInterestingNode) *ppOutInterestingCopy = pNewNode; return pNewNode; } virtual void Print(GArffRelation* pRelation, int nSpaces, const char* szValue) { int n; for(n = 0; n < nSpaces; n++) printf(" "); GArffAttribute* pAttr = pRelation->GetAttribute(m_nAttribute); if(pAttr->IsContinuous()) printf("%s -> %s (%f)?\n", szValue, pAttr->GetName(), m_dPivot); else printf("%s -> %s?\n", szValue, pAttr->GetName()); for(n = 0; n < m_nChildren; n++) m_ppChildren[n]->Print(pRelation, nSpaces + 1, pAttr->GetValue(n)); } // Recursive function that counts the number of times a particular // value is found in a particular output in this branch of the tree virtual void CountValues(int nOutput, int* pnCounts) { int n; for(n = 0; n < m_nChildren; n++) m_ppChildren[n]->CountValues(nOutput, pnCounts); } virtual double FindSumOutputValue(int nOutput) { double dSum = 0; int n; for(n = 0; n < m_nChildren; n++) dSum += m_ppChildren[n]->FindSumOutputValue(nOutput); return dSum; }/* void PruneChildren(GArffRelation* pRelation) { // Create output values by finding the most common outputs among children GAssert(m_ppChildren, "This is a leaf node"); int nOutputCount = pRelation->GetOutputCount(); m_pOutputValues = new double[nOutputCount]; int n; for(n = 0; n < nOutputCount; n++) { // Count the number of occurrences of each possible value for this output attribute GArffAttribute* pAttr = pRelation->GetAttribute(pRelation->GetOutputIndex(n)); int nValueCount = pAttr->GetValueCount(); if(nValueCount <= 0) m_pOutputValues[n] = FindSumOutputValue(n) / m_nSampleSize; else { Holder<int*> hCounts(new int[nValueCount]); int* pnCounts = hCounts.Get(); memset(pnCounts, '\0', sizeof(int) * nValueCount); CountValues(n, pnCounts); // Find the most frequent value int i; int nMax = 0; for(i = 1; i < nValueCount; i++) { if(pnCounts[i] > pnCounts[nMax]) nMax = i; } m_pOutputValues[n] = (double)nMax; } } // Delete the children for(n = 0; n < m_nChildren; n++) delete(m_ppChildren[n]); delete(m_ppChildren); m_ppChildren = NULL; }*/};class GDecisionTreeLeafNode : public GDecisionTreeNode{public: double* m_pOutputValues; int m_nSampleSize;public: GDecisionTreeLeafNode(double* pOutputValues, int nSampleSize) : GDecisionTreeNode() { m_pOutputValues = pOutputValues; m_nSampleSize = nSampleSize; } virtual ~GDecisionTreeLeafNode() { delete[] m_pOutputValues; } virtual bool IsLeaf() { return true; } virtual GDecisionTreeNode* DeepCopy(GArffRelation* pRelation, GDecisionTreeNode* pInterestingNode, GDecisionTreeNode** ppOutInterestingCopy) { int nCount = pRelation->GetOutputCount(); double* pOutputValues = new double[nCount]; int n; for(n = 0; n < nCount; n++) pOutputValues[n] = m_pOutputValues[n]; GDecisionTreeLeafNode* pNewNode = new GDecisionTreeLeafNode(pOutputValues, m_nSampleSize); if(this == pInterestingNode) *ppOutInterestingCopy = pNewNode; return pNewNode; } virtual void Print(GArffRelation* pRelation, int nSpaces, const char* szValue) { int n; for(n = 0; n < nSpaces; n++) printf(" "); int nCount = pRelation->GetOutputCount(); printf("%s -> ", szValue); for(n = 0; n < nCount; n++) { GArffAttribute* pAttr = pRelation->GetAttribute(pRelation->GetOutputIndex(n)); if(n > 0) printf(", "); printf("%s=%s", pAttr->GetName(), pAttr->GetValue((int)m_pOutputValues[n])); } printf("\n"); } virtual void CountValues(int nOutput, int* pnCounts) { int nVal = (int)m_pOutputValues[nOutput]; pnCounts[nVal] += m_nSampleSize; } virtual double FindSumOutputValue(int nOutput) { return m_pOutputValues[nOutput] * m_nSampleSize; }};// -----------------------------------------------------------------GDecisionTree::GDecisionTree(GArffRelation* pRelation, DivisionAlgorithm eAlg): GSupervisedLearner(pRelation){ m_pRoot = NULL; m_eAlg = eAlg; m_dTrainingPortion = .65;}GDecisionTree::GDecisionTree(GDecisionTree* pThat, GDecisionTreeNode* pInterestingNode, GDecisionTreeNode** ppOutInterestingCopy): GSupervisedLearner(pThat->m_pRelation){ m_pRelation = pThat->m_pRelation; m_pRoot = pThat->m_pRoot->DeepCopy(pThat->m_pRelation, pInterestingNode, ppOutInterestingCopy);}GDecisionTree::~GDecisionTree(){ delete(m_pRoot);}void GDecisionTree::Train(GArffData* pData){// int nTrainRows = (int)(m_dTrainingPortion * pData->GetSize());// GArffData* pPruningData = pData->SplitBySize(nTrainRows); TrainWithoutPruning(pData);// Prune(pPruningData);// pData->Merge(pPruningData);// delete(pPruningData);}void GDecisionTree::TrainWithoutPruning(GArffData* pTrainingData){ delete(m_pRoot); int nAttributes = m_pRelation->GetAttributeCount(); Holder<bool*> hUsedAttributes(new bool[nAttributes]); bool* pUsedAttributes = hUsedAttributes.Get(); int n; for(n = 0; n < nAttributes; n++) pUsedAttributes[n] = false; m_pRoot = BuildNode(pTrainingData, pUsedAttributes);}int GDecisionTree::PickDivision(GArffData* pData, double* pPivot, bool* pUsedAttributes)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -