⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ctransitionfunction.cpp

📁 强化学习算法(R-Learning)难得的珍贵资料
💻 CPP
📖 第 1 页 / 共 2 页
字号:
}

void CDynamicLinearActionContinuousTimeModel::getDerivationU(CState *oldstate, CMyMatrix *derivation)
{
	derivation->setMatrix(getB(oldstate));
}



CDynamicLinearContinuousTimeModel::CDynamicLinearContinuousTimeModel(CStateProperties *properties, CContinuousAction *action, rlt_real dt, CMyMatrix *A, CMyMatrix *B) : CDynamicLinearActionContinuousTimeModel(properties, action, dt)
{
	assert(A->getNumRows() == properties->getNumContinuousStates() && A->getNumColumns() == properties->getNumContinuousStates() && B->getNumColumns() == action->getContinuousActionProperties()->getNumActionValues() && B->getNumRows() == properties->getNumContinuousStates());


	B->setMatrix(B);
	AMatrix = new CMyMatrix(properties->getNumContinuousStates(), properties->getNumContinuousStates());
	AMatrix->setMatrix(A);
}

CDynamicLinearContinuousTimeModel::~CDynamicLinearContinuousTimeModel()
{
	delete AMatrix;
}

CMyMatrix *CDynamicLinearContinuousTimeModel::getB(CState *state)
{
	return B;
}

CMyVector *CDynamicLinearContinuousTimeModel::getA(CState *state)
{
	AMatrix->multVector(state, A); // a(x) = A * x
	return A;
}

CTransitionFunctionEnvironment::CTransitionFunctionEnvironment(CTransitionFunction *model) : CEnvironmentModel(model->getStateProperties())
{
	this->TransitionFunction = model;
	modelState = new CState(getStateProperties());
	nextState = new CState(getStateProperties());

	startStates = NULL;
	nEpisode = 0;
	createdStartStates = false;

	failedRegion = NULL;
	sampleRegion = NULL;
	targetRegion = NULL;

	resetModel();
}
	
CTransitionFunctionEnvironment::~CTransitionFunctionEnvironment()
{
	delete modelState;
	delete nextState;

	if (createdStartStates)
	{
		delete startStates;
	}
}

void CTransitionFunctionEnvironment::doNextState(CPrimitiveAction *action)
{
	TransitionFunction->transitionFunction(modelState, action, nextState);
	CState *buf = modelState;
	modelState = nextState;
	nextState = buf;

	if (targetRegion == NULL)
	{
		reset = TransitionFunction->isResetState(modelState);
	}
	else
	{
		reset = targetRegion->isStateInRegion(modelState);
	}

	if (failedRegion == NULL)
	{
		failed = TransitionFunction->isFailedState(modelState);
	}
	else
	{
		failed = failedRegion->isStateInRegion(modelState);
	}
}

void CTransitionFunctionEnvironment::doResetModel()
{
	if (startStates != NULL)
	{
		startStates->getState(nEpisode, modelState);
		nEpisode ++;
		nEpisode = nEpisode % startStates->getNumStates();
	}
	else
	{
		if (sampleRegion == NULL)
		{
			TransitionFunction->getResetState(modelState);
		}
		else
		{
			sampleRegion->getRandomStateSample(modelState);
		}
	}
}

void CTransitionFunctionEnvironment::getState(CState *state)
{
	assert(state->getStateProperties()->equals(getStateProperties()));
	state->setState(modelState);
}

void CTransitionFunctionEnvironment::setState(CState *state)
{
	assert(state->getStateProperties()->equals(getStateProperties()));
	modelState->setState(state);
}

void CTransitionFunctionEnvironment::setStartStates(CStateList *startStates)
{
	if (createdStartStates)
	{
		delete this->startStates;
		createdStartStates = false;
	}
	this->startStates = startStates;
	nEpisode = 0;
}

void CTransitionFunctionEnvironment::setStartStates(char *filename)
{
	FILE *startStateFile = fopen(filename, "r");
	startStates = new CStateList(getStateProperties());
	startStates->loadASCII(startStateFile);
	fclose(startStateFile);
	nEpisode = 0;

}


void CTransitionFunctionEnvironment::setSampleRegion(CRegion *l_sampleRegion)
{
	this->sampleRegion = l_sampleRegion;
}

void CTransitionFunctionEnvironment::setFailedRegion(CRegion *l_failedRegion)
{
	this->failedRegion = l_failedRegion;
}

void CTransitionFunctionEnvironment::setTargetRegion(CRegion *l_targetRegion)
{
	this->targetRegion = l_targetRegion;
}

CQFunctionFromTransitionFunction::CQFunctionFromTransitionFunction(CActionSet *actions, CAbstractVFunction *vfunction, CTransitionFunction *model, CRewardFunction *rewardfunction, std::list<CStateModifier *> *modifiers) : CAbstractQFunction(actions), CStateModifiersObject(model->getStateProperties())
{
	this->vfunction = vfunction;
	this->model = model;
	this->rewardfunction = rewardfunction;

	this->actionDataSet = new CActionDataSet(actions);

	nextState = new CStateCollectionImpl(model->getStateProperties());
	intermediateState = new CStateCollectionImpl(model->getStateProperties());

	this->stateCollectionList = new CStateCollectionList(model->getStateProperties());

	addParameter("SearchDepth", 1);
	addParameter("DiscountFactor", 0.95);
	addParameter("VFunctionScale", 1.0);

	addStateModifiers(modifiers);
}

CQFunctionFromTransitionFunction::~CQFunctionFromTransitionFunction()
{
	delete actionDataSet;
	delete nextState;
	delete intermediateState;

	delete stateCollectionList;
}

void CQFunctionFromTransitionFunction::addStateModifier(CStateModifier *modifier)
{
	CStateModifiersObject::addStateModifier(modifier);

	nextState->addStateModifier(modifier);
	intermediateState->addStateModifier(modifier);

	stateCollectionList->addStateModifier(modifier);
}

rlt_real CQFunctionFromTransitionFunction::getValue(CStateCollection *state, CAction *action, CActionData *data)
{
	stateCollectionList->clearStateLists();
	stateCollectionList->addStateCollection(state);

	return getValueDepthSearch(stateCollectionList, action, data, my_round(getParameter("SearchDepth")));
}

rlt_real CQFunctionFromTransitionFunction::getValueDepthSearch(CStateCollectionList *stateList, CAction *action, CActionData *data, int depth)
{
	stateList->getStateCollection(stateList->getNumStateCollections() - 1, intermediateState);
	if (depth == 0)
	{
		rlt_real vFunctionScale = getParameter("VFunctionScale");
		return vfunction->getValue(intermediateState) * vFunctionScale;
	}

	if (data)
	{
		actionDataSet->getActionData(action)->setData(data);
	}

	CActionData *ldata = actionDataSet->getActionData(action);

	int duration = 1;
	
	rlt_real rewardValue = 0;
	if (model->isType(DM_EXTENDEDACTIONMODEL))
	{
		CExtendedActionTransitionFunction *extModel = dynamic_cast<CExtendedActionTransitionFunction *>(model);
		rewardValue = extModel->transitionFunctionAndReward(intermediateState->getState(model->getStateProperties()), action, nextState->getState(model->getStateProperties()), ldata, rewardfunction, getParameter("DiscountFactor"));
		nextState->newModelState();
	}
	else
	{
		model->transitionFunction(intermediateState->getState(model->getStateProperties()), action, nextState->getState(model->getStateProperties()), ldata);
		nextState->newModelState();
		rewardValue = rewardfunction->getReward(intermediateState, action, nextState);
	}

	if ((action)->isType(MULTISTEPACTION))
	{
		CActionData *actionData = actionDataSet->getActionData(action);
		CMultiStepActionData *multiStepActionData  = dynamic_cast<CMultiStepActionData *>(actionData);
		duration = multiStepActionData->duration;
	}
	else
	{
		duration = action->getDuration();
	}

	if (DebugIsEnabled('q'))
	{
		DebugPrint('q', "Calculated NextState for Action: %d (", actions->getIndex(action));
		
		if (ldata)
		{
			ldata->saveASCII(DebugGetFileHandle('q'));
		}
		
//		data->saveASCII(DebugGetFileHandle('q'));

		DebugPrint('q', ")\n");
		nextState->getState()->saveASCII(DebugGetFileHandle('q'));
		DebugPrint('q',"\n");
	}
	
	rlt_real value = 0.0;

	if(depth > 1)
	{
		stateList->addStateCollection(nextState);

		CActionSet::iterator it = actions->begin();

		value = getValueDepthSearch(stateList, *it, NULL, depth - 1);
		rlt_real max = value;

		it ++;

		for (; it != actions->end();it ++)
		{
			value = getValueDepthSearch(stateList, *it, NULL, depth - 1);
			if (max < value)
			{
				max = value;
			}
		}
		value = max;

		stateList->removeLastStateCollection();
	}
	else
	{
		rlt_real vFunctionScale = getParameter("VFunctionScale");
		value = vfunction->getValue(nextState) * vFunctionScale;
	}
	DebugPrint('q', "Value: %f Reward %f\n", value, rewardValue);

	return rewardValue + pow(getParameter("DiscountFactor"), duration) * value;
}

