📄 ctransitionfunction.cpp
字号:
}
void CDynamicLinearActionContinuousTimeModel::getDerivationU(CState *oldstate, CMyMatrix *derivation)
{
derivation->setMatrix(getB(oldstate));
}
CDynamicLinearContinuousTimeModel::CDynamicLinearContinuousTimeModel(CStateProperties *properties, CContinuousAction *action, rlt_real dt, CMyMatrix *A, CMyMatrix *B) : CDynamicLinearActionContinuousTimeModel(properties, action, dt)
{
assert(A->getNumRows() == properties->getNumContinuousStates() && A->getNumColumns() == properties->getNumContinuousStates() && B->getNumColumns() == action->getContinuousActionProperties()->getNumActionValues() && B->getNumRows() == properties->getNumContinuousStates());
B->setMatrix(B);
AMatrix = new CMyMatrix(properties->getNumContinuousStates(), properties->getNumContinuousStates());
AMatrix->setMatrix(A);
}
CDynamicLinearContinuousTimeModel::~CDynamicLinearContinuousTimeModel()
{
delete AMatrix;
}
CMyMatrix *CDynamicLinearContinuousTimeModel::getB(CState *state)
{
return B;
}
CMyVector *CDynamicLinearContinuousTimeModel::getA(CState *state)
{
AMatrix->multVector(state, A); // a(x) = A * x
return A;
}
CTransitionFunctionEnvironment::CTransitionFunctionEnvironment(CTransitionFunction *model) : CEnvironmentModel(model->getStateProperties())
{
this->TransitionFunction = model;
modelState = new CState(getStateProperties());
nextState = new CState(getStateProperties());
startStates = NULL;
nEpisode = 0;
createdStartStates = false;
failedRegion = NULL;
sampleRegion = NULL;
targetRegion = NULL;
resetModel();
}
CTransitionFunctionEnvironment::~CTransitionFunctionEnvironment()
{
delete modelState;
delete nextState;
if (createdStartStates)
{
delete startStates;
}
}
void CTransitionFunctionEnvironment::doNextState(CPrimitiveAction *action)
{
TransitionFunction->transitionFunction(modelState, action, nextState);
CState *buf = modelState;
modelState = nextState;
nextState = buf;
if (targetRegion == NULL)
{
reset = TransitionFunction->isResetState(modelState);
}
else
{
reset = targetRegion->isStateInRegion(modelState);
}
if (failedRegion == NULL)
{
failed = TransitionFunction->isFailedState(modelState);
}
else
{
failed = failedRegion->isStateInRegion(modelState);
}
}
void CTransitionFunctionEnvironment::doResetModel()
{
if (startStates != NULL)
{
startStates->getState(nEpisode, modelState);
nEpisode ++;
nEpisode = nEpisode % startStates->getNumStates();
}
else
{
if (sampleRegion == NULL)
{
TransitionFunction->getResetState(modelState);
}
else
{
sampleRegion->getRandomStateSample(modelState);
}
}
}
void CTransitionFunctionEnvironment::getState(CState *state)
{
assert(state->getStateProperties()->equals(getStateProperties()));
state->setState(modelState);
}
void CTransitionFunctionEnvironment::setState(CState *state)
{
assert(state->getStateProperties()->equals(getStateProperties()));
modelState->setState(state);
}
void CTransitionFunctionEnvironment::setStartStates(CStateList *startStates)
{
if (createdStartStates)
{
delete this->startStates;
createdStartStates = false;
}
this->startStates = startStates;
nEpisode = 0;
}
void CTransitionFunctionEnvironment::setStartStates(char *filename)
{
FILE *startStateFile = fopen(filename, "r");
startStates = new CStateList(getStateProperties());
startStates->loadASCII(startStateFile);
fclose(startStateFile);
nEpisode = 0;
}
void CTransitionFunctionEnvironment::setSampleRegion(CRegion *l_sampleRegion)
{
this->sampleRegion = l_sampleRegion;
}
void CTransitionFunctionEnvironment::setFailedRegion(CRegion *l_failedRegion)
{
this->failedRegion = l_failedRegion;
}
void CTransitionFunctionEnvironment::setTargetRegion(CRegion *l_targetRegion)
{
this->targetRegion = l_targetRegion;
}
CQFunctionFromTransitionFunction::CQFunctionFromTransitionFunction(CActionSet *actions, CAbstractVFunction *vfunction, CTransitionFunction *model, CRewardFunction *rewardfunction, std::list<CStateModifier *> *modifiers) : CAbstractQFunction(actions), CStateModifiersObject(model->getStateProperties())
{
this->vfunction = vfunction;
this->model = model;
this->rewardfunction = rewardfunction;
this->actionDataSet = new CActionDataSet(actions);
nextState = new CStateCollectionImpl(model->getStateProperties());
intermediateState = new CStateCollectionImpl(model->getStateProperties());
this->stateCollectionList = new CStateCollectionList(model->getStateProperties());
addParameter("SearchDepth", 1);
addParameter("DiscountFactor", 0.95);
addParameter("VFunctionScale", 1.0);
addStateModifiers(modifiers);
}
CQFunctionFromTransitionFunction::~CQFunctionFromTransitionFunction()
{
delete actionDataSet;
delete nextState;
delete intermediateState;
delete stateCollectionList;
}
void CQFunctionFromTransitionFunction::addStateModifier(CStateModifier *modifier)
{
CStateModifiersObject::addStateModifier(modifier);
nextState->addStateModifier(modifier);
intermediateState->addStateModifier(modifier);
stateCollectionList->addStateModifier(modifier);
}
rlt_real CQFunctionFromTransitionFunction::getValue(CStateCollection *state, CAction *action, CActionData *data)
{
stateCollectionList->clearStateLists();
stateCollectionList->addStateCollection(state);
return getValueDepthSearch(stateCollectionList, action, data, my_round(getParameter("SearchDepth")));
}
rlt_real CQFunctionFromTransitionFunction::getValueDepthSearch(CStateCollectionList *stateList, CAction *action, CActionData *data, int depth)
{
stateList->getStateCollection(stateList->getNumStateCollections() - 1, intermediateState);
if (depth == 0)
{
rlt_real vFunctionScale = getParameter("VFunctionScale");
return vfunction->getValue(intermediateState) * vFunctionScale;
}
if (data)
{
actionDataSet->getActionData(action)->setData(data);
}
CActionData *ldata = actionDataSet->getActionData(action);
int duration = 1;
rlt_real rewardValue = 0;
if (model->isType(DM_EXTENDEDACTIONMODEL))
{
CExtendedActionTransitionFunction *extModel = dynamic_cast<CExtendedActionTransitionFunction *>(model);
rewardValue = extModel->transitionFunctionAndReward(intermediateState->getState(model->getStateProperties()), action, nextState->getState(model->getStateProperties()), ldata, rewardfunction, getParameter("DiscountFactor"));
nextState->newModelState();
}
else
{
model->transitionFunction(intermediateState->getState(model->getStateProperties()), action, nextState->getState(model->getStateProperties()), ldata);
nextState->newModelState();
rewardValue = rewardfunction->getReward(intermediateState, action, nextState);
}
if ((action)->isType(MULTISTEPACTION))
{
CActionData *actionData = actionDataSet->getActionData(action);
CMultiStepActionData *multiStepActionData = dynamic_cast<CMultiStepActionData *>(actionData);
duration = multiStepActionData->duration;
}
else
{
duration = action->getDuration();
}
if (DebugIsEnabled('q'))
{
DebugPrint('q', "Calculated NextState for Action: %d (", actions->getIndex(action));
if (ldata)
{
ldata->saveASCII(DebugGetFileHandle('q'));
}
// data->saveASCII(DebugGetFileHandle('q'));
DebugPrint('q', ")\n");
nextState->getState()->saveASCII(DebugGetFileHandle('q'));
DebugPrint('q',"\n");
}
rlt_real value = 0.0;
if(depth > 1)
{
stateList->addStateCollection(nextState);
CActionSet::iterator it = actions->begin();
value = getValueDepthSearch(stateList, *it, NULL, depth - 1);
rlt_real max = value;
it ++;
for (; it != actions->end();it ++)
{
value = getValueDepthSearch(stateList, *it, NULL, depth - 1);
if (max < value)
{
max = value;
}
}
value = max;
stateList->removeLastStateCollection();
}
else
{
rlt_real vFunctionScale = getParameter("VFunctionScale");
value = vfunction->getValue(nextState) * vFunctionScale;
}
DebugPrint('q', "Value: %f Reward %f\n", value, rewardValue);
return rewardValue + pow(getParameter("DiscountFactor"), duration) * value;
}
CContinuousTimeQFunctionFromTransitionFunction::CContinuousTimeQFunctionFromTransitionFunction(CActionSet *actions, CVFunctionInputDerivationCalculator *vfunction, CDynamicContinuousTimeModel *model, CRewardFunction *rewardfunction, std::list<CStateModifier *> *modifiers) : CAbstractQFunction(actions), CStateModifiersObject(model->getStateProperties())
{
this->vfunction = vfunction;
this->model = model;
this->rewardfunction = rewardfunction;
nextState = new CStateCollectionImpl(model->getStateProperties());
derivationXModel = new CState(model->getStateProperties());
derivationXVFunction = new CState(model->getStateProperties());
addStateModifiers(modifiers);
}
CContinuousTimeQFunctionFromTransitionFunction::CContinuousTimeQFunctionFromTransitionFunction(CActionSet *actions, CVFunctionInputDerivationCalculator *vfunction, CDynamicContinuousTimeModel *model, CRewardFunction *rewardfunction) : CAbstractQFunction(actions), CStateModifiersObject(model->getStateProperties())
{
this->vfunction = vfunction;
this->model = model;
this->rewardfunction = rewardfunction;
nextState = new CStateCollectionImpl(model->getStateProperties());
derivationXModel = new CState(model->getStateProperties());
derivationXVFunction = new CState(model->getStateProperties());
}
CContinuousTimeQFunctionFromTransitionFunction::~CContinuousTimeQFunctionFromTransitionFunction()
{
delete nextState;
delete derivationXModel;
delete derivationXVFunction;
}
rlt_real CContinuousTimeQFunctionFromTransitionFunction::getValueVDerivation(CStateCollection *state, CAction *action, CActionData *data, CMyVector *derivationXVFunction)
{
model->getDerivationX(state->getState(model->getStateProperties()), action, derivationXModel, data);
model->transitionFunction(state->getState(model->getStateProperties()), action, nextState->getState(model->getStateProperties()), data);
rlt_real reward = rewardfunction->getReward(state, action, nextState);
return derivationXVFunction->scalarProduct(derivationXModel);
}
void CContinuousTimeQFunctionFromTransitionFunction::getActionValues(CStateCollection *state, CActionSet *actions, rlt_real *actionValues, CActionDataSet *actionDataSet)
{
vfunction->getInputDerivation(state, derivationXVFunction);
CActionSet::iterator it = actions->begin();
for (int i = 0; it != actions->end(); it ++, i ++)
{
actionValues[i] = getValueVDerivation(state, *it, actionDataSet->getActionData(*it), derivationXVFunction);
}
if (DebugIsEnabled('v'))
{
DebugPrint('v', "CTQ Function: ");
for (unsigned int i = 0; i < actions->size(); i++)
{
DebugPrint('v', "%f ", actionValues[i]);
}
DebugPrint('v', "\n");
}
}
rlt_real CContinuousTimeQFunctionFromTransitionFunction::getValue(CStateCollection *state, CAction *action, CActionData *data)
{
vfunction->getInputDerivation(state, derivationXVFunction);
return getValueVDerivation(state, action, data, derivationXVFunction);
}
void CContinuousTimeQFunctionFromTransitionFunction::addStateModifier(CStateModifier *modifier)
{
CStateModifiersObject::addStateModifier(modifier);
nextState->addStateModifier(modifier);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -