gknn.cpp

来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C++ 代码 · 共 1,007 行 · 第 1/2 页

CPP
1,007
字号
/*	Copyright (C) 2006, Mike Gashler	This library is free software; you can redistribute it and/or	modify it under the terms of the GNU Lesser General Public	License as published by the Free Software Foundation; either	version 2.1 of the License, or (at your option) any later version.	see http://www.gnu.org/copyleft/lesser.html*/#include "GKNN.h"#include "GArff.h"#include "GArray.h"#include <math.h>#include "GVector.h"#include "GMacros.h"#include <stdlib.h>#include "GBits.h"class GNeighborFinderNode{public:	GNeighborFinderNode()	{	}	virtual ~GNeighborFinderNode()	{	}	virtual bool IsLeaf() = 0;	virtual GNeighborFinderNode* Split(int nInput, double dPivot, GArffRelation* pRelation, GArffData* pPoints, int nMaxPointsPerLeaf, bool* pDidSomeGood) = 0;};class GNeighborFinderInterior : public GNeighborFinderNode{protected:	double m_dPivot;	int m_nInput;	GNeighborFinderNode* m_pLesser;	GNeighborFinderNode* m_pGreater;public:	GNeighborFinderInterior(int nInput, double dPivot, GNeighborFinderNode* pLesser, GNeighborFinderNode* pGreater)	: GNeighborFinderNode()	{		m_dPivot = dPivot;		m_nInput = nInput;		m_pLesser = pLesser;		m_pGreater = pGreater;	}	virtual ~GNeighborFinderInterior()	{		delete(m_pLesser);		delete(m_pGreater);	}	virtual bool IsLeaf()	{		return false;	}	inline GNeighborFinderNode* GetLesser()	{		return m_pLesser;	}	inline GNeighborFinderNode* GetGreater()	{		return m_pGreater;	}	inline int GetInput()	{		return m_nInput;	}	inline double GetPivot()	{		return m_dPivot;	}	virtual GNeighborFinderNode* Split(int nInput, double dPivot, GArffRelation* pRelation, GArffData* pPoints, int nMaxPointsPerLeaf, bool* pDidSomeGood)	{		if(m_nInput == nInput)		{			if(dPivot >= m_dPivot)				m_pGreater = m_pGreater->Split(nInput, dPivot, pRelation, pPoints, nMaxPointsPerLeaf, pDidSomeGood);			else				m_pLesser = m_pLesser->Split(nInput, dPivot, pRelation, pPoints, nMaxPointsPerLeaf, pDidSomeGood);		}		else		{			m_pLesser = m_pLesser->Split(nInput, dPivot, pRelation, pPoints, nMaxPointsPerLeaf, pDidSomeGood);			m_pGreater = m_pGreater->Split(nInput, dPivot, pRelation, pPoints, nMaxPointsPerLeaf, pDidSomeGood);		}		return this;	}};struct GNeighborFinderSorterStruct{	GArffData* pPoints;	int nDimension;};int GNeighborFinderSorterFunc(void* pThis, int a, int b){	struct GNeighborFinderSorterStruct* pStruct = (struct GNeighborFinderSorterStruct*)pThis;	double* pA = pStruct->pPoints->GetVector(a);	double* pB = pStruct->pPoints->GetVector(b);	if(pA[pStruct->nDimension] < pB[pStruct->nDimension])		return -1;	else if(pA[pStruct->nDimension] > pB[pStruct->nDimension])		return 1;	else		return 0;}class GNeighborFinderLeaf : public GNeighborFinderNode{protected:	double* m_pMins;	double* m_pMaxs;	GNeighborFinderLeaf** m_ppGreater;	GNeighborFinderLeaf** m_ppLesser;	GIntArray m_data;public:	GNeighborFinderLeaf(GArffRelation* pRelation, int nMaxPointsPerLeaf)	: GNeighborFinderNode(), m_data(nMaxPointsPerLeaf + 1)	{		int nInputs = pRelation->GetInputCount();		m_pMins = new double[nInputs];		m_pMaxs = new double[nInputs];		m_ppLesser = new GNeighborFinderLeaf*[nInputs];		memset(m_ppLesser, '\0', sizeof(GNeighborFinderLeaf*) * nInputs);		m_ppGreater = new GNeighborFinderLeaf*[nInputs];		memset(m_ppGreater, '\0', sizeof(GNeighborFinderLeaf*) * nInputs);	}	virtual ~GNeighborFinderLeaf()	{		delete[] m_ppGreater;		delete[] m_ppLesser;		delete[] m_pMaxs;		delete[] m_pMins;	}	void SetToGlobalRange(GArffRelation* pRelation)	{		int nInputs = pRelation->GetInputCount();		int i;		for(i = 0; i < nInputs; i++)		{			m_pMins[i] = -1e200;			m_pMaxs[i] = 1e200;		}	}	virtual bool IsLeaf()	{		return true;	}	GIntArray* GetData()	{		return &m_data;	}	inline GNeighborFinderLeaf* GetLesser(int nInput)	{		return m_ppLesser[nInput];	}	inline GNeighborFinderLeaf* GetGreater(int nInput)	{		return m_ppGreater[nInput];	}	inline double GetMin(int nInput)	{		return m_pMins[nInput];	}	inline double GetMax(int nInput)	{		return m_pMaxs[nInput];	}	bool SuggestPivot(int* pnInput, double* pdPivot, GArffRelation* pRelation, GArffData* pPoints)	{		// Make sure a split would do some good		int nPoints = m_data.GetSize();		if(pRelation->ComputeInputDistanceSquared(				pPoints->GetVector(m_data.GetInt(rand() % nPoints)),				pPoints->GetVector(m_data.GetInt(rand() % nPoints))			) == 0)			return false;		// Pivot at the median of the dimension with the biggest range		double dBiggestGap = 0;		int nBestInput = -1;		double dMedian = 0;		int nInputs = pRelation->GetInputCount();		double dGap;		int i, index, nMidIndex;		struct GNeighborFinderSorterStruct compareStruct;		compareStruct.pPoints = pPoints;		for(i = 0; i < nInputs; i++)		{			index = pRelation->GetInputIndex(i);			compareStruct.nDimension = index;			m_data.Sort(GNeighborFinderSorterFunc, &compareStruct);			dGap = pPoints->GetVector(m_data.GetInt(m_data.GetSize() - 1))[index] - pPoints->GetVector(m_data.GetInt(0))[index];			if(dGap > dBiggestGap)			{				dBiggestGap = dGap;				nMidIndex = m_data.GetSize() / 2;				dMedian = pPoints->GetVector(m_data.GetInt(nMidIndex))[index];				// Make sure the median greater than the smallest value (or else this pivot is worthless, and may cause infinite recursion)				while(dMedian <= pPoints->GetVector(m_data.GetInt(0))[index])				{					GAssert(dMedian == pPoints->GetVector(m_data.GetInt(0))[index], "values not sorted properly");					GAssert(nMidIndex < m_data.GetSize() - 1, "huh? we already checked that there was a gap at the beginning of this method");					if(m_data.GetSize() - nMidIndex > 8)						nMidIndex = (m_data.GetSize() + nMidIndex - 1) / 2;					else						nMidIndex++;					dMedian = pPoints->GetVector(m_data.GetInt(nMidIndex))[index];				}				nBestInput = i;			}		}		GAssert(nBestInput >= 0, "internal error");		*pnInput = nBestInput;		*pdPivot = dMedian;		return true;	}	virtual GNeighborFinderNode* Split(int nInput, double dPivot, GArffRelation* pRelation, GArffData* pPoints, int nMaxPointsPerLeaf, bool* pDidSomeGood)	{		// Split the cell		GNeighborFinderLeaf* pNewGreater = new GNeighborFinderLeaf(pRelation, nMaxPointsPerLeaf);		int nInputs = pRelation->GetInputCount();		int i;		for(i = 0; i < nInputs; i++)		{			if(i == nInput)			{				if(m_ppGreater[i])					m_ppGreater[i]->m_ppLesser[i] = pNewGreater;				pNewGreater->m_ppLesser[i] = this;				pNewGreater->m_ppGreater[i] = m_ppGreater[i];				m_ppGreater[i] = pNewGreater;				pNewGreater->m_pMins[i] = dPivot;				pNewGreater->m_pMaxs[i] = m_pMaxs[i];				m_pMaxs[i] = dPivot;			}			else			{				pNewGreater->m_pMins[i] = m_pMins[i];				pNewGreater->m_pMaxs[i] = m_pMaxs[i];				if(m_ppLesser[i])				{					GNeighborFinderLeaf* pNewGreaterNeighbor = m_ppLesser[i]->m_ppGreater[nInput];					if(pNewGreaterNeighbor && pNewGreaterNeighbor->m_ppGreater[i] == NULL)					{						pNewGreater->m_ppLesser[i] = pNewGreaterNeighbor;						pNewGreaterNeighbor->m_ppGreater[i] = pNewGreater;					}				}				if(m_ppGreater[i])				{					GNeighborFinderLeaf* pNewGreaterNeighbor = m_ppGreater[i]->m_ppGreater[nInput];					if(pNewGreaterNeighbor && pNewGreaterNeighbor->m_ppLesser[i] == NULL)					{						pNewGreaterNeighbor->m_ppLesser[i] = pNewGreater;						pNewGreater->m_ppGreater[i] = pNewGreaterNeighbor;					}				}			}		}		// Split the data		GIntArray* pGreaterData = pNewGreater->GetData();		int tmp;		for(i = 0; i < m_data.GetSize(); i++)		{			tmp = m_data.GetInt(i);			if(pPoints->GetVector(tmp)[pRelation->GetInputIndex(nInput)] >= dPivot)			{				m_data.DeleteCell(i--);				pGreaterData->AddInt(tmp);			}		}		if(m_data.GetSize() > 0 && pGreaterData->GetSize() > 0)			*pDidSomeGood = true;		return new GNeighborFinderInterior(nInput, dPivot, this, pNewGreater);	}};// -------------------------------------------------------------------------------GNeighborFinder::GNeighborFinder(GArffRelation* pRelation, GArffData* pData, int nMaxPointsPerLeaf){	int nInputs = pRelation->GetInputCount();	if(nMaxPointsPerLeaf < pRelation->GetInputCount())		nMaxPointsPerLeaf = pRelation->GetInputCount();	if(nMaxPointsPerLeaf < 4)		nMaxPointsPerLeaf = 4;	m_pRelation = pRelation;	m_pData = pData;	m_nMaxPointsPerLeaf = nMaxPointsPerLeaf;	m_pRoot = new GNeighborFinderLeaf(pRelation, nMaxPointsPerLeaf);	((GNeighborFinderLeaf*)m_pRoot)->SetToGlobalRange(pRelation);	m_ppIterators = new GNeighborFinderLeaf*[nInputs];	m_pMaxs = new double[nInputs];	// Make a random ordering	int nCount = pData->GetSize();	int* pIndexes = new int[nCount];	ArrayHolder<int*> hIndexes(pIndexes);	int i, index, tmp;	for(i = 0; i < nCount; i++)		pIndexes[i] = i;	while(i > 1)	{		index = rand() % i;		tmp = pIndexes[index];		pIndexes[index] = pIndexes[--i];		pIndexes[i] = tmp;	}	// Add each point	double* pVector;	for(i = 0; i < nCount; i++)	{		index = pIndexes[i];		pVector = pData->GetVector(index);		InsertVector(pVector, index);	}}GNeighborFinder::~GNeighborFinder(){	// m_pData is intentionally not deleted here	delete(m_pRoot);	delete[] m_ppIterators;	delete[] m_pMaxs;}#ifndef NO_TEST_CODE/*static*/ void GNeighborFinder::Test(){	// Generate the data	int nDimensions = 7;	int nNeighbors = 10;	int nMaxPointsPerLeaf = 5;	int nPoints = 500;	GArffRelation rel;	int i, j, k;	for(i = 0; i < nDimensions; i++)		rel.AddAttribute(new GArffAttribute(true, 0, NULL));	GArffData data(nPoints);	double* pVector;	for(i = 0; i < nPoints; i++)	{		pVector = new double[nDimensions];		for(j = 0; j < nDimensions; j++)			pVector[j] = GBits::GetRandomDouble() * 100;		data.AddVector(pVector);	}	// Find neighbors using GNeighborFinder	double* pVector2;	int* pNeighbors = new int[nNeighbors];	ArrayHolder<int*> hNeighbors(pNeighbors);	double* pDistances = new double[nNeighbors];	ArrayHolder<double*> hDistances(pDistances);	GNeighborFinder gnf(&rel, &data, nMaxPointsPerLeaf);	for(i = 0; i < nPoints; i++)	{		// Find the neighbors		pVector = data.GetVector(i);		gnf.FindNeighbors(pNeighbors, pDistances, nNeighbors, pVector, i);		// Check the answer		int nWorstNeighbor = -1;		double dWorstDistance = 0;		double d;		for(j = 0; j < nNeighbors; j++)		{			if(pNeighbors[j] < 0)				throw "Didn't retrieve enough neighbors";			d = rel.ComputeInputDistanceSquared(pVector, data.GetVector(pNeighbors[j]));			if(d > dWorstDistance)			{				dWorstDistance = d;				nWorstNeighbor = j;			}		}		for(j = 0; j < nPoints; j++)		{			pVector2 = data.GetVector(j);			d = rel.ComputeInputDistanceSquared(pVector, pVector2);			if(d == 0)			{			}			else if(d < dWorstDistance)			{				for(k = 0; k < nNeighbors; k++)				{					if(k > 0)					{						// This isn't a complete check, but it should be good enough						if(pNeighbors[k] == pNeighbors[k - 1])							throw "same neighbor twice";						if(pNeighbors[k] == pNeighbors[0])							throw "same neighbor twice";					}					if(pNeighbors[k] == i)						throw "failed to exclude itself";					if(data.GetVector(pNeighbors[k]) == pVector2)						break;				}				if(k >= nNeighbors)					throw "missed a neighbor";			}		}	}}#endif // !NO_TEST_CODEvoid GNeighborFinder::AddVector(double* pVector){	int index = m_pData->GetSize();	m_pData->AddVector(pVector);	InsertVector(pVector, index);}void GNeighborFinder::InsertVector(double* pVector, int nIndex){	GNeighborFinderLeaf* pLeaf = FindCell(pVector);	GIntArray* pLeafData = pLeaf->GetData();	pLeafData->AddInt(nIndex);	if(pLeafData->GetSize() > m_nMaxPointsPerLeaf)	{		int nInput;		double dPivot;		if(pLeaf->SuggestPivot(&nInput, &dPivot, m_pRelation, m_pData))		{			bool bDidSomeGood = false;			m_pRoot = m_pRoot->Split(nInput, dPivot, m_pRelation, m_pData, m_nMaxPointsPerLeaf, &bDidSomeGood);			GAssert(bDidSomeGood, "The split didn't do any good. SuggestPivot must be broken");		}	}}double* GNeighborFinder::DropVector(int nIndex){	// Remove the reference	double* pVector = m_pData->GetVector(nIndex);	GNeighborFinderLeaf* pLeaf = FindCell(pVector);	GIntArray* pLeafData = pLeaf->GetData();	int nCount = pLeafData->GetSize();	int i;	for(i = 0; i < nCount; i++)	{		if(pLeafData->GetInt(i) == nIndex)		{			pLeafData->SetInt(i, pLeafData->GetInt(nCount - 1));			pLeafData->DeleteCell(nCount - 1);			break;		}	}	GAssert(i < nCount, "failed to find the vector reference");	// Reindex the last reference to take this slot	int nOldIndex = m_pData->GetSize() - 1;	if(nOldIndex != nIndex)	{		pVector = m_pData->GetVector(nOldIndex);		pLeaf = FindCell(pVector);		pLeafData = pLeaf->GetData();		nCount = pLeafData->GetSize();		for(i = 0; i < nCount; i++)		{			if(pLeafData->GetInt(i) == nOldIndex)			{				pLeafData->SetInt(i, nIndex);				break;			}		}		GAssert(i < nCount, "failed to find the vector reference");	}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?