CContinuousTimeQFunctionFromTransitionFunction::CContinuousTimeQFunctionFromTransitionFunction(CActionSet *actions, CVFunctionInputDerivationCalculator *vfunction, CDynamicContinuousTimeModel *model, CRewardFunction *rewardfunction, std::list<CStateModifier *> *modifiers) : CAbstractQFunction(actions), CStateModifiersObject(model->getStateProperties())
{
	this->vfunction = vfunction;
	this->model = model;
	this->rewardfunction = rewardfunction;


	nextState = new CStateCollectionImpl(model->getStateProperties());

	derivationXModel = new CState(model->getStateProperties());
	derivationXVFunction = new CState(model->getStateProperties());

	addStateModifiers(modifiers);
}

CContinuousTimeQFunctionFromTransitionFunction::CContinuousTimeQFunctionFromTransitionFunction(CActionSet *actions, CVFunctionInputDerivationCalculator *vfunction, CDynamicContinuousTimeModel *model, CRewardFunction *rewardfunction) : CAbstractQFunction(actions), CStateModifiersObject(model->getStateProperties())
{
	this->vfunction = vfunction;
	this->model = model;
	this->rewardfunction = rewardfunction;

	nextState = new CStateCollectionImpl(model->getStateProperties());

	derivationXModel = new CState(model->getStateProperties());
	derivationXVFunction = new CState(model->getStateProperties());
}

CContinuousTimeQFunctionFromTransitionFunction::~CContinuousTimeQFunctionFromTransitionFunction()
{
	delete nextState;

	delete derivationXModel;
	delete derivationXVFunction;
}

rlt_real CContinuousTimeQFunctionFromTransitionFunction::getValueVDerivation(CStateCollection *state, CAction *action, CActionData *data, CMyVector *derivationXVFunction)
{
	model->getDerivationX(state->getState(model->getStateProperties()), action, derivationXModel, data);
	model->transitionFunction(state->getState(model->getStateProperties()), action, nextState->getState(model->getStateProperties()),  data);

	rlt_real reward = rewardfunction->getReward(state, action, nextState);
	return derivationXVFunction->scalarProduct(derivationXModel);
}


void CContinuousTimeQFunctionFromTransitionFunction::getActionValues(CStateCollection *state, CActionSet *actions, rlt_real *actionValues, CActionDataSet *actionDataSet)
{
	vfunction->getInputDerivation(state, derivationXVFunction);

	CActionSet::iterator it = actions->begin();

	for (int i = 0; it != actions->end(); it ++, i ++)
	{
		actionValues[i] = getValueVDerivation(state, *it, actionDataSet->getActionData(*it), derivationXVFunction);
	}

	if (DebugIsEnabled('v'))
	{
		DebugPrint('v', "CTQ Function: ");
		for (unsigned int i = 0; i < actions->size(); i++)
		{
			DebugPrint('v', "%f ", actionValues[i]);
		}
		DebugPrint('v', "\n");
	}
}


rlt_real CContinuousTimeQFunctionFromTransitionFunction::getValue(CStateCollection *state, CAction *action, CActionData *data)
{
	vfunction->getInputDerivation(state, derivationXVFunction);

	return getValueVDerivation(state, action, data, derivationXVFunction);
}

void CContinuousTimeQFunctionFromTransitionFunction::addStateModifier(CStateModifier *modifier)
{
	CStateModifiersObject::addStateModifier(modifier);

	nextState->addStateModifier(modifier);
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -