⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cmultipolemodel.cpp

📁 强化学习算法(R-Learning)难得的珍贵资料
💻 CPP
字号:
// Copyright (C) 2003
// Gerhard Neumann (gerhard@igi.tu-graz.ac.at)

//                
// This file is part of RL Toolbox.
// http://www.igi.tugraz.at/ril_toolbox
//
// All rights reserved.
// 
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
//    notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
//    notice, this list of conditions and the following disclaimer in the
//    documentation and/or other materials provided with the distribution.
// 3. The name of the author may not be used to endorse or promote products
//    derived from this software without specific prior written permission.
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
// OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
// IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
// NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "ril_debug.h"
#include "cmultipolemodel.h"
#include <stdlib.h>
#include <stdio.h>
#include <assert.h>

#define GRAVITY 9.8
#define MASSCART 1.0
#define MASSPOLE 0.1
#define TOTAL_MASS (MASSPOLE + MASSCART)
#define LENGTH 0.5		  /* actually half the pole's length */
#define POLEMASS_LENGTH (MASSPOLE * LENGTH)
#define FORCE_MAG 10.0
#define TAU 0.02		  /* seconds between state updates */
#define FOURTHIRDS 1.3333333333333


#define one_degree 0.0174532	/* 2pi/360 */
#define six_degrees 0.1047192
#define twelve_degrees 0.2094384
#define fifty_degrees 0.87266

CMultiPoleModel::CMultiPoleModel() : CEnvironmentModel(4, 0)
{
	x= x_dot = theta = theta_dot = 0;

	properties->setMinValue(0, -2.4 * 1.1);
	properties->setMaxValue(0, 2.4 * 1.1);

	properties->setMinValue(1, -2);
	properties->setMaxValue(1, 2);

	properties->setMinValue(2, -twelve_degrees * 1.1);
	properties->setMaxValue(2, twelve_degrees * 1.1);

	properties->setMinValue(3, -fifty_degrees * 1.5);
	properties->setMaxValue(3, fifty_degrees * 1.5);
}

CMultiPoleModel::~CMultiPoleModel() {
}

void CMultiPoleModel::doNextState(CPrimitiveAction *act)
{
	rlt_real xacc,thetaacc,force,costheta,sintheta,temp;
	// cast the action to CMultiPoleAction
    CMultiPoleAction* action = (CMultiPoleAction*)(act);
	// determine the force    
	force = action->getForce();

	// calculate the new state
    costheta = cos(theta);
    sintheta = sin(theta);
    temp = (force + POLEMASS_LENGTH * theta_dot * theta_dot * sintheta) / TOTAL_MASS;
    thetaacc = (GRAVITY * sintheta - costheta* temp) / (LENGTH * (FOURTHIRDS - MASSPOLE * costheta * costheta / TOTAL_MASS));
    xacc  = temp - POLEMASS_LENGTH * thetaacc* costheta / TOTAL_MASS;
    /*** Update the four state variables, using Euler's method. ***/
    x  += TAU * x_dot;
    x_dot += TAU * xacc;
    theta += TAU * theta_dot;
    theta_dot += TAU * thetaacc;

	// determine wether the episode has failed
    if (x < -2.4 ||
          x > 2.4  ||
          theta < -twelve_degrees ||
          theta > twelve_degrees) {
          reset = true;
          failed = true;
    }
    // indicate that a new episode has begun
	if (reset)
	{
		printf("Failed State: x = %f; theta = %f\n", x, theta);
	}
}

rlt_real CMultiPoleModel::getReward(CStateCollection *oldStatecol, CAction *action, CStateCollection *newStateCol) {
    rlt_real rew;
	CState *newState = newStateCol->getState(getStateProperties());
	
	rlt_real x = newState->getContinuousState(0);
	rlt_real theta = newState->getContinuousState(2);
	
    // calculate the reward:
	// -1: for failed
	// 0 : else
	if (x < -2.4 ||
		x > 2.4  ||
		theta < -twelve_degrees ||
		theta > twelve_degrees)
	{
		rew = - 1.0;
	}
    else rew = 0.0;

    return rew;
}

CMultiPoleContinuousReward::CMultiPoleContinuousReward(CStateProperties *modelState) : CStateReward(modelState)
{
}

rlt_real CMultiPoleContinuousReward::getStateReward(CState *modelState)
{
	rlt_real reward = 0.0;
	rlt_real theta = modelState->getContinuousState(2);
	rlt_real x = modelState->getContinuousState(0);
	reward = -fabs(theta) * 5;
	if (fabs(x) > 2.4)
	{
		reward -= 5;
	}
	return theta;
}


// Store the model state to the given state object
void CMultiPoleModel::getState(CState *state)
{
	///resets the state object
	CEnvironmentModel::getState(state);

	// Set the 4 internal state variables to the 
	// continuous state variables of the model state
	state->setContinuousState(0, x);
	state->setContinuousState(1, x_dot);
	state->setContinuousState(2, theta);
	state->setContinuousState(3, theta_dot);
}

void CMultiPoleModel::doResetModel()
{
    /// Reset internal state variables
	x = x_dot = theta = theta_dot = 0;
}

CMultiPoleDiscreteState::CMultiPoleDiscreteState() : CAbstractStateDiscretizer(163)
{
}

unsigned int CMultiPoleDiscreteState::getDiscreteStateNumber(CStateCollection *stateCol)
{
	// get the model state
	CState *state = stateCol->getState();
	int box;
	// get the 4 continuous state variables
	rlt_real x = state->getContinuousState(0);
    rlt_real x_dot = state->getContinuousState(1);
    rlt_real theta = state->getContinuousState(2);
    rlt_real theta_dot = state->getContinuousState(3);

    if (x < -2.4 ||  x > 2.4  || theta < -twelve_degrees || theta > twelve_degrees)
	{
		box = -1; /* to signal failure */
    }
    else
	{	
		//partition x
		if (x < -0.8) box = 0;
		else if (x < 0.8) box = 1;
		else box = 2;

		//partition x_dot
		if (x_dot < -0.5);
		else if (x_dot < 0.5) box += 3;
		else box += 6;

		//partition theta
		if (theta < -six_degrees);
		else if (theta < -one_degree) box += 9;
		else if (theta < 0) box += 18;
		else if (theta < one_degree) box += 27;
		else if (theta < six_degrees) box += 36;
		else box += 45;

		//partition theta_dot
		if (theta_dot < -fifty_degrees);
		else if (theta_dot < fifty_degrees)  box += 54;
		else box += 108;
    }
	//increase box because only positiv values are allowed.
	box ++;
    
	return box;
}

CMultiPoleFailedState::CMultiPoleFailedState() : CAbstractStateDiscretizer(2)
{
}

unsigned int CMultiPoleFailedState::getDiscreteStateNumber(CStateCollection *stateCol)
{
	// get the model state
	CState *state = stateCol->getState();
	int box;
	rlt_real x = state->getContinuousState(0);
    rlt_real theta = state->getContinuousState(2);

	/// calculate wether the state is a failed state
	if (x < -2.4 ||  x > 2.4  || theta < -twelve_degrees || theta > twelve_degrees)
	{
		box = 0; /* to signal failure */
    }
	else
	{
		box = 1;
	}
	return box;
}

CMultiPoleAction::CMultiPoleAction(rlt_real force) : CPrimitiveAction()
{
    this->force = force;
}

rlt_real CMultiPoleAction::getForce()
{
    return this->force;
}


CMultiPoleController::CMultiPoleController(CActionSet *actions) : CAgentController(actions)
{
}

CMultiPoleController::~CMultiPoleController()
{
}

CAction* CMultiPoleController::getNextAction(CStateCollection *stateCol, CActionDataSet *data)
{
	CState *state = stateCol->getState();

	//rlt_real x = state->getContinuousState(0);
    //rlt_real x_dot = state->getContinuousState(1);
    rlt_real theta = state->getContinuousState(2);
    rlt_real theta_dot = state->getContinuousState(3);
    rlt_real costheta = cos(theta);
    rlt_real sintheta = sin(theta);

	theta += TAU * theta_dot;
	rlt_real temp1 = (FORCE_MAG + POLEMASS_LENGTH * theta_dot * theta_dot * sintheta) / TOTAL_MASS;
    rlt_real theta_acc1 = (GRAVITY * sintheta - costheta* temp1) / (LENGTH * (FOURTHIRDS - MASSPOLE * costheta * costheta / TOTAL_MASS));
	//rlt_real x_acc1  = temp1 - POLEMASS_LENGTH * theta_acc1 * costheta / TOTAL_MASS;
	rlt_real theta1 = theta + TAU * theta_dot + TAU * (theta_dot + TAU * theta_acc1);
    //rlt_real x1 = x + TAU * x_dot + TAU * (x_dot + TAU * x_acc1);

	rlt_real temp2 = (-FORCE_MAG + POLEMASS_LENGTH * theta_dot * theta_dot * sintheta) / TOTAL_MASS;
    rlt_real theta_acc2 = (GRAVITY * sintheta - costheta* temp2) / (LENGTH * (FOURTHIRDS - MASSPOLE * costheta * costheta / TOTAL_MASS));
	//rlt_real x_acc2  = temp2 - POLEMASS_LENGTH * theta_acc2 * costheta / TOTAL_MASS;
	rlt_real theta2 = theta + TAU * theta_dot + TAU * (theta_dot + TAU * theta_acc2);
    //rlt_real x2 = x + TAU * x_dot + TAU * (x_dot + TAU * x_acc2);

	int index;
	if (fabs(theta1) >= fabs(theta2))
	{
		index = 1;
	}
	else
	{
		index = 0;
	}
	
	return this->actions->get(index);
}

CMultiPoleDiscreteController::CMultiPoleDiscreteController(CActionSet *actions, CStateProperties *discState) : CAgentController(actions), CStateObject(discState) 
{
}

CMultiPoleDiscreteController::~CMultiPoleDiscreteController()
{
}

CAction* CMultiPoleDiscreteController::getNextAction(CStateCollection *stateCol, CActionDataSet *data)
{
	CState *state = stateCol->getState(properties);

	int actionIndeces[163];
	memset(actionIndeces, 0, sizeof(int) * 163);
	
	actionIndeces[2] = 1;
	actionIndeces[9] = 1;
	actionIndeces[21] = 1;
	actionIndeces[31] = 1;
	actionIndeces[32] = 1;
	actionIndeces[34] = 1;
	actionIndeces[51] = 1;
	actionIndeces[52] = 1;
	actionIndeces[53] = 1;
	actionIndeces[68] = 1;
	actionIndeces[71] = 1;
	actionIndeces[73] = 1;
	actionIndeces[75] = 1;
	actionIndeces[76] = 1;
	actionIndeces[77] = 1;
	actionIndeces[78] = 1;
	actionIndeces[79] = 1;
	actionIndeces[81] = 1;
	actionIndeces[85] = 1;
	actionIndeces[86] = 1;
	actionIndeces[87] = 1;
	actionIndeces[88] = 1;
	actionIndeces[89] = 1;
	actionIndeces[91] = 1;
	actionIndeces[93] = 1;
	actionIndeces[94] = 1;
	actionIndeces[95] = 1;
	actionIndeces[96] = 1;
	actionIndeces[97] = 1;
	actionIndeces[98] = 1;
	actionIndeces[99] = 1;
	actionIndeces[100] = 1;
	actionIndeces[101] = 1;
	actionIndeces[102] = 1;
	actionIndeces[103] = 1;
	actionIndeces[104] = 1;
	actionIndeces[105] = 1;
	actionIndeces[106] = 1;
	actionIndeces[107] = 1;
	actionIndeces[109] = 1;
	actionIndeces[110] = 1;
	actionIndeces[112] = 1;

	actionIndeces[117] = 1;
	actionIndeces[118] = 1;
	actionIndeces[119] = 1;
	actionIndeces[120] = 1;
	actionIndeces[122] = 1;
	actionIndeces[125] = 1;
	actionIndeces[127] = 1;
	actionIndeces[128] = 1;
	actionIndeces[129] = 1;
	actionIndeces[130] = 1;
	actionIndeces[131] = 1;
	actionIndeces[135] = 1;
	actionIndeces[136] = 1;
	actionIndeces[137] = 1;
	actionIndeces[138] = 1;
	actionIndeces[139] = 1;
	actionIndeces[140] = 1;
	actionIndeces[144] = 1;
	actionIndeces[145] = 1;
	actionIndeces[146] = 1;
	actionIndeces[147] = 1;
	actionIndeces[148] = 1;
	actionIndeces[149] = 1;
	actionIndeces[151] = 1;
	actionIndeces[152] = 1;
	actionIndeces[153] = 1;
	actionIndeces[154] = 1;
	actionIndeces[155] = 1;
	actionIndeces[156] = 1;
	actionIndeces[157] = 1;
	actionIndeces[161] = 1;

	int discreteStateNum = state->getDiscreteState(0);
	if (discreteStateNum == 0)
	{
		return actions->get(0);
	}
	return actions->get((actionIndeces[discreteStateNum - 1] + 1)%2);
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -