📄 gdecisiontree.cpp
字号:
{ if(m_eAlg == MINIMIZE_ENTROPY) { // Pick the best attribute to divide on GAssert(pData->GetSize() > 0, "Can't work without data"); double dBestGain = -1e100; double dBestPivot = 0; int nBestAttribute = -1; if(!pData->IsOutputHomogenous(m_pRelation)) { double dGain; double dPivot; int nAttr; int nInputCount = m_pRelation->GetInputCount(); int n; for(n = 0; n < nInputCount; n++) { nAttr = m_pRelation->GetInputIndex(n); if(pUsedAttributes[nAttr]) continue; dGain = MeasureInfoGain(pData, nAttr, &dPivot); if(nBestAttribute < 0 || dGain > dBestGain) { dBestGain = dGain; nBestAttribute = nAttr; dBestPivot = dPivot; } } } *pPivot = dBestPivot; return nBestAttribute; } else if(m_eAlg == RANDOM) { if(!pData->IsOutputHomogenous(m_pRelation)) { int nInputCount = m_pRelation->GetInputCount(); int n, i; double d; for(i = 0; i < 4; i++) { n = rand() % nInputCount; int nAttr = m_pRelation->GetInputIndex(n); if(pUsedAttributes[nAttr]) continue; GArffAttribute* pAttr = m_pRelation->GetAttribute(nAttr); if(pAttr->IsContinuous()) { if(!pData->PickPivotToReduceInfo(pPivot, &d, m_pRelation, nAttr)) continue; // couldn't find a suitable pivot } return nAttr; } int nStart = rand() % nInputCount; for(i = 0; i < nInputCount; i++) { int nAttr = m_pRelation->GetInputIndex((i + nStart) % nInputCount); if(!pUsedAttributes[nAttr]) { GArffAttribute* pAttr = m_pRelation->GetAttribute(nAttr); if(pAttr->IsContinuous()) { if(!pData->PickPivotToReduceInfo(pPivot, &d, m_pRelation, nAttr)) continue; // couldn't find a suitable pivot } return nAttr; } } } *pPivot = 0; return -1; } else GAssert(false, "unknown division algorithm"); return -1;}// This constructs the decision tree in a recursive depth-first mannerGDecisionTreeNode* GDecisionTree::BuildNode(GArffData* pData, bool* pUsedAttributes){ int n;#ifdef DEBUGLOG // Log debug stuff dbglog1("BuildNode from %d rows\n", pData->GetSize()); int nAttrCount = pRelation->GetAttributeCount(); for(n = 0; n < pData->GetSize(); n++) { double* pRow = pData->GetRow(n); dbglog0("\t"); int i; for(i = 0; i < nAttrCount; i++) { GArffAttribute* pAttr = pRelation->GetAttribute(i); dbglog1("%s, ", pAttr->GetValue((int)pRow[i])); } dbglog0("\n"); }#endif // DEBUGLOG // Pick the division double dBestPivot; int nBestAttribute = PickDivision(pData, &dBestPivot, pUsedAttributes); Holder<double*> hMostCommonOutputs(pData->MakeSetOfMostCommonOutputs(m_pRelation)); GAssert(hMostCommonOutputs.Get(), "Failed to get output values"); if(nBestAttribute < 0) { // There are no input attributes left on which to divide, so this is a leaf dbglog0("Leaf\n"); return new GDecisionTreeLeafNode(hMostCommonOutputs.Drop(), pData->GetSize()); } GAssert(nBestAttribute < m_pRelation->GetAttributeCount(), "out of range"); // Get rid of any unknown values for the best attribute pData->ReplaceMissingAttributeWithMostCommonValue(m_pRelation, nBestAttribute); // Create child nodes GDecisionTreeInteriorNode* pNode = new GDecisionTreeInteriorNode(nBestAttribute, dBestPivot); GArffAttribute* pAttr = m_pRelation->GetAttribute(nBestAttribute); dbglog2("Attribute=%d (%s)\n", nBestAttribute, pAttr->GetName()); GAssert(pAttr->IsInput(), "Expected an input"); GArffData** ppParts; int nChildCount; if(pAttr->IsContinuous()) { ppParts = new GArffData*[2]; ppParts[0] = pData->SplitByPivot(nBestAttribute, dBestPivot); ppParts[1] = new GArffData(pData->GetSize()); ppParts[1]->Merge(pData); nChildCount = 2; //GAssert(ppParts[0]->GetSize() > 0 && ppParts[1]->GetSize() > 0, "bad pivot"); } else { ppParts = pData->SplitByAttribute(m_pRelation, nBestAttribute); nChildCount = pAttr->GetValueCount(); pUsedAttributes[nBestAttribute] = true; } pNode->m_nChildren = nChildCount; pNode->m_ppChildren = new GDecisionTreeNode*[nChildCount]; for(n = 0; n < nChildCount; n++) { if(ppParts[n] && ppParts[n]->GetSize() > 0) { pNode->m_ppChildren[n] = BuildNode(ppParts[n], pUsedAttributes); pData->Merge(ppParts[n]); } else { // There's no data for this child, so just use the most common outputs of the parent double* pOutputValues = new double[m_pRelation->GetOutputCount()]; GAssert(hMostCommonOutputs.Get(), "no outputs"); memcpy(pOutputValues, hMostCommonOutputs.Get(), sizeof(double) * m_pRelation->GetOutputCount()); pNode->m_ppChildren[n] = new GDecisionTreeLeafNode(pOutputValues, 0); } delete(ppParts[n]); } delete[] ppParts; pUsedAttributes[nBestAttribute] = false; return pNode;}double GDecisionTree::MeasureInfoGain(GArffData* pData, int nAttribute, double* pPivot){ // Measure initial output info double dGain = m_pRelation->MeasureTotalOutputInfo(pData); if(dGain == 0) return 0; // Seperate by attribute values and measure difference in output info GArffAttribute* pAttr = m_pRelation->GetAttribute(nAttribute); GAssert(pAttr->IsInput(), "expected an input attribute"); if(pAttr->IsContinuous()) { double dSumOutputInfo; if(!pData->PickPivotToReduceInfo(pPivot, &dSumOutputInfo, m_pRelation, nAttribute)) return -1e200; // definately don't pick this attribute because it doesn't separate anything dGain -= dSumOutputInfo; return dGain; } else { *pPivot = 0; int nRowCount = pData->GetSize(); GArffData** ppParts = pData->SplitByAttribute(m_pRelation, nAttribute); int nCount = pAttr->GetValueCount(); int n; for(n = 0; n < nCount; n++) { dGain -= ((double)ppParts[n]->GetSize() / nRowCount) * m_pRelation->MeasureTotalOutputInfo(ppParts[n]); pData->Merge(ppParts[n]); delete(ppParts[n]); } delete(ppParts); GAssert(pData->GetSize() == nRowCount, "Didn't reassemble data correctly"); return dGain; }}void GDecisionTree::Eval(double* pRow){ GAssert(m_pRoot, "Not trained yet"); GDecisionTreeNode* pNode = m_pRoot; GArffAttribute* pAttr; int nVal; while(!pNode->IsLeaf()) { GDecisionTreeInteriorNode* pInterior = (GDecisionTreeInteriorNode*)pNode; pAttr = m_pRelation->GetAttribute(pInterior->m_nAttribute); GAssert(pAttr->IsInput(), "expected an input"); if(pAttr->IsContinuous()) { if(pRow[pInterior->m_nAttribute] < pInterior->m_dPivot) pNode = pInterior->m_ppChildren[0]; else pNode = pInterior->m_ppChildren[1]; } else { nVal = (int)pRow[pInterior->m_nAttribute]; if(nVal < 0) { GAssert(nVal == -1, "out of range"); nVal = rand() % pAttr->GetValueCount(); } GAssert(nVal < pAttr->GetValueCount(), "value out of range"); pNode = pInterior->m_ppChildren[nVal]; } } // Copy the output values into the row GDecisionTreeLeafNode* pLeaf = (GDecisionTreeLeafNode*)pNode; int n; int nOutputCount = m_pRelation->GetOutputCount(); for(n = 0; n < nOutputCount; n++) pRow[m_pRelation->GetOutputIndex(n)] = pLeaf->m_pOutputValues[n];}void GDecisionTree::Print(){ m_pRoot->Print(m_pRelation, 0, "All");}/*void GDecisionTree::DeepPruneNode(GDecisionTreeNode* pNode, GArffData* pValidationSet){ if(!pNode->m_ppChildren) return; int n; for(n = 0; n < pNode->m_nChildren; n++) DeepPruneNode(pNode->m_ppChildren[n], pValidationSet); GDecisionTreeNode* pNodeCopy; GDecisionTree tmp(this, pNode, &pNodeCopy); pNodeCopy->PruneChildren(m_pRelation); double dOriginalScore = MeasurePredictiveAccuracy(pValidationSet); double dPrunedScore = tmp.MeasurePredictiveAccuracy(pValidationSet); if(dPrunedScore >= dOriginalScore) pNode->PruneChildren(m_pRelation);}void GDecisionTree::Prune(GArffData* pValidationSet){ DeepPruneNode(m_pRoot, pValidationSet);}*/// virtualvoid GDecisionTree::Reset(){ delete(m_pRoot); m_pRoot = NULL;}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -