gknn.h

来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C头文件 代码 · 共 169 行

H
169
字号
/*	Copyright (C) 2006, Mike Gashler	This library is free software; you can redistribute it and/or	modify it under the terms of the GNU Lesser General Public	License as published by the Free Software Foundation; either	version 2.1 of the License, or (at your option) any later version.	see http://www.gnu.org/copyleft/lesser.html*/#ifndef __GKNN_H__#define __GKNN_H__#include "GLearner.h"class GArffData;class GNeighborFinderNode;class GNeighborFinderLeaf;// This finds the k-nearest neighbors of a point in Euclidean space. It// divides the data with a kd-tree until every hyper-rectangle contains// at most a maximum number of points. Each hyper-rectangle is linked to// its immediate neighbors in each dimension. It searches outward using// these links to efficiently find the nearest neighbors. It doesn't stop// until it has searched a hyper-volume large enough to guarantee that it// has found the k nearest neighbors.class GNeighborFinder{protected:	GNeighborFinderNode* m_pRoot;	GArffRelation* m_pRelation;	GArffData* m_pData;	int m_nMaxPointsPerLeaf;	GNeighborFinderLeaf** m_ppIterators;	double* m_pMaxs;public:	// This instance will use pData to hold the vectors. (So if you call AddVector, the	// new vector will be added to pData.) But it doesn't take ownership of pData. You	// are still responsible to delete it after you delete this GNeighborFinder.	// nMaxPointsPerLeaf tells how many vectors are contained by each hyper-rectangle	// before it is divided. (Experimental results are still needed to determine how	// to select this value. If it's too small, the algorithm will waste time with	// empty hyper-rectangles. If it's too big, the algorithm will waste time	// examining unnecessary data points. Perhaps the number of neighbors might be	// a good value.)	GNeighborFinder(GArffRelation* pRelation, GArffData* pData, int nMaxPointsPerLeaf = 0);	~GNeighborFinder();#ifndef NO_TEST_CODE	static void Test();#endif // !NO_TEST_CODE	// Add a reference-to-pVector to the GArffData collection that was passed in	// to the constructor, and indexes it in the kd-tree.	void AddVector(double* pVector);	// Drops a vector from the known instances. You are responsible to delete[] the vector	// that this returns.	double* DropVector(int nIndex);	// pOutNeighbors and pOutSquaredDistances should both be arrays of size nNeighbors.	// When it returns, these arrays will hold the indexes into the data set of the	// nearest neighbors, and the squared distances respectively. nExclude is an index	// that you don't want to get in the results. For example, if you are passing in	// a vector from the data set, you may wish to exclude its index because you	// already know it's close to itself. If you don't wish to exclude any indexes, just	// set nExclude to -1. If there are not enough points in the data set to fill the	// neighbor array, the empty ones will be set to -1.	void FindNeighbors(int* pOutNeighbors, double* pOutSquaredDistances, int nNeighbors, double* pVector, int nExclude);protected:	// Adds a new vector to the kd-tree. pVector must already be added to m_pData,	// and nIndex must be the index in m_pData to the vector.	void InsertVector(double* pVector, int nIndex);	// Returns the leaf hyper-rectangle that would contain the specified vector	GNeighborFinderLeaf* FindCell(double* pVector);	// Splits all affected hyper-rectangles at the specified location	void Split(int nDimension, double dPivot);	void GetNeighborsFromCell(GNeighborFinderLeaf* pCell, double* pVector, int* pOutNeighbors, double* pOutSquaredDistances, int nNeighbors, int* pnWorstNeighbor, int nExclude);};// Implements the K-Nearest Neighbor learning algorithmclass GKNN : public GSupervisedLearner{public:	enum InterpolationMethod	{		Linear,		Mean,		Learner,	};protected:	int m_nNeighbors;	GArffData* m_pInstances;	double* m_pScaleFactors;	double* m_pEvalVector;	int* m_pEvalNeighbors;	double* m_pEvalDistances;	GNeighborFinder* m_pNeighborFinder;	InterpolationMethod m_eInterpolationMethod;	GSupervisedLearner* m_pLearner;	bool m_bOwnLearner;	bool m_bCopyInstances;public:	// If bCopyInstances is true, it will make a copy of every instance	// that you add. If bCopyInstances is false, it will just use a reference	// to the vector you pass in. That vector must remain valid for the	// duration of this instance, and you are responsible to delete it. Also,	// it won't scale the vectors.	GKNN(GArffRelation* pRelation, int nNeighbors, bool bCopyInstances);	virtual ~GKNN();	// Discard any training (but not any settings) so it can be trained again	virtual void Reset();	// Sets the technique for interpolation. (If you want to use the "Learner" method,	// you should call SetInterpolationLearner instead of this method.)	void SetInterpolationMethod(InterpolationMethod eMethod);	// Sets the interpolation method to "Learner" and sets the learner to use. If	// bTakeOwnership is true, it will delete the learner when this object is deleted.	void SetInterpolationLearner(GSupervisedLearner* pLearner, bool bTakeOwnership);	// Makes a copy of the vector and adds it to the internal set.	void AddVector(double* pVector);	// Makes a copy of the vector and adds it to the internal set. Also, if the closest	// neighbor of that vector is less than dCloseDistance from it, that neighbor is	// deleted from the internal set.	void AddVectorAndDeleteNeighborIfClose(double* pVector, double dCloseDistance);	// Compute the amount to scale each dimension so that all dimensions	// have equal weight	void ComputeScaleFactors(GArffData* pData);	// Train with all the points in pData	virtual void Train(GArffData* pData);	// Deduce the output values from the input values	virtual void Eval(double* pVector);protected:	// Finds the nearest neighbors of pVector	void FindNeighbors(double* pVector);	// Interpolate with each neighbor having equal vote	void InterpolateMean(double* pVector);	// Interpolate with each neighbor having a linear vote. (Actually it's linear with	// respect to the squared distance instead of the distance, because this is faster	// to compute.)	void InterpolateLinear(double* pVector);	// Interpolates with the provided supervised learning algorithm	void InterpolateLearner(double* pVector);};#endif // __GKNN_H__

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?