gknn.cpp
来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C++ 代码 · 共 1,007 行 · 第 1/2 页
CPP
1,007 行
/* Copyright (C) 2006, Mike Gashler This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. see http://www.gnu.org/copyleft/lesser.html*/#include "GKNN.h"#include "GArff.h"#include "GArray.h"#include <math.h>#include "GVector.h"#include "GMacros.h"#include <stdlib.h>#include "GBits.h"class GNeighborFinderNode{public: GNeighborFinderNode() { } virtual ~GNeighborFinderNode() { } virtual bool IsLeaf() = 0; virtual GNeighborFinderNode* Split(int nInput, double dPivot, GArffRelation* pRelation, GArffData* pPoints, int nMaxPointsPerLeaf, bool* pDidSomeGood) = 0;};class GNeighborFinderInterior : public GNeighborFinderNode{protected: double m_dPivot; int m_nInput; GNeighborFinderNode* m_pLesser; GNeighborFinderNode* m_pGreater;public: GNeighborFinderInterior(int nInput, double dPivot, GNeighborFinderNode* pLesser, GNeighborFinderNode* pGreater) : GNeighborFinderNode() { m_dPivot = dPivot; m_nInput = nInput; m_pLesser = pLesser; m_pGreater = pGreater; } virtual ~GNeighborFinderInterior() { delete(m_pLesser); delete(m_pGreater); } virtual bool IsLeaf() { return false; } inline GNeighborFinderNode* GetLesser() { return m_pLesser; } inline GNeighborFinderNode* GetGreater() { return m_pGreater; } inline int GetInput() { return m_nInput; } inline double GetPivot() { return m_dPivot; } virtual GNeighborFinderNode* Split(int nInput, double dPivot, GArffRelation* pRelation, GArffData* pPoints, int nMaxPointsPerLeaf, bool* pDidSomeGood) { if(m_nInput == nInput) { if(dPivot >= m_dPivot) m_pGreater = m_pGreater->Split(nInput, dPivot, pRelation, pPoints, nMaxPointsPerLeaf, pDidSomeGood); else m_pLesser = m_pLesser->Split(nInput, dPivot, pRelation, pPoints, nMaxPointsPerLeaf, pDidSomeGood); } else { m_pLesser = m_pLesser->Split(nInput, dPivot, pRelation, pPoints, nMaxPointsPerLeaf, pDidSomeGood); m_pGreater = m_pGreater->Split(nInput, dPivot, pRelation, pPoints, nMaxPointsPerLeaf, pDidSomeGood); } return this; }};struct GNeighborFinderSorterStruct{ GArffData* pPoints; int nDimension;};int GNeighborFinderSorterFunc(void* pThis, int a, int b){ struct GNeighborFinderSorterStruct* pStruct = (struct GNeighborFinderSorterStruct*)pThis; double* pA = pStruct->pPoints->GetVector(a); double* pB = pStruct->pPoints->GetVector(b); if(pA[pStruct->nDimension] < pB[pStruct->nDimension]) return -1; else if(pA[pStruct->nDimension] > pB[pStruct->nDimension]) return 1; else return 0;}class GNeighborFinderLeaf : public GNeighborFinderNode{protected: double* m_pMins; double* m_pMaxs; GNeighborFinderLeaf** m_ppGreater; GNeighborFinderLeaf** m_ppLesser; GIntArray m_data;public: GNeighborFinderLeaf(GArffRelation* pRelation, int nMaxPointsPerLeaf) : GNeighborFinderNode(), m_data(nMaxPointsPerLeaf + 1) { int nInputs = pRelation->GetInputCount(); m_pMins = new double[nInputs]; m_pMaxs = new double[nInputs]; m_ppLesser = new GNeighborFinderLeaf*[nInputs]; memset(m_ppLesser, '\0', sizeof(GNeighborFinderLeaf*) * nInputs); m_ppGreater = new GNeighborFinderLeaf*[nInputs]; memset(m_ppGreater, '\0', sizeof(GNeighborFinderLeaf*) * nInputs); } virtual ~GNeighborFinderLeaf() { delete[] m_ppGreater; delete[] m_ppLesser; delete[] m_pMaxs; delete[] m_pMins; } void SetToGlobalRange(GArffRelation* pRelation) { int nInputs = pRelation->GetInputCount(); int i; for(i = 0; i < nInputs; i++) { m_pMins[i] = -1e200; m_pMaxs[i] = 1e200; } } virtual bool IsLeaf() { return true; } GIntArray* GetData() { return &m_data; } inline GNeighborFinderLeaf* GetLesser(int nInput) { return m_ppLesser[nInput]; } inline GNeighborFinderLeaf* GetGreater(int nInput) { return m_ppGreater[nInput]; } inline double GetMin(int nInput) { return m_pMins[nInput]; } inline double GetMax(int nInput) { return m_pMaxs[nInput]; } bool SuggestPivot(int* pnInput, double* pdPivot, GArffRelation* pRelation, GArffData* pPoints) { // Make sure a split would do some good int nPoints = m_data.GetSize(); if(pRelation->ComputeInputDistanceSquared( pPoints->GetVector(m_data.GetInt(rand() % nPoints)), pPoints->GetVector(m_data.GetInt(rand() % nPoints)) ) == 0) return false; // Pivot at the median of the dimension with the biggest range double dBiggestGap = 0; int nBestInput = -1; double dMedian = 0; int nInputs = pRelation->GetInputCount(); double dGap; int i, index, nMidIndex; struct GNeighborFinderSorterStruct compareStruct; compareStruct.pPoints = pPoints; for(i = 0; i < nInputs; i++) { index = pRelation->GetInputIndex(i); compareStruct.nDimension = index; m_data.Sort(GNeighborFinderSorterFunc, &compareStruct); dGap = pPoints->GetVector(m_data.GetInt(m_data.GetSize() - 1))[index] - pPoints->GetVector(m_data.GetInt(0))[index]; if(dGap > dBiggestGap) { dBiggestGap = dGap; nMidIndex = m_data.GetSize() / 2; dMedian = pPoints->GetVector(m_data.GetInt(nMidIndex))[index]; // Make sure the median greater than the smallest value (or else this pivot is worthless, and may cause infinite recursion) while(dMedian <= pPoints->GetVector(m_data.GetInt(0))[index]) { GAssert(dMedian == pPoints->GetVector(m_data.GetInt(0))[index], "values not sorted properly"); GAssert(nMidIndex < m_data.GetSize() - 1, "huh? we already checked that there was a gap at the beginning of this method"); if(m_data.GetSize() - nMidIndex > 8) nMidIndex = (m_data.GetSize() + nMidIndex - 1) / 2; else nMidIndex++; dMedian = pPoints->GetVector(m_data.GetInt(nMidIndex))[index]; } nBestInput = i; } } GAssert(nBestInput >= 0, "internal error"); *pnInput = nBestInput; *pdPivot = dMedian; return true; } virtual GNeighborFinderNode* Split(int nInput, double dPivot, GArffRelation* pRelation, GArffData* pPoints, int nMaxPointsPerLeaf, bool* pDidSomeGood) { // Split the cell GNeighborFinderLeaf* pNewGreater = new GNeighborFinderLeaf(pRelation, nMaxPointsPerLeaf); int nInputs = pRelation->GetInputCount(); int i; for(i = 0; i < nInputs; i++) { if(i == nInput) { if(m_ppGreater[i]) m_ppGreater[i]->m_ppLesser[i] = pNewGreater; pNewGreater->m_ppLesser[i] = this; pNewGreater->m_ppGreater[i] = m_ppGreater[i]; m_ppGreater[i] = pNewGreater; pNewGreater->m_pMins[i] = dPivot; pNewGreater->m_pMaxs[i] = m_pMaxs[i]; m_pMaxs[i] = dPivot; } else { pNewGreater->m_pMins[i] = m_pMins[i]; pNewGreater->m_pMaxs[i] = m_pMaxs[i]; if(m_ppLesser[i]) { GNeighborFinderLeaf* pNewGreaterNeighbor = m_ppLesser[i]->m_ppGreater[nInput]; if(pNewGreaterNeighbor && pNewGreaterNeighbor->m_ppGreater[i] == NULL) { pNewGreater->m_ppLesser[i] = pNewGreaterNeighbor; pNewGreaterNeighbor->m_ppGreater[i] = pNewGreater; } } if(m_ppGreater[i]) { GNeighborFinderLeaf* pNewGreaterNeighbor = m_ppGreater[i]->m_ppGreater[nInput]; if(pNewGreaterNeighbor && pNewGreaterNeighbor->m_ppLesser[i] == NULL) { pNewGreaterNeighbor->m_ppLesser[i] = pNewGreater; pNewGreater->m_ppGreater[i] = pNewGreaterNeighbor; } } } } // Split the data GIntArray* pGreaterData = pNewGreater->GetData(); int tmp; for(i = 0; i < m_data.GetSize(); i++) { tmp = m_data.GetInt(i); if(pPoints->GetVector(tmp)[pRelation->GetInputIndex(nInput)] >= dPivot) { m_data.DeleteCell(i--); pGreaterData->AddInt(tmp); } } if(m_data.GetSize() > 0 && pGreaterData->GetSize() > 0) *pDidSomeGood = true; return new GNeighborFinderInterior(nInput, dPivot, this, pNewGreater); }};// -------------------------------------------------------------------------------GNeighborFinder::GNeighborFinder(GArffRelation* pRelation, GArffData* pData, int nMaxPointsPerLeaf){ int nInputs = pRelation->GetInputCount(); if(nMaxPointsPerLeaf < pRelation->GetInputCount()) nMaxPointsPerLeaf = pRelation->GetInputCount(); if(nMaxPointsPerLeaf < 4) nMaxPointsPerLeaf = 4; m_pRelation = pRelation; m_pData = pData; m_nMaxPointsPerLeaf = nMaxPointsPerLeaf; m_pRoot = new GNeighborFinderLeaf(pRelation, nMaxPointsPerLeaf); ((GNeighborFinderLeaf*)m_pRoot)->SetToGlobalRange(pRelation); m_ppIterators = new GNeighborFinderLeaf*[nInputs]; m_pMaxs = new double[nInputs]; // Make a random ordering int nCount = pData->GetSize(); int* pIndexes = new int[nCount]; ArrayHolder<int*> hIndexes(pIndexes); int i, index, tmp; for(i = 0; i < nCount; i++) pIndexes[i] = i; while(i > 1) { index = rand() % i; tmp = pIndexes[index]; pIndexes[index] = pIndexes[--i]; pIndexes[i] = tmp; } // Add each point double* pVector; for(i = 0; i < nCount; i++) { index = pIndexes[i]; pVector = pData->GetVector(index); InsertVector(pVector, index); }}GNeighborFinder::~GNeighborFinder(){ // m_pData is intentionally not deleted here delete(m_pRoot); delete[] m_ppIterators; delete[] m_pMaxs;}#ifndef NO_TEST_CODE/*static*/ void GNeighborFinder::Test(){ // Generate the data int nDimensions = 7; int nNeighbors = 10; int nMaxPointsPerLeaf = 5; int nPoints = 500; GArffRelation rel; int i, j, k; for(i = 0; i < nDimensions; i++) rel.AddAttribute(new GArffAttribute(true, 0, NULL)); GArffData data(nPoints); double* pVector; for(i = 0; i < nPoints; i++) { pVector = new double[nDimensions]; for(j = 0; j < nDimensions; j++) pVector[j] = GBits::GetRandomDouble() * 100; data.AddVector(pVector); } // Find neighbors using GNeighborFinder double* pVector2; int* pNeighbors = new int[nNeighbors]; ArrayHolder<int*> hNeighbors(pNeighbors); double* pDistances = new double[nNeighbors]; ArrayHolder<double*> hDistances(pDistances); GNeighborFinder gnf(&rel, &data, nMaxPointsPerLeaf); for(i = 0; i < nPoints; i++) { // Find the neighbors pVector = data.GetVector(i); gnf.FindNeighbors(pNeighbors, pDistances, nNeighbors, pVector, i); // Check the answer int nWorstNeighbor = -1; double dWorstDistance = 0; double d; for(j = 0; j < nNeighbors; j++) { if(pNeighbors[j] < 0) throw "Didn't retrieve enough neighbors"; d = rel.ComputeInputDistanceSquared(pVector, data.GetVector(pNeighbors[j])); if(d > dWorstDistance) { dWorstDistance = d; nWorstNeighbor = j; } } for(j = 0; j < nPoints; j++) { pVector2 = data.GetVector(j); d = rel.ComputeInputDistanceSquared(pVector, pVector2); if(d == 0) { } else if(d < dWorstDistance) { for(k = 0; k < nNeighbors; k++) { if(k > 0) { // This isn't a complete check, but it should be good enough if(pNeighbors[k] == pNeighbors[k - 1]) throw "same neighbor twice"; if(pNeighbors[k] == pNeighbors[0]) throw "same neighbor twice"; } if(pNeighbors[k] == i) throw "failed to exclude itself"; if(data.GetVector(pNeighbors[k]) == pVector2) break; } if(k >= nNeighbors) throw "missed a neighbor"; } } }}#endif // !NO_TEST_CODEvoid GNeighborFinder::AddVector(double* pVector){ int index = m_pData->GetSize(); m_pData->AddVector(pVector); InsertVector(pVector, index);}void GNeighborFinder::InsertVector(double* pVector, int nIndex){ GNeighborFinderLeaf* pLeaf = FindCell(pVector); GIntArray* pLeafData = pLeaf->GetData(); pLeafData->AddInt(nIndex); if(pLeafData->GetSize() > m_nMaxPointsPerLeaf) { int nInput; double dPivot; if(pLeaf->SuggestPivot(&nInput, &dPivot, m_pRelation, m_pData)) { bool bDidSomeGood = false; m_pRoot = m_pRoot->Split(nInput, dPivot, m_pRelation, m_pData, m_nMaxPointsPerLeaf, &bDidSomeGood); GAssert(bDidSomeGood, "The split didn't do any good. SuggestPivot must be broken"); } }}double* GNeighborFinder::DropVector(int nIndex){ // Remove the reference double* pVector = m_pData->GetVector(nIndex); GNeighborFinderLeaf* pLeaf = FindCell(pVector); GIntArray* pLeafData = pLeaf->GetData(); int nCount = pLeafData->GetSize(); int i; for(i = 0; i < nCount; i++) { if(pLeafData->GetInt(i) == nIndex) { pLeafData->SetInt(i, pLeafData->GetInt(nCount - 1)); pLeafData->DeleteCell(nCount - 1); break; } } GAssert(i < nCount, "failed to find the vector reference"); // Reindex the last reference to take this slot int nOldIndex = m_pData->GetSize() - 1; if(nOldIndex != nIndex) { pVector = m_pData->GetVector(nOldIndex); pLeaf = FindCell(pVector); pLeafData = pLeaf->GetData(); nCount = pLeafData->GetSize(); for(i = 0; i < nCount; i++) { if(pLeafData->GetInt(i) == nOldIndex) { pLeafData->SetInt(i, nIndex); break; } } GAssert(i < nCount, "failed to find the vector reference"); }
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?