gknn.h
来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C头文件 代码 · 共 169 行
H
169 行
/* Copyright (C) 2006, Mike Gashler This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. see http://www.gnu.org/copyleft/lesser.html*/#ifndef __GKNN_H__#define __GKNN_H__#include "GLearner.h"class GArffData;class GNeighborFinderNode;class GNeighborFinderLeaf;// This finds the k-nearest neighbors of a point in Euclidean space. It// divides the data with a kd-tree until every hyper-rectangle contains// at most a maximum number of points. Each hyper-rectangle is linked to// its immediate neighbors in each dimension. It searches outward using// these links to efficiently find the nearest neighbors. It doesn't stop// until it has searched a hyper-volume large enough to guarantee that it// has found the k nearest neighbors.class GNeighborFinder{protected: GNeighborFinderNode* m_pRoot; GArffRelation* m_pRelation; GArffData* m_pData; int m_nMaxPointsPerLeaf; GNeighborFinderLeaf** m_ppIterators; double* m_pMaxs;public: // This instance will use pData to hold the vectors. (So if you call AddVector, the // new vector will be added to pData.) But it doesn't take ownership of pData. You // are still responsible to delete it after you delete this GNeighborFinder. // nMaxPointsPerLeaf tells how many vectors are contained by each hyper-rectangle // before it is divided. (Experimental results are still needed to determine how // to select this value. If it's too small, the algorithm will waste time with // empty hyper-rectangles. If it's too big, the algorithm will waste time // examining unnecessary data points. Perhaps the number of neighbors might be // a good value.) GNeighborFinder(GArffRelation* pRelation, GArffData* pData, int nMaxPointsPerLeaf = 0); ~GNeighborFinder();#ifndef NO_TEST_CODE static void Test();#endif // !NO_TEST_CODE // Add a reference-to-pVector to the GArffData collection that was passed in // to the constructor, and indexes it in the kd-tree. void AddVector(double* pVector); // Drops a vector from the known instances. You are responsible to delete[] the vector // that this returns. double* DropVector(int nIndex); // pOutNeighbors and pOutSquaredDistances should both be arrays of size nNeighbors. // When it returns, these arrays will hold the indexes into the data set of the // nearest neighbors, and the squared distances respectively. nExclude is an index // that you don't want to get in the results. For example, if you are passing in // a vector from the data set, you may wish to exclude its index because you // already know it's close to itself. If you don't wish to exclude any indexes, just // set nExclude to -1. If there are not enough points in the data set to fill the // neighbor array, the empty ones will be set to -1. void FindNeighbors(int* pOutNeighbors, double* pOutSquaredDistances, int nNeighbors, double* pVector, int nExclude);protected: // Adds a new vector to the kd-tree. pVector must already be added to m_pData, // and nIndex must be the index in m_pData to the vector. void InsertVector(double* pVector, int nIndex); // Returns the leaf hyper-rectangle that would contain the specified vector GNeighborFinderLeaf* FindCell(double* pVector); // Splits all affected hyper-rectangles at the specified location void Split(int nDimension, double dPivot); void GetNeighborsFromCell(GNeighborFinderLeaf* pCell, double* pVector, int* pOutNeighbors, double* pOutSquaredDistances, int nNeighbors, int* pnWorstNeighbor, int nExclude);};// Implements the K-Nearest Neighbor learning algorithmclass GKNN : public GSupervisedLearner{public: enum InterpolationMethod { Linear, Mean, Learner, };protected: int m_nNeighbors; GArffData* m_pInstances; double* m_pScaleFactors; double* m_pEvalVector; int* m_pEvalNeighbors; double* m_pEvalDistances; GNeighborFinder* m_pNeighborFinder; InterpolationMethod m_eInterpolationMethod; GSupervisedLearner* m_pLearner; bool m_bOwnLearner; bool m_bCopyInstances;public: // If bCopyInstances is true, it will make a copy of every instance // that you add. If bCopyInstances is false, it will just use a reference // to the vector you pass in. That vector must remain valid for the // duration of this instance, and you are responsible to delete it. Also, // it won't scale the vectors. GKNN(GArffRelation* pRelation, int nNeighbors, bool bCopyInstances); virtual ~GKNN(); // Discard any training (but not any settings) so it can be trained again virtual void Reset(); // Sets the technique for interpolation. (If you want to use the "Learner" method, // you should call SetInterpolationLearner instead of this method.) void SetInterpolationMethod(InterpolationMethod eMethod); // Sets the interpolation method to "Learner" and sets the learner to use. If // bTakeOwnership is true, it will delete the learner when this object is deleted. void SetInterpolationLearner(GSupervisedLearner* pLearner, bool bTakeOwnership); // Makes a copy of the vector and adds it to the internal set. void AddVector(double* pVector); // Makes a copy of the vector and adds it to the internal set. Also, if the closest // neighbor of that vector is less than dCloseDistance from it, that neighbor is // deleted from the internal set. void AddVectorAndDeleteNeighborIfClose(double* pVector, double dCloseDistance); // Compute the amount to scale each dimension so that all dimensions // have equal weight void ComputeScaleFactors(GArffData* pData); // Train with all the points in pData virtual void Train(GArffData* pData); // Deduce the output values from the input values virtual void Eval(double* pVector);protected: // Finds the nearest neighbors of pVector void FindNeighbors(double* pVector); // Interpolate with each neighbor having equal vote void InterpolateMean(double* pVector); // Interpolate with each neighbor having a linear vote. (Actually it's linear with // respect to the squared distance instead of the distance, because this is faster // to compute.) void InterpolateLinear(double* pVector); // Interpolates with the provided supervised learning algorithm void InterpolateLearner(double* pVector);};#endif // __GKNN_H__
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?