⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ctaxidomain.cpp

📁 强化学习算法(R-Learning)难得的珍贵资料
💻 CPP
字号:
#include "ril_debug.h"

#include "ctaxidomain.h"

CTaxiDomain::CTaxiDomain(char* filename) : CGridWorldModel(filename, 100)
{
	targetXYValues = NULL;
	delete properties;
	properties = new CStateProperties(0, 5);
	initTargetVector();
}

CTaxiDomain::~CTaxiDomain()
{
	if (targetXYValues != NULL)
	{
		std::vector<std::pair<int, int> *>::iterator it = targetXYValues->begin();
		for (;it != targetXYValues->end(); it++)
		{
			delete *it;
		}
		delete targetXYValues;
	}
}

void CTaxiDomain::initTargetVector()
{
if (targetXYValues == NULL)
	{
		targetXYValues = new std::vector<std::pair<int, int> *>();
	}
	else
	{
		std::vector<std::pair<int, int> *>::iterator it = targetXYValues->begin();
		for (;it != targetXYValues->end(); it++)
		{
			delete *it;
		}
		targetXYValues->clear();
	}

	for (unsigned int x = 0; x < this->getSizeX(); x++)
	{
		for (unsigned int y = 0; y < this->getSizeY(); y++)
		{
			if (target_values->find(getGridValue(x,y)) != target_values->end())
			{
				targetXYValues->push_back(new std::pair<int, int>(x,y));
			}
		}
	}
	properties->setDiscreteStateSize(0, getSizeX());
	properties->setDiscreteStateSize(1, getSizeY());
	properties->setDiscreteStateSize(2, targetXYValues->size() + 1);
	properties->setDiscreteStateSize(3, targetXYValues->size());
}

void CTaxiDomain::load(FILE *stream)
{
	CGridWorldModel::load(stream);
	initTargetVector();
}

int CTaxiDomain::getTargetPositionX(int numTarget)
{
	return (*targetXYValues)[numTarget]->first;
}

int CTaxiDomain::getTargetPositionY(int numTarget)
{
	return (*targetXYValues)[numTarget]->second;
}

void CTaxiDomain::transitionFunction(CState *oldState, CAction *action, CState *newState, CActionData *data)
{
	if (action->isType(GRIDWORLDACTION))
	{
		CGridWorldModel::transitionFunction(oldState, action, newState, data);
	}
	int pos_x = oldState->getDiscreteState(0);
	int pos_y = oldState->getDiscreteState(1);
	
	int pasLocation = oldState->getDiscreteState(3);
	int pasDestination = oldState->getDiscreteState(4);

	if (action->isType(PICKUPACTION))
	{
		if (oldState->getDiscreteState(3) > 0)
		{
			if (getTargetPositionX(pasLocation - 1) == pos_x && getTargetPositionY(pasLocation - 1) == pos_y)
			{
				pasLocation = 0;
			}
		}
	}
	if (action->isType(PUTDOWNACTION))
	{
		if (pasLocation == 0)
		{
			if (getTargetPositionX(pasDestination) == pos_x && getTargetPositionY(pasDestination) == pos_y)
			{
				pasLocation = pasDestination;
			}
		}
	}

	newState->setDiscreteState(3, pasLocation);
	newState->setDiscreteState(4, pasDestination);
}

bool CTaxiDomain::isResetState(CState *state)
{
	return (CGridWorldModel::isFailedState(state) || state->getDiscreteState(3) == state->getDiscreteState(4)); 
}

void CTaxiDomain::getResetState(CState *resetState)
{
	CGridWorldModel::getResetState(resetState);
	resetState->setDiscreteState(3, (rand() * RAND_MAX) % targetXYValues->size() + 1);
	resetState->setDiscreteState(4, (rand() * RAND_MAX) % targetXYValues->size());
}


rlt_real CTaxiDomain::getReward(CStateCollection *oldStateCol, CAction *action, CStateCollection *newStateCol)
{
	CState *oldState = oldStateCol->getState();
	CState *newState = newStateCol->getState();

	rlt_real reward = this->getRewardStandard();
	
	if (action->isType(GRIDWORLDACTION))
	{
		if (oldState->getDiscreteState(0) == newState->getDiscreteState(0) && oldState->getDiscreteState(1) == newState->getDiscreteState(1))
		{
			reward += this->getRewardBounce();
		}
	}

	if (action->isType(PICKUPACTION))
	{
		// Wrong Pickup action
		if (!(oldState->getDiscreteState(2) != 0 && newState->getDiscreteState(2) == 0))
		{
			reward += this->getRewardBounce();;
		}
	}
	if (action->isType(PUTDOWNACTION))
	{
		if (oldState->getDiscreteState(2) == 0 && oldState->getDiscreteState(0) == getTargetPositionX(oldState->getDiscreteState(3)) && oldState->getDiscreteState(1) == getTargetPositionY(oldState->getDiscreteState(3)))
		{
			reward += getRewardSuccess();
		}
		else
		{
			reward += this->getRewardBounce();;
		}
	}
	return reward;
}

CTaxiHierarchicalBehaviour::CTaxiHierarchicalBehaviour(CEpisode *currentEpisode, int target, CTaxiDomain *taximodel) : CHierarchicalSemiMarkovDecisionProcess(currentEpisode)
{
	this->model = taximodel;
	this->target = target;
}

CTaxiHierarchicalBehaviour::~CTaxiHierarchicalBehaviour()
{

}

bool CTaxiHierarchicalBehaviour::isFinished(CStateCollection *stateCol, CStateCollection *newStateCol)
{
	CState *state = newStateCol->getState();
	if (state->getDiscreteState(0) == model->getTargetPositionX(target) && state->getDiscreteState(1) == model->getTargetPositionY(target))
	{
		return true;
	}
	else
	{
		return false;
	}
}

rlt_real  CTaxiHierarchicalBehaviour::getReward(CStateCollection *oldStateCol, CAction *action, CStateCollection *newStateCol)
{
	CState *oldState = oldStateCol->getState();
	CState *newState = newStateCol->getState();

	rlt_real reward = model->getRewardStandard();
	
	if (action->isType(GRIDWORLDACTION))
	{
		if (oldState->getDiscreteState(0) == newState->getDiscreteState(0) && oldState->getDiscreteState(1) == newState->getDiscreteState(1))
		{
			reward += model->getRewardBounce();
		}
		if (newState->getDiscreteState(0) == model->getTargetPositionX(target) && newState->getDiscreteState(1) == model->getTargetPositionY(target))
		{
			reward += model->getRewardSuccess() / 2;
		}
	}
	return reward;
}


CTaxiIsTargetDiscreteState::CTaxiIsTargetDiscreteState(CTaxiDomain *model) : CAbstractStateDiscretizer(model->getNumTargets() + 1)
{
	this->model = model;
}

unsigned int CTaxiIsTargetDiscreteState::getDiscreteStateNumber(CStateCollection *stateCol)
{
	int target = -1;
	CState *state = stateCol->getState();
	for (int i = 0; i < model->getNumTargets(); i++)
	{
		if (state->getDiscreteState(0) == model->getTargetPositionX(i) && state->getDiscreteState(1) == model->getTargetPositionY(i))
		{
			target = i;
			break;
		}
	}
	return target + 1;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -