📄 ctdlearner.cpp
字号:
// Copyright (C) 2003
// Gerhard Neumann (gerhard@igi.tu-graz.ac.at)
//
// This file is part of RL Toolbox.
// http://www.igi.tugraz.at/ril_toolbox
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// 3. The name of the author may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
// OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
// IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
// NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "ril_debug.h"
#include <stdlib.h>
#include <assert.h>
#include <time.h>
#include <math.h>
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <vector>
#include "cpolicies.h"
#include "ctdlearner.h"
#include "cqfunction.h"
CTDLearner::CTDLearner(CRewardFunction *rewardFunction,CAbstractQFunction *qfunction, CAbstractQETraces *etraces, CAgentController *estimationPolicy) : CSemiMDPRewardListener(rewardFunction) {
this->qfunction = qfunction;
this->etraces = etraces;
this->estimationPolicy = estimationPolicy;
addParameter("QLearningRate", 0.2);
addParameter("DiscountFactor",0.95);
addParameters(qfunction);
addParameters(etraces);
addParameter("ResetETracesOnWrongEstimate", 1.0);
if (estimationPolicy)
{
addParameters(estimationPolicy);
}
this->externETraces = true;
this->actionDataSet = new CActionDataSet(qfunction->getActions());
lastEstimatedAction = NULL;
}
CTDLearner::CTDLearner(CRewardFunction *rewardFunction, CAbstractQFunction *qFunction, CAgentController *estimationPolicy) : CSemiMDPRewardListener(rewardFunction)
{
this->qfunction = qFunction;
this->etraces = qFunction->getStandardETraces();
this->estimationPolicy = estimationPolicy;
addParameter("QLearningRate", 0.2);
addParameter("DiscountFactor",0.95);
addParameters(qfunction);
addParameters(etraces);
addParameter("ResetETracesOnWrongEstimate", 1.0);
if (estimationPolicy)
{
addParameters(estimationPolicy);
}
this->externETraces = false;
this->actionDataSet = new CActionDataSet(qfunction->getActions());
lastEstimatedAction = NULL;
}
CTDLearner::~CTDLearner()
{
if (!this->externETraces)
{
delete etraces;
}
delete actionDataSet;
}
void CTDLearner::newEpisode() {
lastEstimatedAction = NULL;
}
rlt_real CTDLearner::getTemporalDifference(CStateCollection *oldState, CAction *action, rlt_real reward, CStateCollection *newState)
{
rlt_real newQ = 0.0, oldQ;
rlt_real temporalDiff = 0.0;
if (lastEstimatedAction == NULL)
{
lastEstimatedAction = qfunction->getMax(newState, qfunction->getActions(), actionDataSet);
}
int duration = 1;
if (action->isType(MULTISTEPACTION))
{
duration = dynamic_cast<CMultiStepAction *>(action)->getDuration();
}
// assert(lastEstimatedAction->getIndex() >= 0);
oldQ = qfunction->getValue(oldState, action); // Save old prediction: Q(st,at)
newQ = qfunction->getValue(newState, lastEstimatedAction, actionDataSet->getActionData(lastEstimatedAction));
temporalDiff = getResidual(oldQ, reward, duration, newQ);
DebugPrint('t', "OldQValue: %f\n", oldQ);
DebugPrint('t', "NewQValue: %f\n", newQ);
DebugPrint('t', "Reward: %f\n", reward);
DebugPrint('t', "TemporalDiff: %f\n", temporalDiff);
sendErrorToListeners(temporalDiff, oldState, action, NULL);
return temporalDiff;
}
void CTDLearner::learnStep(CStateCollection *oldState, CAction *action, rlt_real reward, CStateCollection *newState)
{
bool resetEtraces = getParameter("ResetETracesOnWrongEstimate") > 0.5;
if (resetEtraces && (lastEstimatedAction == NULL || !action->isSameAction(lastEstimatedAction, actionDataSet->getActionData(lastEstimatedAction))))
{
etraces->resetETraces();
}
etraces->updateETraces(action);
lastEstimatedAction = estimationPolicy->getNextAction(newState, actionDataSet);
// if there is no estimated action, take the greedy policy
if (lastEstimatedAction == NULL)
{
lastEstimatedAction = qfunction->getMax(newState, qfunction->getActions(), actionDataSet);
}
addETraces(oldState, newState, action);
// assert(qfunction->getActions()->getIndex(lastEstimatedAction) >= 0);
etraces->updateQFunction(getParameter("QLearningRate") * getTemporalDifference(oldState, action, reward, newState));
}
rlt_real CTDLearner::getResidual(rlt_real oldQ, rlt_real reward, int duration, rlt_real newQ)
{
return (reward + pow(getParameter("DiscountFactor"), duration) * newQ - oldQ);
}
void CTDLearner::addETraces(CStateCollection *oldState, CStateCollection *newState, CAction *oldAction)
{
etraces->addETrace(oldState, oldAction, 1.0);
}
void CTDLearner::nextStep(CStateCollection *oldState, CAction *action, rlt_real reward, CStateCollection *nextState)
{
learnStep(oldState, action, reward, nextState);
}
void CTDLearner::intermediateStep(CStateCollection *oldState, CAction *action, rlt_real reward, CStateCollection *nextState)
{
addETraces(oldState, nextState, action);
qfunction->updateValue(oldState, action, getParameter("QLearningRate") * getTemporalDifference(oldState, action, reward, nextState));
}
void CTDLearner::saveValues(char *filename) {
FILE *stream = fopen (filename, "w");
saveValues(stream);
fclose (stream);
}
void CTDLearner::loadValues(char *filename) {
FILE *stream = fopen(filename, "r");
loadValues(stream);
fclose(stream);
}
void CTDLearner::saveValues(FILE *stream) {
assert(qfunction != NULL);
qfunction->saveData(stream);
}
void CTDLearner::loadValues(FILE *stream) {
assert(qfunction != NULL);
qfunction->loadData(stream);
}
void CTDLearner::setAlpha(rlt_real alpha) {
setParameter("QLearningRate", alpha);
}
void CTDLearner::setLambda(rlt_real lambda) {
assert(etraces != NULL);
etraces->setLambda(lambda);
}
CAgentController* CTDLearner::getEstimationPolicy() {
return estimationPolicy;
}
void CTDLearner::setEstimationPolicy(CAgentController * estimationPolicy) {
this->estimationPolicy = estimationPolicy;
}
CAbstractQFunction* CTDLearner::getQFunction() {
return qfunction;
}
CAbstractQETraces* CTDLearner::getETraces() {
return etraces;
}
CQLearner::CQLearner(CRewardFunction *rewardFunction, CAbstractQFunction *qfunc) : CTDLearner(rewardFunction, qfunc, new CQGreedyPolicy(qfunc->getActions(), qfunc))
{
}
CQLearner::~CQLearner()
{
delete estimationPolicy;
}
CSarsaLearner::CSarsaLearner(CRewardFunction *rewardFunction, CAbstractQFunction *qfunction, CDeterministicController *agent) : CTDLearner(rewardFunction, qfunction, agent)
{
setParameter("ResetETracesOnWrongEstimate", 1.0);
}
CSarsaLearner::~CSarsaLearner()
{
}
CTDGradientLearner::CTDGradientLearner(CRewardFunction *rewardFunction, CGradientQFunction *qfunction, CAgentController *agentController, CResidualFunction *residual, CResidualGradientFunction *residualGradient) : CTDLearner(rewardFunction, qfunction, new CGradientQETraces(qfunction), agentController)
{
assert(qfunction->isType(GRADIENTQFUNCTION));
this->gradientQFunction = qfunction;
this->residual = residual;
this->residualGradient = residualGradient;
if (residual)
{
addParameters(residual);
}
if (residualGradient)
{
addParameters(residualGradient);
}
this->gradientQETraces = dynamic_cast<CGradientQETraces *>(etraces);
gradientQETraces->setReplacingETraces(true);
oldGradient = new CFeatureList();
newGradient = new CFeatureList();
residualGradientFeatures = new CFeatureList();
}
CTDGradientLearner::~CTDGradientLearner()
{
delete oldGradient;
delete newGradient;
delete residualGradientFeatures;
}
rlt_real CTDGradientLearner::getResidual(rlt_real oldQ, rlt_real reward, int duration, rlt_real newQ)
{
return residual->getResidual(oldQ, reward, duration, newQ);
}
void CTDGradientLearner::addETraces(CStateCollection *oldState, CStateCollection *newState, CAction *oldAction)
{
if (lastEstimatedAction == NULL)
{
lastEstimatedAction = qfunction->getMax(newState, qfunction->getActions(), actionDataSet);
}
rlt_real duration = oldAction->getDuration();
oldGradient->clear();
newGradient->clear();
residualGradientFeatures->clear();
gradientQFunction->getGradient(oldState, oldAction, oldAction->getActionData(), oldGradient);
gradientQFunction->getGradient(newState, lastEstimatedAction, actionDataSet->getActionData(lastEstimatedAction), newGradient);
residualGradient->getResidualGradient(oldGradient, newGradient, duration, residualGradientFeatures);
if (DebugIsEnabled('t'))
{
DebugPrint('t', "Residual Gradient: ");
residualGradientFeatures->saveASCII(DebugGetFileHandle('t'));
DebugPrint('t', "\n");
}
gradientQETraces->addGradientETrace(residualGradientFeatures, - 1.0);
}
CTDResidualLearner::CTDResidualLearner(CRewardFunction *rewardFunction, CGradientQFunction *qfunction, CAgentController *agent, CResidualFunction *residual, CResidualGradientFunction *residualGradient, CAbstractBetaCalculator *betaCalc) : CTDGradientLearner(rewardFunction, qfunction, agent, residual, residualGradient)
{
this->betaCalculator = betaCalc;
residualETraces = new CGradientQETraces(qfunction);
residualETraces->setReplacingETraces(true);
directGradientTraces = new CGradientQETraces(qfunction);
directGradientTraces->setReplacingETraces(true);
residualGradientTraces = new CGradientQETraces(qfunction);
residualGradientTraces->setReplacingETraces(true);
addParameters(residualETraces);
addParameters(directGradientTraces, "Gradient");
addParameters(residualGradientTraces, "Gradient");
addParameters(betaCalculator);
addParameter("ScaleResidualGradient", 0.0);
}
CTDResidualLearner::~CTDResidualLearner()
{
delete residualETraces;
}
void CTDResidualLearner::learnStep(CStateCollection *oldState, CAction *action, rlt_real reward, CStateCollection *newState)
{
bool resetEtraces = getParameter("ResetETracesOnWrongEstimate") > 0.5;
if (resetEtraces && (lastEstimatedAction == NULL || !action->isSameAction(lastEstimatedAction, actionDataSet->getActionData(lastEstimatedAction))))
{
etraces->resetETraces();
residualETraces->resetETraces();
}
etraces->updateETraces(action);
residualETraces->updateETraces(action);
directGradientTraces->updateETraces(action);
residualGradientTraces->updateETraces(action);
lastEstimatedAction = estimationPolicy->getNextAction(newState, actionDataSet);
// if there is no estimated action, take the greedy policy
if (lastEstimatedAction == NULL)
{
lastEstimatedAction = qfunction->getMax(newState, qfunction->getActions(), actionDataSet);
}
assert(qfunction->getActions()->getIndex(lastEstimatedAction) >= 0);
rlt_real td = getParameter("QLearningRate") * getTemporalDifference(oldState, action, reward, newState);
addETraces(oldState, newState, action, td);
rlt_real beta = betaCalculator->getBeta(directGradientTraces->getGradientETraces(), residualGradientTraces->getGradientETraces());
gradientQETraces->updateQFunction(td * (1- beta));
residualETraces->updateQFunction(td * beta);
}
void CTDResidualLearner::addETraces(CStateCollection *oldState, CStateCollection *newState, CAction *action, rlt_real td)
{
if (lastEstimatedAction == NULL)
{
lastEstimatedAction = qfunction->getMax(newState, qfunction->getActions(), actionDataSet);
}
rlt_real duration = action->getDuration();
oldGradient->clear();
newGradient->clear();
residualGradientFeatures->clear();
gradientQFunction->getGradient(oldState, action, action->getActionData(), oldGradient);
gradientQFunction->getGradient(newState, lastEstimatedAction, actionDataSet->getActionData(lastEstimatedAction), newGradient);
// Add Direct Gradient
gradientQETraces->addGradientETrace(oldGradient, 1.0);
residualGradient->getResidualGradient(oldGradient, newGradient,duration, residualGradientFeatures);
if (getParameter("ScaleResidualGradient") > 0.5)
{
residualGradientFeatures->multFactor(oldGradient->getLength() / residualGradientFeatures->getLength());
}
// Add Residual Gradient
residualETraces->addGradientETrace(residualGradientFeatures, - 1.0);
directGradientTraces->addGradientETrace(oldGradient, td);
residualGradientTraces->addGradientETrace(residualGradientFeatures, - td);
if (DebugIsEnabled('t'))
{
DebugPrint('t', "Residual Gradient: ");
residualGradientFeatures->saveASCII(DebugGetFileHandle('t'));
DebugPrint('t', "\n");
}
}
void CTDResidualLearner::newEpisode()
{
CTDGradientLearner::newEpisode();
residualETraces->resetETraces();
residualGradientTraces->resetETraces();
directGradientTraces->resetETraces();
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -