⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cqetraces.cpp

📁 强化学习算法(R-Learning)难得的珍贵资料
💻 CPP
字号:
// Copyright (C) 2003
// Gerhard Neumann (gerhard@igi.tu-graz.ac.at)

//                
// This file is part of RL Toolbox.
// http://www.igi.tugraz.at/ril_toolbox
//
// All rights reserved.
// 
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
//    notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
//    notice, this list of conditions and the following disclaimer in the
//    documentation and/or other materials provided with the distribution.
// 3. The name of the author may not be used to endorse or promote products
//    derived from this software without specific prior written permission.
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
// OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
// IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
// NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "ril_debug.h"
#include "cqetraces.h"

#include <assert.h>
#include <math.h>

CAbstractQETraces::CAbstractQETraces(CAbstractQFunction *qFunction)
{
	this->qFunction = qFunction;
	
	addParameter("Lambda", 0.9);
	addParameter("DiscountFactor", 0.95);

	addParameter("ReplacingETraces", 1.0);
}

void CAbstractQETraces::setLambda(rlt_real lambda)
{
	setParameter("Lambda", lambda);
}
	
rlt_real CAbstractQETraces::getLambda()
{
	return getParameter("Lambda");
}

void CAbstractQETraces::setReplacingETraces(bool bReplace)
{
	if (bReplace)
	{
		setParameter("ReplacingETraces", 1.0);
	}
	else
	{
		setParameter("ReplacingETraces", 0.0);

	}
}

bool CAbstractQETraces::getReplacingETraces()
{
	return getParameter("ReplacingETraces") > 0.5;
}

CQETraces::CQETraces(CQFunction *qfunction) : CAbstractQETraces(qfunction)
{
	this->qExFunction = qfunction;

	vETraces = new std::list<CAbstractVETraces *>();

	CActionSet::iterator it = qExFunction->getActions()->begin();

	for (unsigned int i = 0; i < qExFunction->getNumActions(); i ++, it ++)
	{
		if ((*it) != NULL)
		{
			CAbstractVETraces *vETrace = qExFunction->getVFunction(*it)->getStandardETraces();
			addParameters(vETrace);
			vETraces->push_back(vETrace);
		}
		else
		{
			vETraces->push_back(NULL);
		}
	}
}

CQETraces::~CQETraces()
{
	std::list<CAbstractVETraces *>::iterator it = vETraces->begin();
	for (; it != vETraces->end(); it++)
	{
		if (*it != NULL)
		{
			delete (*it);
		}
	}
	
	delete vETraces;
}

void CQETraces::resetETraces()
{
	std::list<CAbstractVETraces *>::iterator it = vETraces->begin();
	for (int i = 0; it != vETraces->end(); it++, i++)
	{
		assert((*it) != NULL);
		(*it)->resetETraces();
	}
}

void CQETraces::updateETraces(CAction *action, CActionData *data)
{
	std::list<CAbstractVETraces *>::iterator it = vETraces->begin();
	int duration = 1;
	if (action->isType(MULTISTEPACTION))
	{
		if (data)
		{
			duration = dynamic_cast<CMultiStepActionData *>(data)->duration;
		}
		else
		{
			duration = action->getDuration();
		}
	}

	for (int i = 0; it != vETraces->end(); it++, i++)
	{
		assert((*it) != NULL);
		(*it)->updateETraces(duration);
	}
}


void CQETraces::addETrace(CStateCollection *state, CAction *action, rlt_real factor, CActionData *data)
{
	int index = qExFunction->getActions()->getIndex(action);
	std::list<CAbstractVETraces *>::iterator it = vETraces->begin();

	for (int i = 0; it != vETraces->end(); it++, i++)
	{
		assert((*it) != NULL);
		
		if (index == i)
		{
			(*it)->addETrace(state, factor);
			break;
		}
	}
}

void CQETraces::updateQFunction(rlt_real td) 
{
	std::list<CAbstractVETraces *>::iterator it = vETraces->begin();
	CAbstractVFunction *vFunction;

	for (int i = 0; it != vETraces->end(); it++, i++)
	{
		vFunction = qExFunction->getVFunction(i);
		if ((*it) != NULL && (*it)->getVFunction() != vFunction)
		{
			delete *it;
			*it = vFunction->getStandardETraces();
		}
		
		DebugPrint('e', "ETraces Nr: %d \n", i);
		
		(*it)->updateVFunction(td);
	}
}

void CQETraces::setVETrace(CAbstractVETraces *vETrace, int index, bool bDelete)
{
	assert(qExFunction->getVFunction(index) == vETrace->getVFunction());

	std::list<CAbstractVETraces *>::iterator it = vETraces->begin();
	
	for (int i = 0; i < index; i++, it ++);
	
	if (bDelete && *it != NULL)
	{
		delete *it;
	}
	*it = vETrace;

	addParameters(vETrace);
}

CAbstractVETraces *CQETraces::getVETrace(int index)
{
	std::list<CAbstractVETraces *>::iterator it = vETraces->begin();
	
	for (int i = 0; i < index; i++, it ++);
	
	return *it;
}

void CQETraces::setReplacingETraces(bool bReplace)
{
	std::list<CAbstractVETraces *>::iterator it = vETraces->begin();
	

	for (int i = 0; it != vETraces->end(); it++, i++)
	{
		if ((*it) != NULL)
		{
			(*it)->setReplacingETraces(bReplace);
		}
	}
}



CComposedQETraces::CComposedQETraces(CComposedQFunction *qfunction) : CAbstractQETraces(qfunction)
{
	this->qCompFunction = qfunction;

	qETraces = new std::list<CAbstractQETraces *>();

	std::list<CAbstractQFunction *>::iterator it = qCompFunction->getQFunctions()->begin();

	for (int i = 0; i < qCompFunction->getNumQFunctions(); i ++, it ++)
	{
		if ((*it) != NULL)
		{
			CAbstractQETraces *qETrace = (*it)->getStandardETraces();
			addParameters(qETrace);
			qETraces->push_back(qETrace);
		}
		else
		{
			qETraces->push_back(NULL);
		}
	}
}
	
CComposedQETraces::~CComposedQETraces()
{
	std::list<CAbstractQETraces *>::iterator it = qETraces->begin();
	for (; it != qETraces->end(); it++)
	{
		if (*it != NULL)
		{
			delete (*it);
		}
	}

	delete qETraces;
}

void CComposedQETraces::resetETraces()
{
	std::list<CAbstractQETraces *>::iterator it = qETraces->begin();
	for (int i = 0; it != qETraces->end(); it++, i++)
	{
		assert((*it) != NULL);
		(*it)->resetETraces();
	}
}

void CComposedQETraces::addETrace(CStateCollection *state, CAction *action, rlt_real factor,  CActionData *data )
{
	std::list<CAbstractQETraces *>::iterator it = qETraces->begin();	
	std::list<CAbstractQFunction *>::iterator itQFunc = qCompFunction->getQFunctions()->begin();


	for (int i = 0; it != qETraces->end(); it++, i++)
	{
		assert((*it) != NULL);

		if ((*itQFunc)->getActions()->isMember(action))
		{
			(*it)->addETrace(state, action, factor, data);
			break;
		}
	}
}

void CComposedQETraces::updateETraces(CAction *action,  CActionData *data)
{
	std::list<CAbstractQETraces *>::iterator it = qETraces->begin();
	
	for (int i = 0; it != qETraces->end(); it++, i++)
	{
		assert((*it) != NULL);
		(*it)->updateETraces(action, data);
	}	
}

void CComposedQETraces::updateQFunction(rlt_real td)
{
	std::list<CAbstractQETraces *>::iterator it = qETraces->begin();
	//std::list<CAbstractQFunction *>::iterator itQFunc = qCompFunction->getQFunctions()->begin();

	//CAbstractQFunction *qFunction;

	for (; it != qETraces->end(); it++)
	{
		(*it)->updateQFunction(td);
	}
}

void CComposedQETraces::setQETrace(CAbstractQETraces *qETrace, int index, bool bDeleteOld)
{
	std::list<CAbstractQETraces *>::iterator it = qETraces->begin();

	for (int i = 0; i < index; i++, it ++);

	if (bDeleteOld && (*it) != NULL)
	{
		delete *it;
	}
	*it = qETrace;
}


CAbstractQETraces *CComposedQETraces::getQETrace(int index)
{
	std::list<CAbstractQETraces *>::iterator it = qETraces->begin();

	for (int i = 0; i < index; i++, it ++);

	return *it;
}

void CComposedQETraces::setReplacingETraces(bool bReplace)
{
	std::list<CAbstractQETraces *>::iterator it = qETraces->begin();


	for (int i = 0; it != qETraces->end(); it++, i++)
	{
		if ((*it) != NULL)
		{
			(*it)->setReplacingETraces(bReplace);
		}
	}
}
/*
void CComposedQETraces::setLambda(rlt_real lambda)
{
	this->lambda = lambda;
	std::list<CAbstractQETraces *>::iterator it = qETraces->begin();


	for (int i = 0; it != qETraces->end(); it++, i++)
	{
		if ((*it) != NULL)
		{
			(*it)->setLambda(lambda);
		}
	}
}*/


CGradientQETraces::CGradientQETraces(CGradientQFunction *qfunction) : CAbstractQETraces(qfunction)
{
	this->gradientQFunction = qfunction;

	gradient = new CFeatureList(10);
	eTrace = new CFeatureList(10, true, true);

	addParameter("ETraceTreshold", 0.001);
	addParameter("ETraceMaxListSize", 1000);


}

CGradientQETraces::~CGradientQETraces()
{
	delete gradient;
	delete eTrace;
}


void CGradientQETraces::resetETraces()
{
	eTrace->clear();
}

void CGradientQETraces::addETrace(CStateCollection *State, CAction *action, rlt_real factor, CActionData *data)
{
	gradient->clear();
	gradientQFunction->getGradient(State, action, data, gradient);

	addGradientETrace(gradient, factor);
}

void CGradientQETraces::addGradientETrace(CFeatureList *l_gradient, rlt_real factor)
{
	CFeatureList::iterator it = l_gradient->begin();

	bool replacingETraces = this->getReplacingETraces();

	for (; it != l_gradient->end(); it++)
	{
		DebugPrint('e', "%d : %f -> ",(*it)->featureIndex, eTrace->getFeatureFactor((*it)->featureIndex));

		rlt_real featureFactor = (*it)->factor * factor;

		bool signNew = featureFactor > 0;
		bool signOld = eTrace->getFeatureFactor((*it)->featureIndex) > 0;

		if (replacingETraces)
		{
			if (signNew == signOld)
			{
				if (fabs(featureFactor) > fabs(eTrace->getFeatureFactor((*it)->featureIndex)))
				{
					eTrace->set((*it)->featureIndex ,featureFactor);
				}
			}
			else
			{
				eTrace->update((*it)->featureIndex ,featureFactor);
			}
		}
		else
		{
			eTrace->update((*it)->featureIndex ,featureFactor);
		}

		DebugPrint('e', "%f\n", eTrace->getFeatureFactor((*it)->featureIndex));
	}

	int maxSize = my_round(getParameter("ETraceMaxListSize"));

	while (eTrace->size() > maxSize && maxSize > 0)
	{
		eTrace->remove(* eTrace->rbegin());
	}

}

void CGradientQETraces::updateETraces(CAction *action,  CActionData *data)
{
	CFeatureList::iterator it = eTrace->begin();

	int duration = action->getDuration();

	if (DebugIsEnabled('e'))
	{
		DebugPrint('e', "Etraces Bevore Updating: ");
		eTrace->saveASCII(DebugGetFileHandle('e'));
		DebugPrint('e',"\n");
	}

	if (action->isType(MULTISTEPACTION))
	{
		if (data)
		{
			duration = dynamic_cast<CMultiStepActionData *>(data)->duration;
		}
		else
		{
			duration = action->getDuration();
		}
	}

	int i = 0;

	rlt_real mult = getParameter("Lambda") * pow(getParameter("DiscountFactor"), duration);
	rlt_real treshold = getParameter("ETraceTreshold");

	while (it != eTrace->end())
	{
		(*it)->factor *= mult;
		if (fabs((*it)->factor) < treshold)
		{
			DebugPrint('e', "Deleting Etrace %d\n", (*it)->featureIndex);
			eTrace->remove(*it);
			
			it = eTrace->begin();
			for (int j = 0; j < i; j++, it++);

			
		}
		else
		{
			i++;
			it++;
		}
	}

	if (DebugIsEnabled('e'))
	{
		DebugPrint('e', "Etraces After Updating: ");
		eTrace->saveASCII(DebugGetFileHandle('e'));
		DebugPrint('e',"\n");
	}
}

void CGradientQETraces::updateQFunction(rlt_real td)
{
	DebugPrint('t', "Updating GradientQ-Function with TD %f \n", td);
	gradientQFunction->updateGradient(eTrace, td);
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -