ggreedysearch.cpp
来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C++ 代码 · 共 548 行
CPP
548 行
/* Copyright (C) 2006, Mike Gashler This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. see http://www.gnu.org/copyleft/lesser.html*/#include "GGreedySearch.h"#include "GBits.h"#include "GAVLTree.h"#include <math.h>#include <stdio.h>#include "GImage.h"#include "GBitTable.h"GMomentumGreedySearch::GMomentumGreedySearch(GRealVectorCritic* pCritic): GRealVectorSearch(pCritic){ m_nDimensions = pCritic->GetVectorSize(); m_nCurrentDim = 0; m_pVector = new double[m_nDimensions]; m_pStepSizes = new double[m_nDimensions]; m_dChangeFactor = .93; int i; for(i = 0; i < m_nDimensions; i++) { m_pVector[i] = 0; m_pStepSizes[i] = 1; } m_dError = 1e200;}/*virtual*/ GMomentumGreedySearch::~GMomentumGreedySearch(){ delete(m_pVector); delete(m_pStepSizes);}void GMomentumGreedySearch::SetState(double* pVector){ memcpy(m_pVector, pVector, sizeof(double) * m_nDimensions);}void GMomentumGreedySearch::SetAllStepSizes(double dStepSize){ int i; for(i = 0; i < m_nDimensions; i++) m_pStepSizes[i] = dStepSize;}/*virtual*/ void GMomentumGreedySearch::Iterate(){ m_pVector[m_nCurrentDim] += m_pStepSizes[m_nCurrentDim]; double dError = m_pCritic->Critique(m_pVector); if(dError >= m_dError) { m_pVector[m_nCurrentDim] -= m_pStepSizes[m_nCurrentDim]; m_pVector[m_nCurrentDim] -= m_pStepSizes[m_nCurrentDim]; dError = m_pCritic->Critique(m_pVector); if(dError >= m_dError) m_pVector[m_nCurrentDim] += m_pStepSizes[m_nCurrentDim]; } if(dError >= m_dError) m_pStepSizes[m_nCurrentDim] *= m_dChangeFactor; else { m_pStepSizes[m_nCurrentDim] /= m_dChangeFactor; m_dError = dError; } if(++m_nCurrentDim >= m_nDimensions) m_nCurrentDim = 0;} // --------------------------------------------------------------------------------GStochasticGreedySearch::GStochasticGreedySearch(GRealVectorCritic* pCritic, double dMin, double dRange): GRealVectorSearch(pCritic){ m_dRange = dRange; m_dConservativeness = 4; m_nDimensions = pCritic->GetVectorSize(); m_pVector = new double[m_nDimensions]; m_pTest = new double[m_nDimensions]; int i; for(i = 0; i < m_nDimensions; i++) m_pVector[i] = GBits::GetRandomDouble() * dRange + dMin; m_dError = pCritic->Critique(m_pVector);}/*virtual*/ GStochasticGreedySearch::~GStochasticGreedySearch(){ delete(m_pVector); delete(m_pTest);}/*virtual*/ void GStochasticGreedySearch::Iterate(){ // Pick a new spot to try int i; for(i = 0; i < m_nDimensions; i++) { if(rand() & 1) m_pTest[i] = m_pVector[i] + (m_dRange * pow(GBits::GetRandomDouble(), m_dConservativeness)); else m_pTest[i] = m_pVector[i] - (m_dRange * pow(GBits::GetRandomDouble(), m_dConservativeness)); } // Critique the current spots and find the global best double dError = m_pCritic->Critique(m_pTest); if(dError < m_dError) { double* pTmp = m_pTest; m_pTest = m_pVector; m_pVector = pTmp; m_dError = dError; }}// --------------------------------------------------------------------------------GActionGreedySearch::GActionGreedySearch(GActionPathState* pStartState, int nActionCount): GActionPathSearch(pStartState, nActionCount){ m_pPath = new GActionPath(pStartState); m_dPrevError = 1e200;}// virtualGActionGreedySearch::~GActionGreedySearch(){ delete(m_pPath);}// virtualbool GActionGreedySearch::Iterate(){ GActionPath* pPathBest = m_pPath->Fork(); pPathBest->DoAction(0); double dBest = pPathBest->Critique(); double dTest; GActionPath* pPathTest; GActionPath* pTmp; int i; for(i = 1; i < m_nActionCount; i++) { if(i == m_nActionCount - 1) pPathTest = m_pPath; else pPathTest = m_pPath->Fork(); pPathTest->DoAction(i); dTest = pPathTest->Critique(); if(dTest < dBest) { dBest = dTest; pTmp = pPathBest; pPathBest = pPathTest; pPathTest = pTmp; } delete(pPathTest); } m_pPath = pPathBest; // See if we got any better if(dBest >= m_dPrevError) return true; m_dPrevError = dBest; return false;}// virtualGActionPath* GActionGreedySearch::GetBestPath(){ return m_pPath;}// virtualdouble GActionGreedySearch::GetBestPathError(){ return m_dPrevError;}#ifndef NO_TEST_CODEclass GActionGreedySearchTestState : public GActionPathState{protected: int m_state[2]; GImage* m_pImage;public: GActionGreedySearchTestState(GImage* pImage) { m_state[0] = 10; m_state[1] = 10; m_pImage = pImage; } virtual ~GActionGreedySearchTestState() { } virtual GActionPathState* Copy() { GActionGreedySearchTestState* pNewState = new GActionGreedySearchTestState(m_pImage); pNewState->m_state[0] = m_state[0]; pNewState->m_state[1] = m_state[1]; return pNewState; } virtual double CritiquePath(int nPathLen, GAction* pLastAction) { int x = m_state[0]; int y = m_state[1]; //GColor col = m_image.GetPixel(x, y); //int nPrevLen = 255 - (int)gBlue(col); //if(nPrevLen <= nLen) // if a shorter path ever got to this state // return 1e200; m_pImage->SetPixel(x, y, gARGB(0xff, 0, 88, 255 - nPathLen)); // record the path length in the blue channel double dError = 0; double d; d = 49 - m_state[0]; dError += d * d; d = 49 - m_state[1]; dError += d * d; return dError; } virtual void PerformAction(int nAction) { switch(nAction) { case 0: if(m_state[0] < m_pImage->GetWidth() - 1 && gGreen(m_pImage->GetPixel(m_state[0] + 1, m_state[1])) < 10) m_state[0]++; break; case 1: if(m_state[0] > 0 && gGreen(m_pImage->GetPixel(m_state[0] - 1, m_state[1])) < 10) m_state[0]--; break; case 2: if(m_state[1] < m_pImage->GetHeight() - 1 && gGreen(m_pImage->GetPixel(m_state[0], m_state[1] + 1)) < 10) m_state[1]++; break; case 3: if(m_state[1] > 0 && gGreen(m_pImage->GetPixel(m_state[0], m_state[1] - 1)) < 10) m_state[1]--; break; default: GAssert(false, "unrecognized action"); break; } }};// staticvoid GActionGreedySearch::Test(){ GImage image; image.SetSize(50, 50); image.Clear(0xff000000); image.DrawLine(17, 20, 27, 20, 0xffffffff); image.DrawLine(30, 40, 40, 40, 0xffffffff); image.DrawLine(40, 30, 40, 40, 0xffffffff); GActionGreedySearch search(new GActionGreedySearchTestState(&image), 4); while(!search.Iterate()) { } GActionPath* pPath = search.GetBestPath(); int nLen = pPath->GetLength(); if(nLen < 50) throw "impossible"; if(nLen > 70) throw "path unnecessarily long"; int* path = new int[nLen]; Holder<int*> hPath(path); pPath->GetPath(nLen, path); int x = 10; int y = 10; int i; for(i = 0; i < nLen; i++) { image.SetPixel(x, y, 0xff00ffff); if(path[i] == 0) x++; else if(path[i] == 2) y++; else if(path[i] == 1) { x--; throw "unnecessary action"; } else if(path[i] == 3) { y--; throw "unnecessary action"; } } //image.SavePNGFile("greedy.png");}#endif // !NO_TEST_CODE// --------------------------------------------------------------------------------class GAStarNode : public GAVLNode{protected: double m_dError; GActionPath* m_pPath;public: // Takes ownership of pPath GAStarNode(GActionPath* pPath) : GAVLNode() { m_dError = pPath->Critique(); m_pPath = pPath; } virtual ~GAStarNode() { delete(m_pPath); } void SetNewPath(GActionPath* pPath) { delete(m_pPath); m_pPath = pPath; m_dError = pPath->Critique(); } virtual int Compare(GAVLNode* pThat) { GAStarNode* pOther = (GAStarNode*)pThat; if(m_dError < pOther->m_dError) return -1; else if(m_dError > pOther->m_dError) return 1; else return 0; } GActionPath* DropPath() { GActionPath* pPath = m_pPath; m_pPath = NULL; return pPath; } double GetError() { return m_dError; }};GAStarSearch::GAStarSearch(GActionPathState* pStartState, int nActionCount, int nMaxPaths) : GActionPathSearch(pStartState, nActionCount){ m_nMaxPaths = nMaxPaths; m_dBestError = 1e200; m_pBestPath = NULL; m_pPriorityQueue = new GAVLTree(); GActionPath* pEmptyPath = new GActionPath(pStartState); m_pPriorityQueue->Insert(new GAStarNode(pEmptyPath));}// virtualGAStarSearch::~GAStarSearch(){ delete(m_pPriorityQueue); delete(m_pBestPath);}// virtualbool GAStarSearch::Iterate(){ GAStarNode* pSpare = (GAStarNode*)m_pPriorityQueue->Unlink(0); GActionPath* pPath = pSpare->DropPath(); double dError = pSpare->GetError(); int i; for(i = 0; i < m_nActionCount; i++) { GActionPath* pNewPath = pPath->Fork(); pNewPath->DoAction(i); if(pSpare) { pSpare->SetNewPath(pNewPath); m_pPriorityQueue->Insert(pSpare); pSpare = NULL; } else m_pPriorityQueue->Insert(new GAStarNode(pNewPath)); if(m_pPriorityQueue->GetSize() > m_nMaxPaths) pSpare = (GAStarNode*)m_pPriorityQueue->Unlink(m_pPriorityQueue->GetSize() - 1); } delete(pSpare); if(dError < m_dBestError) { delete(m_pBestPath); m_pBestPath = pPath; m_dBestError = dError; } else delete(pPath); return m_dBestError == 0;}// virtualGActionPath* GAStarSearch::GetBestPath(){ return m_pBestPath;}// virtualdouble GAStarSearch::GetBestPathError(){ return m_dBestError;}#ifndef NO_TEST_CODEclass GAStarSearchTestState : public GActionPathState{protected: int m_state[2]; GImage* m_pImage;public: GAStarSearchTestState(GImage* pImage) { m_state[0] = 10; m_state[1] = 10; m_pImage = pImage; } virtual ~GAStarSearchTestState() { } virtual GActionPathState* Copy() { GAStarSearchTestState* pNewState = new GAStarSearchTestState(m_pImage); pNewState->m_state[0] = m_state[0]; pNewState->m_state[1] = m_state[1]; return pNewState; } virtual double CritiquePath(int nPathLen, GAction* pLastAction) { int x = m_state[0]; int y = m_state[1]; GColor col = m_pImage->GetPixel(x, y); int nPrevLen = 255 - (int)gBlue(col); if(nPrevLen <= nPathLen) // if a shorter path ever got to this state return 1e200; m_pImage->SetPixel(x, y, gARGB(0xff, 0, 88, 255 - nPathLen)); // record the path length in the blue channel double dError = 0; double d; d = 49 - m_state[0]; dError += d * d; d = 49 - m_state[1]; dError += d * d; if(dError == 0) return 0; dError = 1.4 * sqrt(dError) + nPathLen; return dError; } virtual void PerformAction(int nAction) { switch(nAction) { case 0: if(m_state[0] < m_pImage->GetWidth() - 1 && gGreen(m_pImage->GetPixel(m_state[0] + 1, m_state[1])) < 10) m_state[0]++; break; case 1: if(m_state[0] > 0 && gGreen(m_pImage->GetPixel((int)m_state[0] - 1, m_state[1])) < 10) m_state[0]--; break; case 2: if(m_state[1] < m_pImage->GetHeight() - 1 && gGreen(m_pImage->GetPixel(m_state[0], m_state[1] + 1)) < 10) m_state[1]++; break; case 3: if(m_state[1] > 0 && gGreen(m_pImage->GetPixel((int)m_state[0], m_state[1] - 1)) < 10) m_state[1]--; break; default: GAssert(false, "unrecognized action"); break; } }};// staticvoid GAStarSearch::Test(){ GImage image; image.SetSize(50, 50); image.Clear(0xff000000); image.DrawLine(30, 40, 40, 40, 0xffffffff); image.DrawLine(40, 30, 40, 40, 0xffffffff); GAStarSearch search(new GAStarSearchTestState(&image), 4, 1000); while(!search.Iterate()) { } GActionPath* pPath = search.GetBestPath(); int nLen = pPath->GetLength(); if(nLen < 70) throw "impossible"; if(nLen > 80) throw "path unnecessarily long"; int* path = new int[nLen]; Holder<int*> hPath(path); pPath->GetPath(nLen, path); int x = 10; int y = 10; int i; for(i = 0; i < nLen; i++) { image.SetPixel(x, y, 0xff00ffff); if(path[i] == 0) x++; else if(path[i] == 2) y++; else if(path[i] == 1) { x--; throw "unnecessary action"; } else if(path[i] == 3) { y--; throw "unnecessary action"; } } if(x != 49 || y != 49) throw "wrong target state"; //image.SavePNGFile("astar.png");}#endif // !NO_TEST_CODE
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?