⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cerrorfunction.cpp

📁 强化学习算法(R-Learning)难得的珍贵资料
💻 CPP
字号:
// Copyright (C) 2003
// Gerhard Neumann (gerhard@igi.tu-graz.ac.at)

//                
// This file is part of RL Toolbox.
// http://www.igi.tugraz.at/ril_toolbox
//
// All rights reserved.
// 
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
//    notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
//    notice, this list of conditions and the following disclaimer in the
//    documentation and/or other materials provided with the distribution.
// 3. The name of the author may not be used to endorse or promote products
//    derived from this software without specific prior written permission.
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
// OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
// IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
// NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "ril_debug.h"
#include "cerrorfunction.h"

#include <math.h>


rlt_real CErrorFunction::getError(CStateCollection *oldState, CAction *action ,rlt_real reward, CStateCollection *newState, CActionData *data)
{
	return 0.5 * pow(getDerivatedError(oldState, action, reward, newState, data), 2);
}

CTDErrorFunction::CTDErrorFunction(CAgentController *estimationPolicy, CResidualFunction *residual,CResidualGradientFunction *residualGradientFunction, CGradientQFunction *qFunction)
{
	this->estimationPolicy = estimationPolicy;
	this->residual = residual;
	this->residualGradientFunction = residualGradientFunction;
	this->qFunction = qFunction;

	oldGradient = new CFeatureList();
	newGradient = new CFeatureList();
	residualGradient = new CFeatureList();

	actionDataSet = new CActionDataSet(qFunction->getActions());

	addParameters(qFunction);
	addParameters(estimationPolicy);
	addParameters(residual);
	addParameters(residualGradientFunction);
}

CTDErrorFunction::~CTDErrorFunction()
{
	delete oldGradient;
	delete newGradient;
	delete residualGradient;

	delete actionDataSet;
}



rlt_real CTDErrorFunction::getDerivatedError(CStateCollection *oldState, CAction *action, rlt_real reward, CStateCollection *newState, CActionData *data)
{
	rlt_real oldQ = qFunction->getValue(oldState, action, data);
	CAction *estimationAction = estimationPolicy->getNextAction(newState, actionDataSet);
	rlt_real newQ = qFunction->getValue(newState, estimationAction, actionDataSet->getActionData(estimationAction));

	return residual->getResidual(oldQ, reward, action->getDuration(), newQ);
}

void CTDErrorFunction::getErrorGradient(CStateCollection *oldState, CAction *action, CStateCollection *newState, rlt_real reward, CFeatureList *gradient, CActionData *data)
{
	rlt_real duration = action->getDuration();

	oldGradient->clear();
	newGradient->clear();

	qFunction->getGradient(oldState, action, data, oldGradient);
//	rlt_real residualError = getDerivatedError(oldState, action, reward, newState, data);

	CAction *lastEstimatedAction = estimationPolicy->getNextAction(newState, actionDataSet);
	qFunction->getGradient(newState, lastEstimatedAction, actionDataSet->getActionData(lastEstimatedAction), newGradient);

	residualGradientFunction->getResidualGradient(oldGradient, newGradient, duration, gradient);

//	gradient->multFactor(residualError);
}

int CTDErrorFunction::getNumWeights()
{
	return qFunction->getNumWeights(); 
}

void CTDErrorFunction::updateGradient(CFeatureList *gradientFeatures)
{
	qFunction->updateGradient(gradientFeatures);
}

CAdvantageLearningErrorFunction::CAdvantageLearningErrorFunction(CGradientQFunction *qFunction, rlt_real dt,rlt_real K,rlt_real beta)
{
	this->qFunction = qFunction;

	addParameters(qFunction);
	addParameter("TimeScale", K);
	addParameter("ResidualBeta", beta);
	addParameter("DiscountFactor", 0.95);

	this->dt = dt;

	this->qFuncGradient = new CFeatureList();
	this->actionDataSet = new CActionDataSet(qFunction->getActions());
}

CAdvantageLearningErrorFunction::~CAdvantageLearningErrorFunction()
{
	delete qFuncGradient;
	delete actionDataSet;
}

rlt_real CAdvantageLearningErrorFunction::getDerivatedError(CStateCollection *oldState, CAction *action ,rlt_real reward, CStateCollection *newState, CActionData *data)
{
	rlt_real K = getParameter("TimeScale");
	rlt_real oldQ = qFunction->getValue(oldState, action, data);
	rlt_real newMaxQ = qFunction->getMaxValue(newState, qFunction->getActions());
	rlt_real oldMaxQ = qFunction->getMaxValue(oldState, qFunction->getActions());

	return reward + getParameter("DiscountFactor") * newMaxQ - dt/K * oldQ + (dt/K - 1) * oldMaxQ;
}

void CAdvantageLearningErrorFunction::getErrorGradient(CStateCollection *oldState, CAction *action, CStateCollection *newState, rlt_real reward,CFeatureList *gradient, CActionData *actionData)
{
	rlt_real beta = getParameter("ResidualBeta");
	rlt_real gamma = getParameter("DiscountFactor");
	rlt_real K = getParameter("TimeScale");

	rlt_real duration = action->getDuration();
//	rlt_real advantage = getDerivatedError(oldState, action, reward, newState, actionData);
	
	// get Gradient of the old state and executed action
	qFuncGradient->clear();
	qFunction->getGradient(oldState, action, actionData, qFuncGradient);

	gradient->add(qFuncGradient);

	// newMaxGradient
	qFuncGradient->clear();
	CAction *newMaxAction = qFunction->getMax(newState, qFunction->getActions(), actionDataSet);

	qFunction->getGradient(newState, newMaxAction, actionDataSet->getActionData(newMaxAction), qFuncGradient);

	gradient->add(qFuncGradient,(- beta) * pow(gamma, dt * duration) / (dt * duration * K));

	qFuncGradient->clear();
	CAction *oldMaxAction = qFunction->getMax(oldState, qFunction->getActions(), actionDataSet);

	qFunction->getGradient(oldState, oldMaxAction, actionDataSet->getActionData(oldMaxAction), qFuncGradient);

	gradient->add(qFuncGradient, (- beta) * (1 - 1/(dt * duration * K)));

//	gradient->multFactor(advantage);	
}

int CAdvantageLearningErrorFunction::getNumWeights()
{
	return qFunction->getNumWeights();
}

void CAdvantageLearningErrorFunction::updateGradient(CFeatureList *gradientFeatures)
{
	qFunction->updateGradient(gradientFeatures);
}

void CAdvantageLearningErrorFunction::setK(rlt_real K)
{
	setParameter("TimeScale", K);
}

rlt_real CAdvantageLearningErrorFunction::getK()
{
	return getParameter("TimeScale");
}

void CAdvantageLearningErrorFunction::setTimeIntervall(rlt_real dt)
{
	this->dt = dt;
}

rlt_real CAdvantageLearningErrorFunction::getTimeIntervall()
{
	return dt;
}

CValueIterationErrorFunction::CValueIterationErrorFunction(CGradientVFunction *vFunction, CResidualFunction *residualFunction, CResidualGradientFunction *gradientFunction)
{
	this->vFunction = vFunction;
	oldGradient = new CFeatureList();
	newGradient = new CFeatureList();

	this->residualFunction = residualFunction;
	this->gradientFunction = gradientFunction;

	addParameters(vFunction);
	addParameters(residualFunction);
	addParameters(gradientFunction);
}

CValueIterationErrorFunction::~CValueIterationErrorFunction()
{
	delete oldGradient;
	delete newGradient;
}

rlt_real CValueIterationErrorFunction::getDerivatedError(CStateCollection *oldState, CAction *action ,rlt_real reward, CStateCollection *newState, CActionData *actionData)
{
	int duration = 1;

	if (actionData && action->isType(MULTISTEPACTION))
	{
		duration = dynamic_cast<CMultiStepActionData *>(actionData)->duration;
	}
	else
	{
		duration = action->getDuration();
	}

	rlt_real error = residualFunction->getResidual(vFunction->getValue(oldState), reward, duration,vFunction->getValue(newState));
	DebugPrint('t', "ValueIteration ErrorFunction: Error %f\n", error);

	return error;
}

void CValueIterationErrorFunction::getErrorGradient(CStateCollection *oldState, CAction *action, CStateCollection *newState, rlt_real reward, CFeatureList *gradient, CActionData *actionData)
{
	rlt_real duration = action->getDuration();

	if (actionData && action->isType(MULTISTEPACTION))
	{
		duration = dynamic_cast<CMultiStepActionData *>(actionData)->duration;
	}


	oldGradient->clear();
	newGradient->clear();

	vFunction->getGradient(oldState, oldGradient);
	vFunction->getGradient(newState, newGradient);

	gradientFunction->getResidualGradient(oldGradient, newGradient, duration, gradient);
}

int CValueIterationErrorFunction::getNumWeights()
{
	return vFunction->getNumWeights();
}

void CValueIterationErrorFunction::updateGradient(CFeatureList *gradientFeatures)
{
	vFunction->updateGradient(gradientFeatures, 1.0);
}

CPolicySearchErrorFunction::CPolicySearchErrorFunction(CErrorFunction *errorFunction, rlt_real beta, rlt_real b, rlt_real gamma)
{
	this->errorFunction = errorFunction;
	
	CErrorFunction::addParameter("DiscountFactor", gamma);
	CErrorFunction::addParameter("PolicySearchBeta", beta);
	CErrorFunction::addParameter("PolicySearchB", b);

	this->timeStep = 0;
}

rlt_real CPolicySearchErrorFunction::getError(CStateCollection *oldState, CAction *action ,rlt_real reward, CStateCollection *newState, CActionData *actionData)
{
	rlt_real b = CErrorFunction::getParameter("PolicySearchB");
	rlt_real beta = CErrorFunction::getParameter("PolicySearchBeta");
	rlt_real gamma = CErrorFunction::getParameter("DiscountFactor");
	return (1 - beta) * errorFunction->getError(oldState, action, reward, newState, actionData) + beta * (b - pow(gamma, timeStep) * reward);
}

rlt_real CPolicySearchErrorFunction::getDerivatedError(CStateCollection *oldState, CAction *action ,rlt_real reward, CStateCollection *newState, CActionData *actionData)
{
	return errorFunction->getDerivatedError(oldState, action, reward, newState,actionData);
}

void CPolicySearchErrorFunction::getErrorGradient(CStateCollection *oldState, CAction *action, CStateCollection *newState, rlt_real reward, CFeatureList *gradient, CActionData *actionData)
{
	errorFunction->getErrorGradient(oldState, action, newState, reward, gradient, actionData);
}

int CPolicySearchErrorFunction::getNumWeights()
{
	return errorFunction->getNumWeights();
}

void CPolicySearchErrorFunction::updateGradient(CFeatureList *gradientFeatures)
{
	errorFunction->updateGradient(gradientFeatures);
}

void CPolicySearchErrorFunction::nextStep(CStateCollection *oldState, CAction *action, CStateCollection *newState)
{
	timeStep += action->getDuration();
}

void CPolicySearchErrorFunction::newEpisode()
{
	timeStep = 0.0;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -