📄 gknn.cpp
字号:
/* Copyright (C) 2006, Mike Gashler This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. see http://www.gnu.org/copyleft/lesser.html*/#include "GKNN.h"#include "GArff.h"#include "GArray.h"#include <math.h>#include "GVector.h"#include "GMacros.h"#include <stdlib.h>#include "GBits.h"class GNeighborFinderNode{public: GNeighborFinderNode() { } virtual ~GNeighborFinderNode() { } virtual bool IsLeaf() = 0; virtual GNeighborFinderNode* Split(int nInput, double dPivot, GArffRelation* pRelation, GArffData* pPoints, int nMaxPointsPerLeaf) = 0;};class GNeighborFinderInterior : public GNeighborFinderNode{protected: double m_dPivot; int m_nInput; GNeighborFinderNode* m_pLesser; GNeighborFinderNode* m_pGreater;public: GNeighborFinderInterior(int nInput, double dPivot, GNeighborFinderNode* pLesser, GNeighborFinderNode* pGreater) : GNeighborFinderNode() { m_dPivot = dPivot; m_nInput = nInput; m_pLesser = pLesser; m_pGreater = pGreater; } virtual ~GNeighborFinderInterior() { delete(m_pLesser); delete(m_pGreater); } virtual bool IsLeaf() { return false; } inline GNeighborFinderNode* GetLesser() { return m_pLesser; } inline GNeighborFinderNode* GetGreater() { return m_pGreater; } inline int GetInput() { return m_nInput; } inline double GetPivot() { return m_dPivot; } virtual GNeighborFinderNode* Split(int nInput, double dPivot, GArffRelation* pRelation, GArffData* pPoints, int nMaxPointsPerLeaf) { if(m_nInput == nInput) { if(dPivot >= m_dPivot) m_pGreater = m_pGreater->Split(nInput, dPivot, pRelation, pPoints, nMaxPointsPerLeaf); else m_pLesser = m_pLesser->Split(nInput, dPivot, pRelation, pPoints, nMaxPointsPerLeaf); } else { m_pLesser = m_pLesser->Split(nInput, dPivot, pRelation, pPoints, nMaxPointsPerLeaf); m_pGreater = m_pGreater->Split(nInput, dPivot, pRelation, pPoints, nMaxPointsPerLeaf); } return this; }};struct GNeighborFinderSorterStruct{ GArffData* pPoints; int nDimension;};int GNeighborFinderSorterFunc(void* pThis, int a, int b){ struct GNeighborFinderSorterStruct* pStruct = (struct GNeighborFinderSorterStruct*)pThis; double* pA = pStruct->pPoints->GetVector(a); double* pB = pStruct->pPoints->GetVector(b); if(pA[pStruct->nDimension] < pB[pStruct->nDimension]) return -1; else if(pA[pStruct->nDimension] > pB[pStruct->nDimension]) return 1; else return 0;}class GNeighborFinderLeaf : public GNeighborFinderNode{protected: double* m_pMins; double* m_pMaxs; GNeighborFinderLeaf** m_ppGreater; GNeighborFinderLeaf** m_ppLesser; GIntArray m_data;public: GNeighborFinderLeaf(GArffRelation* pRelation, int nMaxPointsPerLeaf) : GNeighborFinderNode(), m_data(nMaxPointsPerLeaf + 1) { int nInputs = pRelation->GetInputCount(); m_pMins = new double[nInputs]; m_pMaxs = new double[nInputs]; m_ppLesser = new GNeighborFinderLeaf*[nInputs]; memset(m_ppLesser, '\0', sizeof(GNeighborFinderLeaf*) * nInputs); m_ppGreater = new GNeighborFinderLeaf*[nInputs]; memset(m_ppGreater, '\0', sizeof(GNeighborFinderLeaf*) * nInputs); } virtual ~GNeighborFinderLeaf() { delete[] m_ppGreater; delete[] m_ppLesser; delete[] m_pMaxs; delete[] m_pMins; } void SetToGlobalRange(GArffRelation* pRelation) { int nInputs = pRelation->GetInputCount(); int i; for(i = 0; i < nInputs; i++) { m_pMins[i] = -1e200; m_pMaxs[i] = 1e200; } } virtual bool IsLeaf() { return true; } GIntArray* GetData() { return &m_data; } inline GNeighborFinderLeaf* GetLesser(int nInput) { return m_ppLesser[nInput]; } inline GNeighborFinderLeaf* GetGreater(int nInput) { return m_ppGreater[nInput]; } inline double GetMin(int nInput) { return m_pMins[nInput]; } inline double GetMax(int nInput) { return m_pMaxs[nInput]; } bool SuggestPivot(int* pnInput, double* pdPivot, GArffRelation* pRelation, GArffData* pPoints) { // Make sure a split would do some good int nPoints = m_data.GetSize(); if(pRelation->ComputeInputDistanceSquared( pPoints->GetVector(m_data.GetInt(rand() % nPoints)), pPoints->GetVector(m_data.GetInt(rand() % nPoints)) ) == 0) return false; // Pivot at the median of the dimension with the biggest range double dBiggestGap = 0; int nBestInput = -1; double dMedian = 0; int nInputs = pRelation->GetInputCount(); double dGap; int i, index; struct GNeighborFinderSorterStruct compareStruct; compareStruct.pPoints = pPoints; for(i = 0; i < nInputs; i++) { index = pRelation->GetInputIndex(i); compareStruct.nDimension = index; m_data.Sort(GNeighborFinderSorterFunc, &compareStruct); dGap = pPoints->GetVector(m_data.GetInt(m_data.GetSize() - 1))[index] - pPoints->GetVector(m_data.GetInt(0))[index]; if(dGap > dBiggestGap) { dBiggestGap = dGap; dMedian = pPoints->GetVector(m_data.GetInt(m_data.GetSize() / 2))[index]; nBestInput = i; } } GAssert(nBestInput >= 0, "internal error"); *pnInput = nBestInput; *pdPivot = dMedian; return true; } virtual GNeighborFinderNode* Split(int nInput, double dPivot, GArffRelation* pRelation, GArffData* pPoints, int nMaxPointsPerLeaf) { // Split the cell GNeighborFinderLeaf* pNewGreater = new GNeighborFinderLeaf(pRelation, nMaxPointsPerLeaf); int nInputs = pRelation->GetInputCount(); int i; for(i = 0; i < nInputs; i++) { if(i == nInput) { if(m_ppGreater[i]) m_ppGreater[i]->m_ppLesser[i] = pNewGreater; pNewGreater->m_ppLesser[i] = this; pNewGreater->m_ppGreater[i] = m_ppGreater[i]; m_ppGreater[i] = pNewGreater; pNewGreater->m_pMins[i] = dPivot; pNewGreater->m_pMaxs[i] = m_pMaxs[i]; m_pMaxs[i] = dPivot; } else { pNewGreater->m_pMins[i] = m_pMins[i]; pNewGreater->m_pMaxs[i] = m_pMaxs[i]; if(m_ppLesser[i]) { GNeighborFinderLeaf* pNewGreaterNeighbor = m_ppLesser[i]->m_ppGreater[nInput]; if(pNewGreaterNeighbor && pNewGreaterNeighbor->m_ppGreater[i] == NULL) { pNewGreater->m_ppLesser[i] = pNewGreaterNeighbor; pNewGreaterNeighbor->m_ppGreater[i] = pNewGreater; } } if(m_ppGreater[i]) { GNeighborFinderLeaf* pNewGreaterNeighbor = m_ppGreater[i]->m_ppGreater[nInput]; if(pNewGreaterNeighbor && pNewGreaterNeighbor->m_ppLesser[i] == NULL) { pNewGreaterNeighbor->m_ppLesser[i] = pNewGreater; pNewGreater->m_ppGreater[i] = pNewGreaterNeighbor; } } } } // Split the data GIntArray* pGreaterData = pNewGreater->GetData(); int tmp; for(i = 0; i < m_data.GetSize(); i++) { tmp = m_data.GetInt(i); if(pPoints->GetVector(tmp)[pRelation->GetInputIndex(nInput)] >= dPivot) { m_data.DeleteCell(i--); pGreaterData->AddInt(tmp); } } return new GNeighborFinderInterior(nInput, dPivot, this, pNewGreater); }};// -------------------------------------------------------------------------------GNeighborFinder::GNeighborFinder(GArffRelation* pRelation, GArffData* pData, int nMaxPointsPerLeaf){ int nInputs = pRelation->GetInputCount(); if(nMaxPointsPerLeaf < 4) nMaxPointsPerLeaf = 4; m_pRelation = pRelation; m_pData = pData; m_nMaxPointsPerLeaf = nMaxPointsPerLeaf; m_pRoot = new GNeighborFinderLeaf(pRelation, nMaxPointsPerLeaf); ((GNeighborFinderLeaf*)m_pRoot)->SetToGlobalRange(pRelation); m_ppIterators = new GNeighborFinderLeaf*[nInputs]; m_pMaxs = new double[nInputs]; // Make a random ordering int nCount = pData->GetSize(); int* pIndexes = new int[nCount]; ArrayHolder<int*> hIndexes(pIndexes); int i, index, tmp; for(i = 0; i < nCount; i++) pIndexes[i] = i; while(i > 1) { index = rand() % i; tmp = pIndexes[index]; pIndexes[index] = pIndexes[--i]; pIndexes[i] = tmp; } // Add each point for(i = 0; i < nCount; i++) { index = pIndexes[i]; GNeighborFinderLeaf* pLeaf = FindCell(pData->GetVector(index)); GIntArray* pLeafData = pLeaf->GetData(); pLeafData->AddInt(index); if(pLeafData->GetSize() > m_nMaxPointsPerLeaf) { int nInput; double dPivot; if(pLeaf->SuggestPivot(&nInput, &dPivot, m_pRelation, pData)) m_pRoot = m_pRoot->Split(nInput, dPivot, m_pRelation, pData, m_nMaxPointsPerLeaf); } }}GNeighborFinder::~GNeighborFinder(){ delete(m_pRoot); delete[] m_ppIterators; delete[] m_pMaxs;}#ifndef NO_TEST_CODE/*static*/ void GNeighborFinder::Test(){ // Generate the data int nDimensions = 10; int nNeighbors = 20; int nMaxPointsPerLeaf = 5; int nPoints = 500; GArffRelation rel; int i, j, k; for(i = 0; i < nDimensions; i++) rel.AddAttribute(new GArffAttribute(true, 0, NULL)); GArffData data(nPoints); double* pVector; for(i = 0; i < nPoints; i++) { pVector = new double[nDimensions]; for(j = 0; j < nDimensions; j++) pVector[j] = GBits::GetRandomDouble() * 100; data.AddVector(pVector); } // Find neighbors using GNeighborFinder double* pVector2; int* pNeighbors = new int[nNeighbors]; ArrayHolder<int*> hNeighbors(pNeighbors); double* pDistances = new double[nNeighbors]; ArrayHolder<double*> hDistances(pDistances); GNeighborFinder gnf(&rel, &data, nMaxPointsPerLeaf); for(i = 0; i < nPoints; i++) { // Find the neighbors pVector = data.GetVector(i); gnf.FindNeighbors(pNeighbors, pDistances, nNeighbors, pVector, i); // Check the answer int nWorstNeighbor = -1; double dWorstDistance = 0; double d; for(j = 0; j < nNeighbors; j++) { if(pNeighbors[j] < 0) throw "Didn't retrieve enough neighbors";
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -