gmanifold.cpp
来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C++ 代码 · 共 1,591 行 · 第 1/3 页
CPP
1,591 行
/* Copyright (C) 2006, Mike Gashler This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. see http://www.gnu.org/copyleft/lesser.html*/#include "GManifold.h"#include <stdio.h>#include "GArray.h"#include <math.h>#include "GPointerQueue.h"#include "GArff.h"#include "GMatrix.h"#include "GVector.h"#include "GBits.h"#include "GKNN.h"#include "GTime.h"#include "GAVLTree.h"//#define DOT_PRODUCT_ONLY//#define LINEAR_WEIGHTING/*class GManifoldNeighborNode{public: double* m_pCenter;public: GManifoldNeighborNode(int nCount, int* pIndexes, int nDims, GArffData* pData) { // Compute the center m_pCenter = new double[nDims]; int i, j; for(j = 0; j < nDims; j++) m_pCenter[j] = 0; double* pVector; for(i = 0; i < nCount; i++) { pVector = pData->GetVector(pIndexes[i]); for(j = 0; j < nDims; j++) m_pCenter[j] += pVector[j]; } for(j = 0; j < nDims; j++) m_pCenter[j] /= nCount; } virtual ~GManifoldNeighborNode() { delete[] m_pCenter; } virtual bool IsLeaf() = 0; virtual void FindNeighbors(GArffRelation* pRelation, GArffData* pData, double* pVector, int nNeighbors, int* pOutNeighbors, double* pOutSquaredDistances, int* pnWorstNeighbor, int nExclude) = 0;};class GManifoldNeighborInterior : public GManifoldNeighborNode{protected: double* m_pPivot; GManifoldNeighborNode* m_pLeft; GManifoldNeighborNode* m_pRight;public: GManifoldNeighborInterior(int nCount, int* pIndexes, int nDims, GArffData* pData) : GManifoldNeighborNode(nCount, pIndexes, nDims, pData) { m_pLeft = NULL; m_pRight = NULL; m_pPivot = new double[nDims]; } virtual ~GManifoldNeighborInterior() { delete[] m_pPivot; } virtual bool IsLeaf() { return false; } double* GetPivot() { return m_pPivot; } void SetLeft(GManifoldNeighborNode* pNode) { GAssert(!m_pLeft, "already got a left"); m_pLeft = pNode; } void SetRight(GManifoldNeighborNode* pNode) { GAssert(!m_pRight, "already got a right"); m_pRight = pNode; } bool TestVector(int nDims, double* pVector) { double d = 0; int i; for(i = 0; i < nDims; i++) d += (pVector[i] - m_pCenter[i]) * m_pPivot[i]; return (d >= 0); } virtual void FindNeighbors(GArffRelation* pRelation, GArffData* pData, double* pVector, int nNeighbors, int* pOutNeighbors, double* pOutSquaredDistances, int* pnWorstNeighbor, int nExclude) { int nDims = pRelation->GetInputCount(); bool bRight = TestVector(nDims, pVector); GManifoldNeighborNode* pFirst = bRight ? m_pRight : m_pLeft; GManifoldNeighborNode* pSecond = bRight ? m_pLeft : m_pRight; pFirst->FindNeighbors(pRelation, pData, pVector, nNeighbors, pOutNeighbors, pOutSquaredDistances, pnWorstNeighbor, nExclude); bool bTrySecond = false; if(pOutNeighbors[*pnWorstNeighbor] < 0) bTrySecond = true; else { double dWorstDist = sqrt(pOutSquaredDistances[*pnWorstNeighbor]); int i; for(i = 0; i < nDims; i++) pVector[i] -= (dWorstDist * m_pPivot[i]); if(!TestVector(nDims, pVector)) bTrySecond = true; for(i = 0; i < nDims; i++) pVector[i] += (dWorstDist * m_pPivot[i]); } if(bTrySecond) pSecond->FindNeighbors(pRelation, pData, pVector, nNeighbors, pOutNeighbors, pOutSquaredDistances, pnWorstNeighbor, nExclude); }};class GManifoldNeighborLeaf : public GManifoldNeighborNode{protected: int m_nCount; int* m_pIndexes;public: GManifoldNeighborLeaf(int nCount, int* pIndexes, int nDims, GArffData* pData) : GManifoldNeighborNode(nCount, pIndexes, nDims, pData) { m_nCount = nCount; m_pIndexes = new int[nCount]; memcpy(m_pIndexes, pIndexes, sizeof(int) * nCount); } virtual ~GManifoldNeighborLeaf() { } virtual bool IsLeaf() { return true; } virtual void FindNeighbors(GArffRelation* pRelation, GArffData* pData, double* pVector, int nNeighbors, int* pOutNeighbors, double* pOutSquaredDistances, int* pnWorstNeighbor, int nExclude) { double* pCandidate; double d; int i, j, index; for(i = 0; i < m_nCount; i++) { index = m_pIndexes[i]; if(index == nExclude) continue; pCandidate = pData->GetVector(index); d = pRelation->ComputeInputDistanceSquared(pCandidate, pVector); if(d < pOutSquaredDistances[*pnWorstNeighbor]) { pOutNeighbors[*pnWorstNeighbor] = index; pOutSquaredDistances[*pnWorstNeighbor] = d; *pnWorstNeighbor = 0; for(j = 1; j < nNeighbors; j++) { if(pOutSquaredDistances[j] > pOutSquaredDistances[*pnWorstNeighbor]) *pnWorstNeighbor = j; } } } }};class GManifoldNeighborFinder{public: enum Mode { kd_tree, random, principle_component_4, principle_component_8, principle_component_16, };protected: GManifoldNeighborNode* m_pRoot; GArffRelation* m_pRelation; GArffData* m_pData; int m_nLeafCount; Mode m_eMode;public: GManifoldNeighborFinder(GArffRelation* pRelation, GArffData* pData, int nLeafCount, Mode eMode) { m_pRelation = pRelation; m_pData = pData; m_nLeafCount = nLeafCount; m_eMode = eMode; int* pIndexes = new int[pData->GetSize()]; Holder<int*> hIndexes(pIndexes); int i; int nCount = pData->GetSize(); for(i = 0; i < nCount; i++) pIndexes[i] = i; m_pRoot = BuildTree(nCount, pIndexes); } ~GManifoldNeighborFinder() { delete(m_pRoot); } GManifoldNeighborNode* BuildTree(int nCount, int* pIndexes) { int nDims = m_pRelation->GetInputCount(); if(nCount <= m_nLeafCount) return new GManifoldNeighborLeaf(nCount, pIndexes, nDims, m_pData); GManifoldNeighborInterior* pNode = new GManifoldNeighborInterior(nCount, pIndexes, nDims, m_pData); GArffData data(nCount); int i; for(i = 0; i < nCount; i++) data.AddVector(m_pData->GetVector(pIndexes[i])); double* pPivot = pNode->GetPivot(); switch(m_eMode) { case kd_tree: for(i = 0; i < nDims; i++) pPivot[i] = 0; pPivot[rand() % nDims] = 1; break; case random: for(i = 0; i < nDims; i++) pPivot[i] = GBits::GetRandomDouble() - .5; GVector::Normalize(pPivot, nDims); break; case principle_component_4: data.ComputePrincipleComponent(m_pRelation->GetInputCount(), pPivot, 4); break; case principle_component_8: data.ComputePrincipleComponent(m_pRelation->GetInputCount(), pPivot, 8); break; case principle_component_16: data.ComputePrincipleComponent(m_pRelation->GetInputCount(), pPivot, 16); break; } int nHead = 0; int nTail = nCount; double* pVector; while(true) { // Advance the head while(nHead < nTail) { pVector = m_pData->GetVector(pIndexes[nHead]); if(pNode->TestVector(nDims, pVector)) break; nHead++; } // Advance (the other way) the tail while(nHead < nTail) { pVector = m_pData->GetVector(pIndexes[nTail - 1]); if(!pNode->TestVector(nDims, pVector)) break; nTail--; } // Test for completion if(nHead < nTail) { // Swap the head and the tail int tmp = pIndexes[nHead]; pIndexes[nHead] = pIndexes[nTail - 1]; pIndexes[nTail - 1] = tmp; nHead++; nTail--; } else break; }printf("Left=%d, Right=%d\n", nTail, nCount - nTail); pNode->SetLeft(BuildTree(nTail, pIndexes)); pNode->SetRight(BuildTree(nCount - nTail, pIndexes + nTail)); data.DropAllVectors(); return pNode; } void FindNeighbors(int* pOutNeighbors, double* pOutSquaredDistances, int nNeighbors, double* pVector, int nExclude) { int nWorstNeighbor = 0; int i; for(i = 0; i < nNeighbors; i++) { pOutNeighbors[i] = -1; pOutSquaredDistances[i] = 1e200; } m_pRoot->FindNeighbors(m_pRelation, m_pData, pVector, nNeighbors, pOutNeighbors, pOutSquaredDistances, &nWorstNeighbor, nExclude); }#ifndef NO_TEST_CODE static void Test() { // Generate the data int nDimensions = 10; int nNeighbors = 20; int nMaxPointsPerLeaf = 5; int nPoints = 500; GArffRelation rel; int i, j, k; for(i = 0; i < nDimensions; i++) rel.AddAttribute(new GArffAttribute(true, 0, NULL)); GArffData data(nPoints); double* pVector; for(i = 0; i < nPoints; i++) { pVector = new double[nDimensions]; for(j = 0; j < nDimensions; j++) pVector[j] = GBits::GetRandomDouble() * 100; data.AddVector(pVector); } // Find neighbors using GManifoldNeighborFinder double* pVector2; int* pNeighbors = new int[nNeighbors]; ArrayHolder<int*> hNeighbors(pNeighbors); double* pDistances = new double[nNeighbors]; ArrayHolder<double*> hDistances(pDistances); GManifoldNeighborFinder gnf(&rel, &data, nMaxPointsPerLeaf, GManifoldNeighborFinder::principle_component_8); for(i = 0; i < nPoints; i++) { // Find the neighbors pVector = data.GetVector(i); gnf.FindNeighbors(pNeighbors, pDistances, nNeighbors, pVector, i); // Check the answer int nWorstNeighbor = -1; double dWorstDistance = 0; double d; for(j = 0; j < nNeighbors; j++) { if(pNeighbors[j] < 0) throw "Didn't retrieve enough neighbors"; d = rel.ComputeInputDistanceSquared(pVector, data.GetVector(pNeighbors[j])); if(d > dWorstDistance) { dWorstDistance = d; nWorstNeighbor = j; } } for(j = 0; j < nPoints; j++) { pVector2 = data.GetVector(j); d = rel.ComputeInputDistanceSquared(pVector, pVector2); if(d == 0) { } else if(d < dWorstDistance) { for(k = 0; k < nNeighbors; k++) { if(k > 0) { // This isn't a complete check, but it should be good enough if(pNeighbors[k] == pNeighbors[k - 1]) throw "same neighbor twice"; if(pNeighbors[k] == pNeighbors[0]) throw "same neighbor twice"; } if(pNeighbors[k] == i) throw "failed to exclude itself"; if(data.GetVector(pNeighbors[k]) == pVector2) break; } if(k >= nNeighbors) throw "missed a neighbor"; } } } }#endif // !NO_TEST_CODE};*/// --------------------------------------------------------------------/*class GYetAnotherPoint : public GAVLNode{protected: int m_nIndex; int m_nProg; double m_dSquaredDist;public: GYetAnotherPoint(int nIndex) { m_nIndex = nIndex; m_nProg = 0; m_dSquaredDist = 0; } virtual ~GYetAnotherPoint() { } virtual int Compare(GAVLNode* pThat) { if(m_dSquaredDist < ((GYetAnotherPoint*)pThat)->m_dSquaredDist) return -1; else if(m_dSquaredDist > ((GYetAnotherPoint*)pThat)->m_dSquaredDist) return 1; else return 0; } int GetIndex() { return m_nIndex; } int GetProg() { return m_nProg; } double GetSquaredDist() { return m_dSquaredDist; } void AddSquaredDist(double d) { m_dSquaredDist += d; m_nProg++; }};class GYetAnotherNeighborFinder{protected: int m_nDims; GAVLTree* m_pPriorityQueue; GArffData* m_pData; double* m_pPrincipleComponent; GIntArray m_dimOrder;public: GYetAnotherNeighborFinder(int nDims, GArffData* pData); ~GYetAnotherNeighborFinder() { delete[] m_pPrincipleComponent; delete(m_pPriorityQueue); } void FindNeighbors(int* pNeighbors, double* pSquaredDistances, int nNeighbors, double* pVector, int nExclude) { // Create the point records GAssert(m_pPriorityQueue->GetSize() == 0, "priority queue not empty"); int i; int nCount = m_pData->GetSize(); for(i = 0; i < nCount; i++) { if(i == nExclude) continue; m_pPriorityQueue->Insert(new GYetAnotherPoint(i)); } // Crunch until the closest K neighbors are completed double* pVec; double d; if(nCount > nNeighbors) { int nFirstIncomplete = 0; int dim, prog; while(true) { GYetAnotherPoint* pPoint = (GYetAnotherPoint*)m_pPriorityQueue->Unlink(nFirstIncomplete); pVec = m_pData->GetVector(pPoint->GetIndex()); prog = pPoint->GetProg(); if(prog >= m_nDims) { m_pPriorityQueue->Insert(pPoint); nFirstIncomplete++; if(nFirstIncomplete >= nNeighbors) break; else continue; } for(i = 0; i < 10; i++) { if(prog >= m_nDims) break; dim = m_dimOrder.GetInt(prog); d = pVec[dim] - pVector[dim]; pPoint->AddSquaredDist(d * d); prog = pPoint->GetProg(); } m_pPriorityQueue->Insert(pPoint); } } // Return the results for(i = 0; i < nNeighbors; i++) { GYetAnotherPoint* pPoint = (GYetAnotherPoint*)m_pPriorityQueue->Unlink(0); if(pPoint) { pNeighbors[i] = pPoint->GetIndex();
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?