gmanifold.cpp

来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C++ 代码 · 共 1,591 行 · 第 1/3 页

CPP
1,591
字号
/*	Copyright (C) 2006, Mike Gashler	This library is free software; you can redistribute it and/or	modify it under the terms of the GNU Lesser General Public	License as published by the Free Software Foundation; either	version 2.1 of the License, or (at your option) any later version.	see http://www.gnu.org/copyleft/lesser.html*/#include "GManifold.h"#include <stdio.h>#include "GArray.h"#include <math.h>#include "GPointerQueue.h"#include "GArff.h"#include "GMatrix.h"#include "GVector.h"#include "GBits.h"#include "GKNN.h"#include "GTime.h"#include "GAVLTree.h"//#define DOT_PRODUCT_ONLY//#define LINEAR_WEIGHTING/*class GManifoldNeighborNode{public:	double* m_pCenter;public:	GManifoldNeighborNode(int nCount, int* pIndexes, int nDims, GArffData* pData)	{		// Compute the center		m_pCenter = new double[nDims];		int i, j;		for(j = 0; j < nDims; j++)			m_pCenter[j] = 0;		double* pVector;		for(i = 0; i < nCount; i++)		{			pVector = pData->GetVector(pIndexes[i]);			for(j = 0; j < nDims; j++)				m_pCenter[j] += pVector[j];		}		for(j = 0; j < nDims; j++)			m_pCenter[j] /= nCount;	}	virtual ~GManifoldNeighborNode()	{		delete[] m_pCenter;	}	virtual bool IsLeaf() = 0;	virtual void FindNeighbors(GArffRelation* pRelation, GArffData* pData, double* pVector, int nNeighbors, int* pOutNeighbors, double* pOutSquaredDistances, int* pnWorstNeighbor, int nExclude) = 0;};class GManifoldNeighborInterior : public GManifoldNeighborNode{protected:	double* m_pPivot;	GManifoldNeighborNode* m_pLeft;	GManifoldNeighborNode* m_pRight;public:	GManifoldNeighborInterior(int nCount, int* pIndexes, int nDims, GArffData* pData)		: GManifoldNeighborNode(nCount, pIndexes, nDims, pData)	{		m_pLeft = NULL;		m_pRight = NULL;		m_pPivot = new double[nDims];	}	virtual ~GManifoldNeighborInterior()	{		delete[] m_pPivot;	}	virtual bool IsLeaf()	{		return false;	}	double* GetPivot()	{		return m_pPivot;	}	void SetLeft(GManifoldNeighborNode* pNode)	{		GAssert(!m_pLeft, "already got a left");		m_pLeft = pNode;	}	void SetRight(GManifoldNeighborNode* pNode)	{		GAssert(!m_pRight, "already got a right");		m_pRight = pNode;	}	bool TestVector(int nDims, double* pVector)	{		double d = 0;		int i;		for(i = 0; i < nDims; i++)			d += (pVector[i] - m_pCenter[i]) * m_pPivot[i];		return (d >= 0);	}	virtual void FindNeighbors(GArffRelation* pRelation, GArffData* pData, double* pVector, int nNeighbors, int* pOutNeighbors, double* pOutSquaredDistances, int* pnWorstNeighbor, int nExclude)	{		int nDims = pRelation->GetInputCount();		bool bRight = TestVector(nDims, pVector);		GManifoldNeighborNode* pFirst = bRight ? m_pRight : m_pLeft;		GManifoldNeighborNode* pSecond = bRight ? m_pLeft : m_pRight;		pFirst->FindNeighbors(pRelation, pData, pVector, nNeighbors, pOutNeighbors, pOutSquaredDistances, pnWorstNeighbor, nExclude);		bool bTrySecond = false;		if(pOutNeighbors[*pnWorstNeighbor] < 0)			bTrySecond = true;		else		{			double dWorstDist = sqrt(pOutSquaredDistances[*pnWorstNeighbor]);			int i;			for(i = 0; i < nDims; i++)				pVector[i] -= (dWorstDist * m_pPivot[i]);			if(!TestVector(nDims, pVector))				bTrySecond = true;			for(i = 0; i < nDims; i++)				pVector[i] += (dWorstDist * m_pPivot[i]);		}		if(bTrySecond)			pSecond->FindNeighbors(pRelation, pData, pVector, nNeighbors, pOutNeighbors, pOutSquaredDistances, pnWorstNeighbor, nExclude);	}};class GManifoldNeighborLeaf : public GManifoldNeighborNode{protected:	int m_nCount;	int* m_pIndexes;public:	GManifoldNeighborLeaf(int nCount, int* pIndexes, int nDims, GArffData* pData)		: GManifoldNeighborNode(nCount, pIndexes, nDims, pData)	{		m_nCount = nCount;		m_pIndexes = new int[nCount];		memcpy(m_pIndexes, pIndexes, sizeof(int) * nCount);	}	virtual ~GManifoldNeighborLeaf()	{	}	virtual bool IsLeaf()	{		return true;	}	virtual void FindNeighbors(GArffRelation* pRelation, GArffData* pData, double* pVector, int nNeighbors, int* pOutNeighbors, double* pOutSquaredDistances, int* pnWorstNeighbor, int nExclude)	{		double* pCandidate;		double d;		int i, j, index;		for(i = 0; i < m_nCount; i++)		{			index = m_pIndexes[i];			if(index == nExclude)				continue;			pCandidate = pData->GetVector(index);			d = pRelation->ComputeInputDistanceSquared(pCandidate, pVector);			if(d < pOutSquaredDistances[*pnWorstNeighbor])			{				pOutNeighbors[*pnWorstNeighbor] = index;				pOutSquaredDistances[*pnWorstNeighbor] = d;				*pnWorstNeighbor = 0;				for(j = 1; j < nNeighbors; j++)				{					if(pOutSquaredDistances[j] > pOutSquaredDistances[*pnWorstNeighbor])						*pnWorstNeighbor = j;				}			}		}	}};class GManifoldNeighborFinder{public:	enum Mode	{		kd_tree,		random,		principle_component_4,		principle_component_8,		principle_component_16,	};protected:	GManifoldNeighborNode* m_pRoot;	GArffRelation* m_pRelation;	GArffData* m_pData;	int m_nLeafCount;	Mode m_eMode;public:	GManifoldNeighborFinder(GArffRelation* pRelation, GArffData* pData, int nLeafCount, Mode eMode)	{		m_pRelation = pRelation;		m_pData = pData;		m_nLeafCount = nLeafCount;		m_eMode = eMode;		int* pIndexes = new int[pData->GetSize()];		Holder<int*> hIndexes(pIndexes);		int i;		int nCount = pData->GetSize();		for(i = 0; i < nCount; i++)			pIndexes[i] = i;		m_pRoot = BuildTree(nCount, pIndexes);	}	~GManifoldNeighborFinder()	{		delete(m_pRoot);	}	GManifoldNeighborNode* BuildTree(int nCount, int* pIndexes)	{		int nDims = m_pRelation->GetInputCount();		if(nCount <= m_nLeafCount)			return new GManifoldNeighborLeaf(nCount, pIndexes, nDims, m_pData);		GManifoldNeighborInterior* pNode = new GManifoldNeighborInterior(nCount, pIndexes, nDims, m_pData);		GArffData data(nCount);		int i;		for(i = 0; i < nCount; i++)			data.AddVector(m_pData->GetVector(pIndexes[i]));		double* pPivot = pNode->GetPivot();		switch(m_eMode)		{			case kd_tree:				for(i = 0; i < nDims; i++)					pPivot[i] = 0;				pPivot[rand() % nDims] = 1;				break;			case random:				for(i = 0; i < nDims; i++)					pPivot[i] = GBits::GetRandomDouble() - .5;				GVector::Normalize(pPivot, nDims);				break;			case principle_component_4:				data.ComputePrincipleComponent(m_pRelation->GetInputCount(), pPivot, 4);				break;			case principle_component_8:				data.ComputePrincipleComponent(m_pRelation->GetInputCount(), pPivot, 8);				break;			case principle_component_16:				data.ComputePrincipleComponent(m_pRelation->GetInputCount(), pPivot, 16);				break;		}		int nHead = 0;		int nTail = nCount;		double* pVector;		while(true)		{			// Advance the head			while(nHead < nTail)			{				pVector = m_pData->GetVector(pIndexes[nHead]);				if(pNode->TestVector(nDims, pVector))					break;				nHead++;			}			// Advance (the other way) the tail			while(nHead < nTail)			{				pVector = m_pData->GetVector(pIndexes[nTail - 1]);				if(!pNode->TestVector(nDims, pVector))					break;				nTail--;			}			// Test for completion			if(nHead < nTail)			{				// Swap the head and the tail				int tmp = pIndexes[nHead];				pIndexes[nHead] = pIndexes[nTail - 1];				pIndexes[nTail - 1] = tmp;				nHead++;				nTail--;			}			else				break;		}printf("Left=%d, Right=%d\n", nTail, nCount - nTail);		pNode->SetLeft(BuildTree(nTail, pIndexes));		pNode->SetRight(BuildTree(nCount - nTail, pIndexes + nTail));		data.DropAllVectors();		return pNode;	}	void FindNeighbors(int* pOutNeighbors, double* pOutSquaredDistances, int nNeighbors, double* pVector, int nExclude)	{		int nWorstNeighbor = 0;		int i;		for(i = 0; i < nNeighbors; i++)		{			pOutNeighbors[i] = -1;			pOutSquaredDistances[i] = 1e200;		}		m_pRoot->FindNeighbors(m_pRelation, m_pData, pVector, nNeighbors, pOutNeighbors, pOutSquaredDistances, &nWorstNeighbor, nExclude);	}#ifndef NO_TEST_CODE	static void Test()	{		// Generate the data		int nDimensions = 10;		int nNeighbors = 20;		int nMaxPointsPerLeaf = 5;		int nPoints = 500;		GArffRelation rel;		int i, j, k;		for(i = 0; i < nDimensions; i++)			rel.AddAttribute(new GArffAttribute(true, 0, NULL));		GArffData data(nPoints);		double* pVector;		for(i = 0; i < nPoints; i++)		{			pVector = new double[nDimensions];			for(j = 0; j < nDimensions; j++)				pVector[j] = GBits::GetRandomDouble() * 100;			data.AddVector(pVector);		}		// Find neighbors using GManifoldNeighborFinder		double* pVector2;		int* pNeighbors = new int[nNeighbors];		ArrayHolder<int*> hNeighbors(pNeighbors);		double* pDistances = new double[nNeighbors];		ArrayHolder<double*> hDistances(pDistances);		GManifoldNeighborFinder gnf(&rel, &data, nMaxPointsPerLeaf, GManifoldNeighborFinder::principle_component_8);		for(i = 0; i < nPoints; i++)		{			// Find the neighbors			pVector = data.GetVector(i);			gnf.FindNeighbors(pNeighbors, pDistances, nNeighbors, pVector, i);			// Check the answer			int nWorstNeighbor = -1;			double dWorstDistance = 0;			double d;			for(j = 0; j < nNeighbors; j++)			{				if(pNeighbors[j] < 0)					throw "Didn't retrieve enough neighbors";				d = rel.ComputeInputDistanceSquared(pVector, data.GetVector(pNeighbors[j]));				if(d > dWorstDistance)				{					dWorstDistance = d;					nWorstNeighbor = j;				}			}			for(j = 0; j < nPoints; j++)			{				pVector2 = data.GetVector(j);				d = rel.ComputeInputDistanceSquared(pVector, pVector2);				if(d == 0)				{				}				else if(d < dWorstDistance)				{					for(k = 0; k < nNeighbors; k++)					{						if(k > 0)						{							// This isn't a complete check, but it should be good enough							if(pNeighbors[k] == pNeighbors[k - 1])								throw "same neighbor twice";							if(pNeighbors[k] == pNeighbors[0])								throw "same neighbor twice";						}						if(pNeighbors[k] == i)							throw "failed to exclude itself";						if(data.GetVector(pNeighbors[k]) == pVector2)							break;					}					if(k >= nNeighbors)						throw "missed a neighbor";				}			}		}	}#endif // !NO_TEST_CODE};*/// --------------------------------------------------------------------/*class GYetAnotherPoint : public GAVLNode{protected:	int m_nIndex;	int m_nProg;	double m_dSquaredDist;public:	GYetAnotherPoint(int nIndex)	{		m_nIndex = nIndex;		m_nProg = 0;		m_dSquaredDist = 0;	}	virtual ~GYetAnotherPoint()	{	}	virtual int Compare(GAVLNode* pThat)	{		if(m_dSquaredDist < ((GYetAnotherPoint*)pThat)->m_dSquaredDist)			return -1;		else if(m_dSquaredDist > ((GYetAnotherPoint*)pThat)->m_dSquaredDist)			return 1;		else			return 0;	}	int GetIndex()	{		return m_nIndex;	}	int GetProg()	{		return m_nProg;	}	double GetSquaredDist()	{		return m_dSquaredDist;	}	void AddSquaredDist(double d)	{		m_dSquaredDist += d;		m_nProg++;	}};class GYetAnotherNeighborFinder{protected:	int m_nDims;	GAVLTree* m_pPriorityQueue;	GArffData* m_pData;	double* m_pPrincipleComponent;	GIntArray m_dimOrder;public:	GYetAnotherNeighborFinder(int nDims, GArffData* pData);	~GYetAnotherNeighborFinder()	{		delete[] m_pPrincipleComponent;		delete(m_pPriorityQueue);	}	void FindNeighbors(int* pNeighbors, double* pSquaredDistances, int nNeighbors, double* pVector, int nExclude)	{		// Create the point records		GAssert(m_pPriorityQueue->GetSize() == 0, "priority queue not empty");		int i;		int nCount = m_pData->GetSize();		for(i = 0; i < nCount; i++)		{			if(i == nExclude)				continue;			m_pPriorityQueue->Insert(new GYetAnotherPoint(i));		}		// Crunch until the closest K neighbors are completed		double* pVec;		double d;		if(nCount > nNeighbors)		{			int nFirstIncomplete = 0;			int dim, prog;			while(true)			{				GYetAnotherPoint* pPoint = (GYetAnotherPoint*)m_pPriorityQueue->Unlink(nFirstIncomplete);				pVec = m_pData->GetVector(pPoint->GetIndex());				prog = pPoint->GetProg();				if(prog >= m_nDims)				{					m_pPriorityQueue->Insert(pPoint);					nFirstIncomplete++;					if(nFirstIncomplete >= nNeighbors)						break;					else						continue;				}				for(i = 0; i < 10; i++)				{					if(prog >= m_nDims)						break;					dim = m_dimOrder.GetInt(prog);					d = pVec[dim] - pVector[dim];					pPoint->AddSquaredDist(d * d);					prog = pPoint->GetProg();				}				m_pPriorityQueue->Insert(pPoint);			}		}		// Return the results		for(i = 0; i < nNeighbors; i++)		{			GYetAnotherPoint* pPoint = (GYetAnotherPoint*)m_pPriorityQueue->Unlink(0);			if(pPoint)			{				pNeighbors[i] = pPoint->GetIndex();

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?