📄 cerrorfunction.cpp
字号:
// Copyright (C) 2003
// Gerhard Neumann (gerhard@igi.tu-graz.ac.at)
//
// This file is part of RL Toolbox.
// http://www.igi.tugraz.at/ril_toolbox
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// 3. The name of the author may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
// OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
// IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
// NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "ril_debug.h"
#include "cerrorfunction.h"
#include <math.h>
rlt_real CErrorFunction::getError(CStateCollection *oldState, CAction *action ,rlt_real reward, CStateCollection *newState, CActionData *data)
{
return 0.5 * pow(getDerivatedError(oldState, action, reward, newState, data), 2);
}
CTDErrorFunction::CTDErrorFunction(CAgentController *estimationPolicy, CResidualFunction *residual,CResidualGradientFunction *residualGradientFunction, CGradientQFunction *qFunction)
{
this->estimationPolicy = estimationPolicy;
this->residual = residual;
this->residualGradientFunction = residualGradientFunction;
this->qFunction = qFunction;
oldGradient = new CFeatureList();
newGradient = new CFeatureList();
residualGradient = new CFeatureList();
actionDataSet = new CActionDataSet(qFunction->getActions());
addParameters(qFunction);
addParameters(estimationPolicy);
addParameters(residual);
addParameters(residualGradientFunction);
}
CTDErrorFunction::~CTDErrorFunction()
{
delete oldGradient;
delete newGradient;
delete residualGradient;
delete actionDataSet;
}
rlt_real CTDErrorFunction::getDerivatedError(CStateCollection *oldState, CAction *action, rlt_real reward, CStateCollection *newState, CActionData *data)
{
rlt_real oldQ = qFunction->getValue(oldState, action, data);
CAction *estimationAction = estimationPolicy->getNextAction(newState, actionDataSet);
rlt_real newQ = qFunction->getValue(newState, estimationAction, actionDataSet->getActionData(estimationAction));
return residual->getResidual(oldQ, reward, action->getDuration(), newQ);
}
void CTDErrorFunction::getErrorGradient(CStateCollection *oldState, CAction *action, CStateCollection *newState, rlt_real reward, CFeatureList *gradient, CActionData *data)
{
rlt_real duration = action->getDuration();
oldGradient->clear();
newGradient->clear();
qFunction->getGradient(oldState, action, data, oldGradient);
// rlt_real residualError = getDerivatedError(oldState, action, reward, newState, data);
CAction *lastEstimatedAction = estimationPolicy->getNextAction(newState, actionDataSet);
qFunction->getGradient(newState, lastEstimatedAction, actionDataSet->getActionData(lastEstimatedAction), newGradient);
residualGradientFunction->getResidualGradient(oldGradient, newGradient, duration, gradient);
// gradient->multFactor(residualError);
}
int CTDErrorFunction::getNumWeights()
{
return qFunction->getNumWeights();
}
void CTDErrorFunction::updateGradient(CFeatureList *gradientFeatures)
{
qFunction->updateGradient(gradientFeatures);
}
CAdvantageLearningErrorFunction::CAdvantageLearningErrorFunction(CGradientQFunction *qFunction, rlt_real dt,rlt_real K,rlt_real beta)
{
this->qFunction = qFunction;
addParameters(qFunction);
addParameter("TimeScale", K);
addParameter("ResidualBeta", beta);
addParameter("DiscountFactor", 0.95);
this->dt = dt;
this->qFuncGradient = new CFeatureList();
this->actionDataSet = new CActionDataSet(qFunction->getActions());
}
CAdvantageLearningErrorFunction::~CAdvantageLearningErrorFunction()
{
delete qFuncGradient;
delete actionDataSet;
}
rlt_real CAdvantageLearningErrorFunction::getDerivatedError(CStateCollection *oldState, CAction *action ,rlt_real reward, CStateCollection *newState, CActionData *data)
{
rlt_real K = getParameter("TimeScale");
rlt_real oldQ = qFunction->getValue(oldState, action, data);
rlt_real newMaxQ = qFunction->getMaxValue(newState, qFunction->getActions());
rlt_real oldMaxQ = qFunction->getMaxValue(oldState, qFunction->getActions());
return reward + getParameter("DiscountFactor") * newMaxQ - dt/K * oldQ + (dt/K - 1) * oldMaxQ;
}
void CAdvantageLearningErrorFunction::getErrorGradient(CStateCollection *oldState, CAction *action, CStateCollection *newState, rlt_real reward,CFeatureList *gradient, CActionData *actionData)
{
rlt_real beta = getParameter("ResidualBeta");
rlt_real gamma = getParameter("DiscountFactor");
rlt_real K = getParameter("TimeScale");
rlt_real duration = action->getDuration();
// rlt_real advantage = getDerivatedError(oldState, action, reward, newState, actionData);
// get Gradient of the old state and executed action
qFuncGradient->clear();
qFunction->getGradient(oldState, action, actionData, qFuncGradient);
gradient->add(qFuncGradient);
// newMaxGradient
qFuncGradient->clear();
CAction *newMaxAction = qFunction->getMax(newState, qFunction->getActions(), actionDataSet);
qFunction->getGradient(newState, newMaxAction, actionDataSet->getActionData(newMaxAction), qFuncGradient);
gradient->add(qFuncGradient,(- beta) * pow(gamma, dt * duration) / (dt * duration * K));
qFuncGradient->clear();
CAction *oldMaxAction = qFunction->getMax(oldState, qFunction->getActions(), actionDataSet);
qFunction->getGradient(oldState, oldMaxAction, actionDataSet->getActionData(oldMaxAction), qFuncGradient);
gradient->add(qFuncGradient, (- beta) * (1 - 1/(dt * duration * K)));
// gradient->multFactor(advantage);
}
int CAdvantageLearningErrorFunction::getNumWeights()
{
return qFunction->getNumWeights();
}
void CAdvantageLearningErrorFunction::updateGradient(CFeatureList *gradientFeatures)
{
qFunction->updateGradient(gradientFeatures);
}
void CAdvantageLearningErrorFunction::setK(rlt_real K)
{
setParameter("TimeScale", K);
}
rlt_real CAdvantageLearningErrorFunction::getK()
{
return getParameter("TimeScale");
}
void CAdvantageLearningErrorFunction::setTimeIntervall(rlt_real dt)
{
this->dt = dt;
}
rlt_real CAdvantageLearningErrorFunction::getTimeIntervall()
{
return dt;
}
CValueIterationErrorFunction::CValueIterationErrorFunction(CGradientVFunction *vFunction, CResidualFunction *residualFunction, CResidualGradientFunction *gradientFunction)
{
this->vFunction = vFunction;
oldGradient = new CFeatureList();
newGradient = new CFeatureList();
this->residualFunction = residualFunction;
this->gradientFunction = gradientFunction;
addParameters(vFunction);
addParameters(residualFunction);
addParameters(gradientFunction);
}
CValueIterationErrorFunction::~CValueIterationErrorFunction()
{
delete oldGradient;
delete newGradient;
}
rlt_real CValueIterationErrorFunction::getDerivatedError(CStateCollection *oldState, CAction *action ,rlt_real reward, CStateCollection *newState, CActionData *actionData)
{
int duration = 1;
if (actionData && action->isType(MULTISTEPACTION))
{
duration = dynamic_cast<CMultiStepActionData *>(actionData)->duration;
}
else
{
duration = action->getDuration();
}
rlt_real error = residualFunction->getResidual(vFunction->getValue(oldState), reward, duration,vFunction->getValue(newState));
DebugPrint('t', "ValueIteration ErrorFunction: Error %f\n", error);
return error;
}
void CValueIterationErrorFunction::getErrorGradient(CStateCollection *oldState, CAction *action, CStateCollection *newState, rlt_real reward, CFeatureList *gradient, CActionData *actionData)
{
rlt_real duration = action->getDuration();
if (actionData && action->isType(MULTISTEPACTION))
{
duration = dynamic_cast<CMultiStepActionData *>(actionData)->duration;
}
oldGradient->clear();
newGradient->clear();
vFunction->getGradient(oldState, oldGradient);
vFunction->getGradient(newState, newGradient);
gradientFunction->getResidualGradient(oldGradient, newGradient, duration, gradient);
}
int CValueIterationErrorFunction::getNumWeights()
{
return vFunction->getNumWeights();
}
void CValueIterationErrorFunction::updateGradient(CFeatureList *gradientFeatures)
{
vFunction->updateGradient(gradientFeatures, 1.0);
}
CPolicySearchErrorFunction::CPolicySearchErrorFunction(CErrorFunction *errorFunction, rlt_real beta, rlt_real b, rlt_real gamma)
{
this->errorFunction = errorFunction;
CErrorFunction::addParameter("DiscountFactor", gamma);
CErrorFunction::addParameter("PolicySearchBeta", beta);
CErrorFunction::addParameter("PolicySearchB", b);
this->timeStep = 0;
}
rlt_real CPolicySearchErrorFunction::getError(CStateCollection *oldState, CAction *action ,rlt_real reward, CStateCollection *newState, CActionData *actionData)
{
rlt_real b = CErrorFunction::getParameter("PolicySearchB");
rlt_real beta = CErrorFunction::getParameter("PolicySearchBeta");
rlt_real gamma = CErrorFunction::getParameter("DiscountFactor");
return (1 - beta) * errorFunction->getError(oldState, action, reward, newState, actionData) + beta * (b - pow(gamma, timeStep) * reward);
}
rlt_real CPolicySearchErrorFunction::getDerivatedError(CStateCollection *oldState, CAction *action ,rlt_real reward, CStateCollection *newState, CActionData *actionData)
{
return errorFunction->getDerivatedError(oldState, action, reward, newState,actionData);
}
void CPolicySearchErrorFunction::getErrorGradient(CStateCollection *oldState, CAction *action, CStateCollection *newState, rlt_real reward, CFeatureList *gradient, CActionData *actionData)
{
errorFunction->getErrorGradient(oldState, action, newState, reward, gradient, actionData);
}
int CPolicySearchErrorFunction::getNumWeights()
{
return errorFunction->getNumWeights();
}
void CPolicySearchErrorFunction::updateGradient(CFeatureList *gradientFeatures)
{
errorFunction->updateGradient(gradientFeatures);
}
void CPolicySearchErrorFunction::nextStep(CStateCollection *oldState, CAction *action, CStateCollection *newState)
{
timeStep += action->getDuration();
}
void CPolicySearchErrorFunction::newEpisode()
{
timeStep = 0.0;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -