📄 ccartpolemodel.cpp
字号:
// Copyright (C) 2003
// Gerhard Neumann (gerhard@igi.tu-graz.ac.at)
//
// This file is part of RL Toolbox.
// http://www.igi.tugraz.at/ril_toolbox
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// 3. The name of the author may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
// OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
// IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
// NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "ccartpolemodel.h"
#include <math.h>
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
CCartPoleModel::CCartPoleModel(rlt_real dt, rlt_real uMax, rlt_real lengthTrack, rlt_real lengthPole, rlt_real massCart, rlt_real massPole, rlt_real mu_c, rlt_real mu_p, rlt_real g, bool endLeaveTrack, bool endOverRotate) : CDynamicLinearActionContinuousTimeModel(new CStateProperties(5,0), new CContinuousAction(new CContinuousActionProperties(1)), dt)
{
this->uMax = uMax;
this->lengthTrack = lengthTrack;
this->lengthPole = lengthPole;
this->massCart = massCart;
this->massPole = massPole;
this->mu_c = mu_c;
this->mu_p = mu_p;
this->g = g;
properties->setMaxValue(0, lengthTrack / 2 + 0.1);
properties->setMinValue(0, -lengthTrack / 2 - 0.1);
properties->setMaxValue(1, lengthTrack);
properties->setMinValue(1, -lengthTrack);
properties->setMaxValue(2, M_PI);
properties->setMinValue(2, -M_PI);
properties->setPeriodicity(2, true);
properties->setMaxValue(3, 15);
properties->setMinValue(3, -15);
properties->setMaxValue(4, M_PI * 11);
properties->setMinValue(4, -M_PI * 11);
actionProp->setMaxActionValue(0,uMax);
actionProp->setMinActionValue(0,-uMax);
this->endLeaveTrack = endLeaveTrack;
this->endOverRotate = endOverRotate;
failedReward = -100.0;
}
CCartPoleModel::~CCartPoleModel()
{
delete properties;
delete actionProp;
delete contAction;
}
CMyMatrix *CCartPoleModel::getB(CState *state)
{
rlt_real Phi, x, dx, dPhi;
x = state->getContinuousState(0);
dx = state->getContinuousState(1);
Phi = state->getContinuousState(2);
dPhi = state->getContinuousState(3);
rlt_real denum = - 4 * lengthPole / 3 * (massCart + massPole) + lengthPole * massPole * pow(cos(Phi),2);
rlt_real uFactordx = - 4 * lengthPole / denum / 3;
rlt_real uFactordPhi = - cos(Phi) / denum;
B->initMatrix(0.0);
B->setElement(1, 0, uFactordx);
B->setElement(3, 0, uFactordPhi);
return B;
}
CMyVector *CCartPoleModel::getA(CState *state)
{
rlt_real Phi, x, dx, ddx, dPhi, ddPhi;
x = state->getContinuousState(0);
dx = state->getContinuousState(1);
Phi = state->getContinuousState(2);
dPhi = state->getContinuousState(3);
rlt_real sign_dx = 1.0;
if (dx < 0)
{
sign_dx = - 1.0;
}
// Calculate Inverse Matrix
rlt_real denum = - 4 * lengthPole / 3 * (massCart + massPole) + lengthPole * massPole * pow(cos(Phi),2);
rlt_real b1 = g * sin(Phi) - mu_p * dPhi / (lengthPole * massPole);
rlt_real b2 = lengthPole * massPole * dPhi * dPhi * sin(Phi) + mu_c * sign_dx;
ddx = - lengthPole * massPole * cos(Phi) * b1 + 4 / 3 * lengthPole * b2;
ddx = ddx / denum;
ddPhi = (- massCart - massPole) * b1 + cos(Phi) * b2;
ddPhi = ddPhi / denum;
A->setElement(0, dx);
A->setElement(1, ddx);
A->setElement(2, dPhi);
A->setElement(3, ddPhi);
A->setElement(4, dPhi);
return A;
}
bool CCartPoleModel::isFailedState(CState *state)
{
bool failed = (endLeaveTrack && fabs(state->getContinuousState(0)) > lengthTrack / 2);
failed = failed | (endOverRotate && fabs(state->getContinuousState(4)) > 10 * M_PI);
return failed;
}
void CCartPoleModel::doSimulationStep(CState *state, rlt_real timestep, CAction *action, CActionData *data)
{
getDerivationX(state, action, derivation, data);
rlt_real ddx = derivation->getElement(1);
rlt_real ddPhi = derivation->getElement(3);
for (unsigned int i = 0; i < state->getNumContinuousStates(); i++)
{
state->setContinuousState(i, state->getContinuousState(i) + timestep * derivation->getElement(i));
}
state->setContinuousState(0, state->getContinuousState(0) + pow(timestep, 2) * ddx / 2);
state->setContinuousState(2, state->getContinuousState(2) + pow(timestep, 2) * ddPhi / 2);
state->setContinuousState(4, state->getContinuousState(4) + pow(timestep, 2) * ddPhi / 2);
if (!endLeaveTrack && fabs(state->getContinuousState(0)) >= lengthTrack / 2 )
{
state->setContinuousState(1, 0.0);
}
}
void CCartPoleModel::getResetState(CState *state)
{
CTransitionFunction::getResetState(state);
state->setContinuousState(0, state->getContinuousState(0) * 0.8);
state->setContinuousState(1, 0.0);
state->setContinuousState(3, 0.0);
state->setContinuousState(4, state->getContinuousState(2));
}
CCartPoleRewardFunction::CCartPoleRewardFunction(CCartPoleModel *model) : CStateReward(model->getStateProperties())
{
this->cartpoleModel = model;
useHeighPeak = true;
punishOverRotate = true;
}
rlt_real CCartPoleRewardFunction::getStateReward(CState *state)
{
rlt_real Phi = state->getContinuousState(2);
rlt_real x = state->getContinuousState(0);
rlt_real reward = cos(Phi) - 1 - 100 * my_exp((fabs(x) - cartpoleModel->lengthTrack / 2) * 25);
if (useHeighPeak)
{
rlt_real dreward = exp(-pow(Phi, 2.0) * 25);
reward += dreward;
}
if (punishOverRotate)
{
rlt_real phi_ = state->getContinuousState(4);
rlt_real dreward = 20 * exp(fabs(phi_) - 10 * M_PI);
reward -= dreward;
}
return reward;
}
void CCartPoleRewardFunction::getInputDerivation(CState *modelState, CMyVector *targetState)
{
rlt_real Phi = modelState->getState(properties)->getContinuousState(2);
rlt_real x = modelState->getContinuousState(0);
if (x < 0)
{
targetState->setElement(0, 25 * 100 * exp((fabs(x) - cartpoleModel->lengthTrack / 2) * 25));
}
else
{
targetState->setElement(0, - 25 * 100 * exp((fabs(x) - cartpoleModel->lengthTrack / 2) * 25));
}
targetState->setElement(2, 0.0);
if (useHeighPeak)
{
targetState->setElement(2, - 50 * Phi * exp( -pow(Phi, 2.0) * 25));
}
if (punishOverRotate)
{
rlt_real phi_ = modelState->getContinuousState(4);
if (phi_ < 0)
{
targetState->setElement(4, 20 * exp(fabs(phi_) - 10 * M_PI));
}
else
{
targetState->setElement(4, - 20 * exp(fabs(phi_) - 10 * M_PI));
}
}
targetState->setElement(1, 0);
targetState->setElement(2, targetState->getElement(2) - sin(Phi));
targetState->setElement(3, 0);
}
CCartPoleHeightRewardFunction::CCartPoleHeightRewardFunction(CCartPoleModel *model) : CStateReward(model->getStateProperties())
{
this->cartpoleModel = model;
}
rlt_real CCartPoleHeightRewardFunction::getStateReward(CState *state)
{
rlt_real Phi = state->getContinuousState(2);
rlt_real x = state->getContinuousState(0);
rlt_real reward = cos(Phi) - 1;
return reward;
}
void CCartPoleHeightRewardFunction::getInputDerivation(CState *modelState, CMyVector *targetState)
{
rlt_real Phi = modelState->getState(properties)->getContinuousState(2);
rlt_real x = modelState->getContinuousState(0);
targetState->setElement(1, 0);
targetState->setElement(2, - sin(Phi));
targetState->setElement(3, 0);
}
#ifdef RL_TOOLBOX_USE_QT
CQTCartPoleVisualizer::CQTCartPoleVisualizer(CCartPoleModel *cartModel, QWidget *parent, const char *name) : CQTModelVisualizer(NULL, name)
{
this->cartModel = cartModel;
phi = 0;
dphi = 0;
x = 0;
dx = 0;
setFixedSize(700, 400);
}
void CQTCartPoleVisualizer::doDrawState( QPainter *painter)
{
QString s1 = "x = " + QString::number( x );
QString s2 = "x' = " + QString::number( dx );
QString s3 = "Phi = " + QString::number( phi );
QString s4 = "Phi' = " + QString::number( dphi );
painter->drawText(10,20, s1);
painter->drawText(10,40, s2);
painter->drawText(10,60, s3);
painter->drawText(10,80, s4);
painter->translate(this->width() / 2, this->height() / 2);
painter->setBrush(black);
painter->drawRect(- cartModel->lengthTrack / 2 * 100 - 60, -25, 5, 50);
painter->drawRect(cartModel->lengthTrack / 2 * 100 + 50, -25, 5, 50);
painter->translate(x * 100, 0);
painter->setBrush(white);
painter->drawRect(-50, -25, 100, 50);
painter->rotate(- phi + 180);
painter->setBrush(black);
painter->drawRect(- 5, -5 , 5, (cartModel->lengthPole * 100) + 5);
painter->flush();
}
void CQTCartPoleVisualizer::newDrawState(CStateCollection *state)
{
x = state->getState()->getContinuousState(0);
dx = state->getState()->getContinuousState(1);
phi = state->getState()->getContinuousState(2) * 180 / M_PI;
dphi = state->getState()->getContinuousState(3) * 180 / M_PI;
}
#endif
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -