📄 ctaxidomain.cpp
字号:
#include "ril_debug.h"
#include "ctaxidomain.h"
CTaxiDomain::CTaxiDomain(char* filename) : CGridWorldModel(filename, 100)
{
targetXYValues = NULL;
delete properties;
properties = new CStateProperties(0, 5);
initTargetVector();
}
CTaxiDomain::~CTaxiDomain()
{
if (targetXYValues != NULL)
{
std::vector<std::pair<int, int> *>::iterator it = targetXYValues->begin();
for (;it != targetXYValues->end(); it++)
{
delete *it;
}
delete targetXYValues;
}
}
void CTaxiDomain::initTargetVector()
{
if (targetXYValues == NULL)
{
targetXYValues = new std::vector<std::pair<int, int> *>();
}
else
{
std::vector<std::pair<int, int> *>::iterator it = targetXYValues->begin();
for (;it != targetXYValues->end(); it++)
{
delete *it;
}
targetXYValues->clear();
}
for (unsigned int x = 0; x < this->getSizeX(); x++)
{
for (unsigned int y = 0; y < this->getSizeY(); y++)
{
if (target_values->find(getGridValue(x,y)) != target_values->end())
{
targetXYValues->push_back(new std::pair<int, int>(x,y));
}
}
}
properties->setDiscreteStateSize(0, getSizeX());
properties->setDiscreteStateSize(1, getSizeY());
properties->setDiscreteStateSize(2, targetXYValues->size() + 1);
properties->setDiscreteStateSize(3, targetXYValues->size());
}
void CTaxiDomain::load(FILE *stream)
{
CGridWorldModel::load(stream);
initTargetVector();
}
int CTaxiDomain::getTargetPositionX(int numTarget)
{
return (*targetXYValues)[numTarget]->first;
}
int CTaxiDomain::getTargetPositionY(int numTarget)
{
return (*targetXYValues)[numTarget]->second;
}
void CTaxiDomain::transitionFunction(CState *oldState, CAction *action, CState *newState, CActionData *data)
{
if (action->isType(GRIDWORLDACTION))
{
CGridWorldModel::transitionFunction(oldState, action, newState, data);
}
int pos_x = oldState->getDiscreteState(0);
int pos_y = oldState->getDiscreteState(1);
int pasLocation = oldState->getDiscreteState(3);
int pasDestination = oldState->getDiscreteState(4);
if (action->isType(PICKUPACTION))
{
if (oldState->getDiscreteState(3) > 0)
{
if (getTargetPositionX(pasLocation - 1) == pos_x && getTargetPositionY(pasLocation - 1) == pos_y)
{
pasLocation = 0;
}
}
}
if (action->isType(PUTDOWNACTION))
{
if (pasLocation == 0)
{
if (getTargetPositionX(pasDestination) == pos_x && getTargetPositionY(pasDestination) == pos_y)
{
pasLocation = pasDestination;
}
}
}
newState->setDiscreteState(3, pasLocation);
newState->setDiscreteState(4, pasDestination);
}
bool CTaxiDomain::isResetState(CState *state)
{
return (CGridWorldModel::isFailedState(state) || state->getDiscreteState(3) == state->getDiscreteState(4));
}
void CTaxiDomain::getResetState(CState *resetState)
{
CGridWorldModel::getResetState(resetState);
resetState->setDiscreteState(3, (rand() * RAND_MAX) % targetXYValues->size() + 1);
resetState->setDiscreteState(4, (rand() * RAND_MAX) % targetXYValues->size());
}
rlt_real CTaxiDomain::getReward(CStateCollection *oldStateCol, CAction *action, CStateCollection *newStateCol)
{
CState *oldState = oldStateCol->getState();
CState *newState = newStateCol->getState();
rlt_real reward = this->getRewardStandard();
if (action->isType(GRIDWORLDACTION))
{
if (oldState->getDiscreteState(0) == newState->getDiscreteState(0) && oldState->getDiscreteState(1) == newState->getDiscreteState(1))
{
reward += this->getRewardBounce();
}
}
if (action->isType(PICKUPACTION))
{
// Wrong Pickup action
if (!(oldState->getDiscreteState(2) != 0 && newState->getDiscreteState(2) == 0))
{
reward += this->getRewardBounce();;
}
}
if (action->isType(PUTDOWNACTION))
{
if (oldState->getDiscreteState(2) == 0 && oldState->getDiscreteState(0) == getTargetPositionX(oldState->getDiscreteState(3)) && oldState->getDiscreteState(1) == getTargetPositionY(oldState->getDiscreteState(3)))
{
reward += getRewardSuccess();
}
else
{
reward += this->getRewardBounce();;
}
}
return reward;
}
CTaxiHierarchicalBehaviour::CTaxiHierarchicalBehaviour(CEpisode *currentEpisode, int target, CTaxiDomain *taximodel) : CHierarchicalSemiMarkovDecisionProcess(currentEpisode)
{
this->model = taximodel;
this->target = target;
}
CTaxiHierarchicalBehaviour::~CTaxiHierarchicalBehaviour()
{
}
bool CTaxiHierarchicalBehaviour::isFinished(CStateCollection *stateCol, CStateCollection *newStateCol)
{
CState *state = newStateCol->getState();
if (state->getDiscreteState(0) == model->getTargetPositionX(target) && state->getDiscreteState(1) == model->getTargetPositionY(target))
{
return true;
}
else
{
return false;
}
}
rlt_real CTaxiHierarchicalBehaviour::getReward(CStateCollection *oldStateCol, CAction *action, CStateCollection *newStateCol)
{
CState *oldState = oldStateCol->getState();
CState *newState = newStateCol->getState();
rlt_real reward = model->getRewardStandard();
if (action->isType(GRIDWORLDACTION))
{
if (oldState->getDiscreteState(0) == newState->getDiscreteState(0) && oldState->getDiscreteState(1) == newState->getDiscreteState(1))
{
reward += model->getRewardBounce();
}
if (newState->getDiscreteState(0) == model->getTargetPositionX(target) && newState->getDiscreteState(1) == model->getTargetPositionY(target))
{
reward += model->getRewardSuccess() / 2;
}
}
return reward;
}
CTaxiIsTargetDiscreteState::CTaxiIsTargetDiscreteState(CTaxiDomain *model) : CAbstractStateDiscretizer(model->getNumTargets() + 1)
{
this->model = model;
}
unsigned int CTaxiIsTargetDiscreteState::getDiscreteStateNumber(CStateCollection *stateCol)
{
int target = -1;
CState *state = stateCol->getState();
for (int i = 0; i < model->getNumTargets(); i++)
{
if (state->getDiscreteState(0) == model->getTargetPositionX(i) && state->getDiscreteState(1) == model->getTargetPositionY(i))
{
target = i;
break;
}
}
return target + 1;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -