ggreedysearch.cpp

来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C++ 代码 · 共 548 行

CPP
548
字号
/*	Copyright (C) 2006, Mike Gashler	This library is free software; you can redistribute it and/or	modify it under the terms of the GNU Lesser General Public	License as published by the Free Software Foundation; either	version 2.1 of the License, or (at your option) any later version.	see http://www.gnu.org/copyleft/lesser.html*/#include "GGreedySearch.h"#include "GBits.h"#include "GAVLTree.h"#include <math.h>#include <stdio.h>#include "GImage.h"#include "GBitTable.h"GMomentumGreedySearch::GMomentumGreedySearch(GRealVectorCritic* pCritic): GRealVectorSearch(pCritic){	m_nDimensions = pCritic->GetVectorSize();	m_nCurrentDim = 0;	m_pVector = new double[m_nDimensions];	m_pStepSizes = new double[m_nDimensions];	m_dChangeFactor = .93;	int i;	for(i = 0; i < m_nDimensions; i++)	{		m_pVector[i] = 0;		m_pStepSizes[i] = 1;	}	m_dError = 1e200;}/*virtual*/ GMomentumGreedySearch::~GMomentumGreedySearch(){	delete(m_pVector);	delete(m_pStepSizes);}void GMomentumGreedySearch::SetState(double* pVector){	memcpy(m_pVector, pVector, sizeof(double) * m_nDimensions);}void GMomentumGreedySearch::SetAllStepSizes(double dStepSize){	int i;	for(i = 0; i < m_nDimensions; i++)		m_pStepSizes[i] = dStepSize;}/*virtual*/ void GMomentumGreedySearch::Iterate(){	m_pVector[m_nCurrentDim] += m_pStepSizes[m_nCurrentDim];	double dError = m_pCritic->Critique(m_pVector);	if(dError >= m_dError)	{		m_pVector[m_nCurrentDim] -= m_pStepSizes[m_nCurrentDim];		m_pVector[m_nCurrentDim] -= m_pStepSizes[m_nCurrentDim];		dError = m_pCritic->Critique(m_pVector);		if(dError >= m_dError)			m_pVector[m_nCurrentDim] += m_pStepSizes[m_nCurrentDim];	}	if(dError >= m_dError)		m_pStepSizes[m_nCurrentDim] *= m_dChangeFactor;	else	{		m_pStepSizes[m_nCurrentDim] /= m_dChangeFactor;		m_dError = dError;	}	if(++m_nCurrentDim >= m_nDimensions)		m_nCurrentDim = 0;}	// --------------------------------------------------------------------------------GStochasticGreedySearch::GStochasticGreedySearch(GRealVectorCritic* pCritic, double dMin, double dRange): GRealVectorSearch(pCritic){	m_dRange = dRange;	m_dConservativeness = 4;	m_nDimensions = pCritic->GetVectorSize();	m_pVector = new double[m_nDimensions];	m_pTest = new double[m_nDimensions];	int i;	for(i = 0; i < m_nDimensions; i++)		m_pVector[i] = GBits::GetRandomDouble() * dRange + dMin;	m_dError = pCritic->Critique(m_pVector);}/*virtual*/ GStochasticGreedySearch::~GStochasticGreedySearch(){	delete(m_pVector);	delete(m_pTest);}/*virtual*/ void GStochasticGreedySearch::Iterate(){	// Pick a new spot to try	int i;	for(i = 0; i < m_nDimensions; i++)	{		if(rand() & 1)			m_pTest[i] = m_pVector[i] + (m_dRange * pow(GBits::GetRandomDouble(), m_dConservativeness));		else			m_pTest[i] = m_pVector[i] - (m_dRange * pow(GBits::GetRandomDouble(), m_dConservativeness));	}	// Critique the current spots and find the global best	double dError = m_pCritic->Critique(m_pTest);	if(dError < m_dError)	{		double* pTmp = m_pTest;		m_pTest = m_pVector;		m_pVector = pTmp;		m_dError = dError;	}}// --------------------------------------------------------------------------------GActionGreedySearch::GActionGreedySearch(GActionPathState* pStartState, int nActionCount): GActionPathSearch(pStartState, nActionCount){	m_pPath = new GActionPath(pStartState);	m_dPrevError = 1e200;}// virtualGActionGreedySearch::~GActionGreedySearch(){	delete(m_pPath);}// virtualbool GActionGreedySearch::Iterate(){	GActionPath* pPathBest = m_pPath->Fork();	pPathBest->DoAction(0);	double dBest = pPathBest->Critique();	double dTest;	GActionPath* pPathTest;	GActionPath* pTmp;	int i;	for(i = 1; i < m_nActionCount; i++)	{		if(i == m_nActionCount - 1)			pPathTest = m_pPath;		else			pPathTest = m_pPath->Fork();		pPathTest->DoAction(i);		dTest = pPathTest->Critique();		if(dTest < dBest)		{			dBest = dTest;			pTmp = pPathBest;			pPathBest = pPathTest;			pPathTest = pTmp;		}		delete(pPathTest);	}	m_pPath = pPathBest;	// See if we got any better	if(dBest >= m_dPrevError)		return true;	m_dPrevError = dBest;	return false;}// virtualGActionPath* GActionGreedySearch::GetBestPath(){	return m_pPath;}// virtualdouble GActionGreedySearch::GetBestPathError(){	return m_dPrevError;}#ifndef NO_TEST_CODEclass GActionGreedySearchTestState : public GActionPathState{protected:	int m_state[2];	GImage* m_pImage;public:	GActionGreedySearchTestState(GImage* pImage)	{		m_state[0] = 10;		m_state[1] = 10;		m_pImage = pImage;	}	virtual ~GActionGreedySearchTestState()	{	}	virtual GActionPathState* Copy()	{		GActionGreedySearchTestState* pNewState = new GActionGreedySearchTestState(m_pImage);		pNewState->m_state[0] = m_state[0];		pNewState->m_state[1] = m_state[1];		return pNewState;	}	virtual double CritiquePath(int nPathLen, GAction* pLastAction)	{		int x = m_state[0];		int y = m_state[1];		//GColor col = m_image.GetPixel(x, y);		//int nPrevLen = 255 - (int)gBlue(col);		//if(nPrevLen <= nLen) // if a shorter path ever got to this state		//	return 1e200;		m_pImage->SetPixel(x, y, gARGB(0xff, 0, 88, 255 - nPathLen)); // record the path length in the blue channel		double dError = 0;		double d;		d = 49 - m_state[0];		dError += d * d;		d = 49 - m_state[1];		dError += d * d;		return dError;	}	virtual void PerformAction(int nAction)	{		switch(nAction)		{			case 0:				if(m_state[0] < m_pImage->GetWidth() - 1 && gGreen(m_pImage->GetPixel(m_state[0] + 1, m_state[1])) < 10)					m_state[0]++;				break;			case 1:				if(m_state[0] > 0 && gGreen(m_pImage->GetPixel(m_state[0] - 1, m_state[1])) < 10)					m_state[0]--;				break;			case 2:				if(m_state[1] < m_pImage->GetHeight() - 1 && gGreen(m_pImage->GetPixel(m_state[0], m_state[1] + 1)) < 10)					m_state[1]++;				break;			case 3:				if(m_state[1] > 0 && gGreen(m_pImage->GetPixel(m_state[0], m_state[1] - 1)) < 10)					m_state[1]--;				break;			default:				GAssert(false, "unrecognized action");				break;		}	}};// staticvoid GActionGreedySearch::Test(){	GImage image;	image.SetSize(50, 50);	image.Clear(0xff000000);	image.DrawLine(17, 20, 27, 20, 0xffffffff);	image.DrawLine(30, 40, 40, 40, 0xffffffff);	image.DrawLine(40, 30, 40, 40, 0xffffffff);	GActionGreedySearch search(new GActionGreedySearchTestState(&image), 4);	while(!search.Iterate())	{	}	GActionPath* pPath = search.GetBestPath();	int nLen = pPath->GetLength();	if(nLen < 50)		throw "impossible";	if(nLen > 70)		throw "path unnecessarily long";	int* path = new int[nLen];	Holder<int*> hPath(path);	pPath->GetPath(nLen, path);	int x = 10;	int y = 10;	int i;	for(i = 0; i < nLen; i++)	{		image.SetPixel(x, y, 0xff00ffff);		if(path[i] == 0)			x++;		else if(path[i] == 2)			y++;		else if(path[i] == 1)		{			x--;			throw "unnecessary action";		}		else if(path[i] == 3)		{			y--;			throw "unnecessary action";		}	}	//image.SavePNGFile("greedy.png");}#endif // !NO_TEST_CODE// --------------------------------------------------------------------------------class GAStarNode : public GAVLNode{protected:	double m_dError;	GActionPath* m_pPath;public:	// Takes ownership of pPath	GAStarNode(GActionPath* pPath) : GAVLNode()	{		m_dError = pPath->Critique();		m_pPath = pPath;	}	virtual ~GAStarNode()	{		delete(m_pPath);	}	void SetNewPath(GActionPath* pPath)	{		delete(m_pPath);		m_pPath = pPath;		m_dError = pPath->Critique();	}	virtual int Compare(GAVLNode* pThat)	{		GAStarNode* pOther = (GAStarNode*)pThat;		if(m_dError < pOther->m_dError)			return -1;		else if(m_dError > pOther->m_dError)			return 1;		else			return 0;	}	GActionPath* DropPath()	{		GActionPath* pPath = m_pPath;		m_pPath = NULL;		return pPath;	}	double GetError()	{		return m_dError;	}};GAStarSearch::GAStarSearch(GActionPathState* pStartState, int nActionCount, int nMaxPaths)	: GActionPathSearch(pStartState, nActionCount){	m_nMaxPaths = nMaxPaths;	m_dBestError = 1e200;	m_pBestPath = NULL;	m_pPriorityQueue = new GAVLTree();	GActionPath* pEmptyPath = new GActionPath(pStartState);	m_pPriorityQueue->Insert(new GAStarNode(pEmptyPath));}// virtualGAStarSearch::~GAStarSearch(){	delete(m_pPriorityQueue);	delete(m_pBestPath);}// virtualbool GAStarSearch::Iterate(){	GAStarNode* pSpare = (GAStarNode*)m_pPriorityQueue->Unlink(0);	GActionPath* pPath = pSpare->DropPath();	double dError = pSpare->GetError();	int i;	for(i = 0; i < m_nActionCount; i++)	{		GActionPath* pNewPath = pPath->Fork();		pNewPath->DoAction(i);		if(pSpare)		{			pSpare->SetNewPath(pNewPath);			m_pPriorityQueue->Insert(pSpare);			pSpare = NULL;		}		else			m_pPriorityQueue->Insert(new GAStarNode(pNewPath));		if(m_pPriorityQueue->GetSize() > m_nMaxPaths)			pSpare = (GAStarNode*)m_pPriorityQueue->Unlink(m_pPriorityQueue->GetSize() - 1);	}	delete(pSpare);	if(dError < m_dBestError)	{		delete(m_pBestPath);		m_pBestPath = pPath;		m_dBestError = dError;	}	else		delete(pPath);	return m_dBestError == 0;}// virtualGActionPath* GAStarSearch::GetBestPath(){	return m_pBestPath;}// virtualdouble GAStarSearch::GetBestPathError(){	return m_dBestError;}#ifndef NO_TEST_CODEclass GAStarSearchTestState : public GActionPathState{protected:	int m_state[2];	GImage* m_pImage;public:	GAStarSearchTestState(GImage* pImage)	{		m_state[0] = 10;		m_state[1] = 10;		m_pImage = pImage;	}	virtual ~GAStarSearchTestState()	{	}	virtual GActionPathState* Copy()	{		GAStarSearchTestState* pNewState = new GAStarSearchTestState(m_pImage);		pNewState->m_state[0] = m_state[0];		pNewState->m_state[1] = m_state[1];		return pNewState;	}	virtual double CritiquePath(int nPathLen, GAction* pLastAction)	{		int x = m_state[0];		int y = m_state[1];		GColor col = m_pImage->GetPixel(x, y);		int nPrevLen = 255 - (int)gBlue(col);		if(nPrevLen <= nPathLen) // if a shorter path ever got to this state			return 1e200;		m_pImage->SetPixel(x, y, gARGB(0xff, 0, 88, 255 - nPathLen)); // record the path length in the blue channel		double dError = 0;		double d;		d = 49 - m_state[0];		dError += d * d;		d = 49 - m_state[1];		dError += d * d;		if(dError == 0)			return 0;		dError = 1.4 * sqrt(dError) + nPathLen;		return dError;	}	virtual void PerformAction(int nAction)	{		switch(nAction)		{			case 0:				if(m_state[0] < m_pImage->GetWidth() - 1 && gGreen(m_pImage->GetPixel(m_state[0] + 1, m_state[1])) < 10)					m_state[0]++;				break;			case 1:				if(m_state[0] > 0 && gGreen(m_pImage->GetPixel((int)m_state[0] - 1, m_state[1])) < 10)					m_state[0]--;				break;			case 2:				if(m_state[1] < m_pImage->GetHeight() - 1 && gGreen(m_pImage->GetPixel(m_state[0], m_state[1] + 1)) < 10)					m_state[1]++;				break;			case 3:				if(m_state[1] > 0 && gGreen(m_pImage->GetPixel((int)m_state[0], m_state[1] - 1)) < 10)					m_state[1]--;				break;			default:				GAssert(false, "unrecognized action");				break;		}	}};// staticvoid GAStarSearch::Test(){	GImage image;	image.SetSize(50, 50);	image.Clear(0xff000000);	image.DrawLine(30, 40, 40, 40, 0xffffffff);	image.DrawLine(40, 30, 40, 40, 0xffffffff);	GAStarSearch search(new GAStarSearchTestState(&image), 4, 1000);	while(!search.Iterate())	{	}	GActionPath* pPath = search.GetBestPath();	int nLen = pPath->GetLength();	if(nLen < 70)		throw "impossible";	if(nLen > 80)		throw "path unnecessarily long";	int* path = new int[nLen];	Holder<int*> hPath(path);	pPath->GetPath(nLen, path);	int x = 10;	int y = 10;	int i;	for(i = 0; i < nLen; i++)	{		image.SetPixel(x, y, 0xff00ffff);		if(path[i] == 0)			x++;		else if(path[i] == 2)			y++;		else if(path[i] == 1)		{			x--;			throw "unnecessary action";		}		else if(path[i] == 3)		{			y--;			throw "unnecessary action";		}	}	if(x != 49 || y != 49)		throw "wrong target state";	//image.SavePNGFile("astar.png");}#endif // !NO_TEST_CODE

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?