⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cstate.cpp

📁 强化学习算法(R-Learning)难得的珍贵资料
💻 CPP
字号:
// Copyright (C) 2003
// Gerhard Neumann (gerhard@igi.tu-graz.ac.at)

//                
// This file is part of RL Toolbox.
// http://www.igi.tugraz.at/ril_toolbox
//
// All rights reserved.
// 
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
//    notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
//    notice, this list of conditions and the following disclaimer in the
//    documentation and/or other materials provided with the distribution.
// 3. The name of the author may not be used to endorse or promote products
//    derived from this software without specific prior written permission.
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
// OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
// IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
// NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "ril_debug.h"
#include "cstate.h"
#include "cenvironmentmodel.h"

#include <assert.h>
#include <math.h>

CState::CState(CStateProperties *properties) : CStateObject(properties),CMyVector(properties->getNumContinuousStates())
{
	continuousState = getData();
	discreteState = new int[getNumDiscreteStates()];

	resetState();
}

CState::CState(CState *copy) : CStateObject(copy->getStateProperties()) , CMyVector(copy->getStateProperties()->getNumContinuousStates())
{
	continuousState = getData();
	discreteState = new int[getNumDiscreteStates()];

	resetState();

	setState(copy);
}

CState::CState(CEnvironmentModel *model) : CStateObject(model->getStateProperties()), CMyVector(model->getStateProperties()->getNumContinuousStates())
{
	continuousState = getData();
	discreteState = new int[getNumDiscreteStates()];

	resetState();
}

CState::CState(CStateProperties *properties, FILE *stream, bool binary) : CStateObject(properties), CMyVector(properties->getNumContinuousStates())
{
	this->properties = properties;

	if (binary)
	{
		loadBinary(stream);
	}
	else
	{
		loadASCII(stream);
	}
	numActiveContinuousStates  = properties->getNumContinuousStates();
	numActiveDiscreteStates = properties->getNumDiscreteStates();
}


void CState::resetState()
{
	unsigned int i = 0;
	for (i = 0; i < getNumContinuousStates(); i++)
	{
		continuousState[i] = 0;
	}
	for (i = 0; i < getNumDiscreteStates(); i++)
	{
		discreteState[i] = false;
	}
	numActiveContinuousStates  = properties->getNumContinuousStates();
	numActiveDiscreteStates = properties->getNumDiscreteStates();
}

void CState::setState(CState *copy)
{
	unsigned int i = 0;
	assert(equalsModelProperties(copy));
	for (i = 0; i < getNumContinuousStates(); i++)
	{
		continuousState[i] = copy->getContinuousState(i);
	}
	for (i = 0; i < getNumDiscreteStates(); i++)
	{
		discreteState[i] = copy->getDiscreteState(i);
	}

	numActiveContinuousStates  = copy->getNumActiveContinuousStates();
	numActiveDiscreteStates = copy->getNumActiveDiscreteStates();
}

CState::~CState()
{
//	delete [] continuousState;
	delete [] discreteState;
}

rlt_real CState::getContinuousState(unsigned int dim)
{
	assert(dim < getNumContinuousStates());
	return continuousState[dim];
}

rlt_real CState::getNormalizedContinuousState(unsigned int dim)
{
	assert(dim < getNumContinuousStates());
	return (continuousState[dim] - this->getStateProperties()->getMinValue(dim)) / 
		(this->getStateProperties()->getMaxValue(dim) - this->getStateProperties()->getMinValue(dim));
}

int CState::getDiscreteState(unsigned int dim)
{
	assert(dim < getNumDiscreteStates());
	return discreteState[dim];
}

CState* CState::clone()
{
	return new CState(this);
}

void CState::setContinuousState(unsigned int dim, rlt_real value)
{
	assert(dim < getNumContinuousStates());
	continuousState[dim]= value;

	if (properties->getPeriodicity(dim))
	{
		if ((continuousState[dim] < properties->getMinValue(dim)) || (continuousState[dim] > properties->getMaxValue(dim)))
		{
			rlt_real Period = (properties->getMaxValue(dim) - properties->getMinValue(dim));
			assert(Period > 0);
			continuousState[dim] = continuousState[dim] - Period * floor((continuousState[dim] - properties->getMinValue(dim)) / Period);
		}
	}
	else
	{
		if (continuousState[dim] < properties->getMinValue(dim))
		{
			continuousState[dim] = properties->getMinValue(dim);
		}
		else
		{
			if (continuousState[dim] > properties->getMaxValue(dim))
			{
				continuousState[dim] = properties->getMaxValue(dim);
			}
		}
	}
}

void CState::setDiscreteState(unsigned int dim, int value)
{
	assert(dim < getNumDiscreteStates());
	discreteState[dim]= value;	
}

CState *CState::getState(CStateProperties *properties)
{
	return this;
}

CState *CState::getState()
{
	return this;
}

bool CState::isMember(CStateProperties *stateModifier)
{
	return getStateProperties() == stateModifier;
}

int CState::getNumActiveDiscreteStates()
{
	return numActiveDiscreteStates;
}

int CState::getNumActiveContinuousStates()
{
	return numActiveContinuousStates;
}

void CState::setNumActiveDiscreteStates(int numActiveStates)
{
	numActiveDiscreteStates = numActiveStates;
}

void CState::setNumActiveContinuousStates(int numActiveStates)
{
	numActiveContinuousStates = numActiveStates;
}


/*CActionSet* CState::getAvailableActions()
{
	return availableActions;
}

void CState::setAvailableActions(CActionSet *aset)
{
	CActionSet::iterator it;

	availableActions->clear();

	for (it = aset->begin(); it != aset->end(); it++)
	{
		availableActions->push_back(*it);
	}

}
*/

bool CState::equals(CState *state)
{
	unsigned int i;
	if (! this->equalsModelProperties(state))
	{
		return false;
	}
	for (i = 0; i < getNumContinuousStates(); i++)
	{
		if (getContinuousState(i) != state->getContinuousState(i))
		{
			return false;
		}
	}
	for (i = 0; i < getNumDiscreteStates(); i++)
	{
		if (getDiscreteState(i) != state->getDiscreteState(i))
		{
			return false;
		}
	}
	return true;
}


void CState::saveBinary(FILE *stream)
{
	fwrite(this->continuousState, sizeof(rlt_real), getNumContinuousStates(), stream);
	fprintf(stream, "\n");
	fwrite(this->discreteState, sizeof(bool), getNumDiscreteStates(), stream);
	fprintf(stream, "\n");
}

void CState::loadBinary(FILE *stream)
{
	fread(this->continuousState, sizeof(rlt_real), getNumContinuousStates(), stream);
	fscanf(stream, "\n");
	fread(this->discreteState, sizeof(bool), getNumDiscreteStates(), stream);
	fscanf(stream, "\n");
}

void CState::saveASCII(FILE *stream)
{
	unsigned int i = 0;
	fprintf(stream, "[");
	for (i = 0; i < getNumContinuousStates(); i++)
	{
		fprintf(stream, "%lf ", getContinuousState(i));
	}
	for (i = 0; i < getNumDiscreteStates(); i++)
	{
		fprintf(stream, "%d ", (int)getDiscreteState(i));
	}
	fprintf(stream, "]");
}

void CState::loadASCII(FILE *stream)
{
	unsigned int i;
	int bBuf;
	fscanf(stream, "[");
	for (i = 0; i < getNumContinuousStates(); i++)
	{
		fscanf(stream, "%lf ", &continuousState[i]);
	}
	for (i = 0; i < getNumDiscreteStates(); i++)
	{
		fscanf(stream, "%d ", &bBuf);
        discreteState[i] = (bBuf != 0);
	}
	fscanf(stream, "]");
}

rlt_real CState::getDistance(CMyVector *vector)
{
	assert(getNumDimensions() == vector->getNumDimensions());
	rlt_real *data2 = vector->getData();
	rlt_real distance = 0.0;

	for (unsigned int i = 0; i < getNumDimensions(); i++)
	{
		distance += pow(getSingleStateDifference(i,data2[i]), 2);
	}
	return sqrt(distance);
}


rlt_real CState::getSingleStateDifference(int i, rlt_real value)
{
	rlt_real distance = 0.0;
	if (properties->getPeriodicity(i))
	{
		rlt_real period = properties->getMaxValue(i) - properties->getMinValue(i);
		assert(period != 0);
		distance = (data[i]-value);
		if (distance < - period / 2)
		{
			distance += period;
		}
		else
		{
			if (distance > period / 2)
			{
				distance -= period;
			}
		}	
	}
	else
	{
		distance = data[i] - value;
	}
	return distance;
}


CStateList::CStateList(CStateProperties *properties) : CStateObject(properties)
{
	continuousStates = new std::vector<std::vector<rlt_real> *>();
	discreteStates = new std::vector<std::vector<int> *>();

	unsigned int i;
	for (i = 0; i < properties->getNumContinuousStates(); i++)
	{
		continuousStates->push_back(new std::vector<rlt_real>());
	}
	for (i = 0; i < properties->getNumDiscreteStates(); i++)
	{
		discreteStates->push_back(new std::vector<int>());
	}

	numStates = 0;
}

CStateList::~CStateList()
{
	unsigned int i;
	for (i = 0; i < getNumContinuousStates(); i++)
	{
		delete (*continuousStates)[i];
	}
	for (i = 0; i < getNumDiscreteStates(); i++)
	{
		delete (*discreteStates)[i];
	}
	delete continuousStates;
	delete discreteStates;
}

void CStateList::addState(CState *state)
{
	unsigned int i;
	assert(state->equalsModelProperties(this));

	for (i = 0; i < getNumContinuousStates(); i++)
	{
		(*continuousStates)[i]->push_back(state->getContinuousState(i));
	}
	for (i = 0; i < getNumDiscreteStates(); i++)
	{
		(*discreteStates)[i]->push_back(state->getDiscreteState(i));
	}
	numStates ++;
}

void CStateList::removeLastState()
{
	unsigned int i;
	for (i = 0; i < getNumContinuousStates(); i++)
	{
		(*continuousStates)[i]->pop_back();
	}
	for (i = 0; i < getNumDiscreteStates(); i++)
	{
		(*discreteStates)[i]->pop_back();
	}
	numStates --;
}

unsigned int CStateList::getNumStates()
{
	return numStates;
}

void CStateList::clear()
{
	unsigned int i;
	for (i = 0; i < getNumContinuousStates(); i++)
	{
		(*continuousStates)[i]->clear();
	}
	for (i = 0; i < getNumDiscreteStates(); i++)
	{
		(*discreteStates)[i]->clear();
	}
	numStates  = 0;
}

void CStateList::getState(unsigned int num, CState *state)
{
	unsigned int i;
	for (i = 0; i < getNumContinuousStates(); i++)
	{
		state->setContinuousState(i, (*(*continuousStates)[i])[num]);
	}
	for (i = 0; i < getNumDiscreteStates(); i++)
	{
		state->setDiscreteState(i, (*(*discreteStates)[i])[num]);
	}
}

void CStateList::loadBIN(FILE *stream)
{
	unsigned int i, j, buf;
	rlt_real dBuf;
	int nBuf;
	buf = 0;
	
	fread(&buf, sizeof(int), 1, stream);
	
	for (i = 0; i < properties->getNumContinuousStates(); i++)
	{
		for (j = 0; j < buf; j++)
		{
			dBuf = 0.0;

			int r  = fread( &dBuf, sizeof(rlt_real), 1, stream);
			assert(r == 1);
			(*continuousStates)[i]->push_back(dBuf);
		}
	}

	for (i = 0; i < properties->getNumDiscreteStates(); i++)
	{
		for (j = 0; j < buf; j++)
		{
			int r = fread( &nBuf, sizeof(int), 1, stream);
			assert(r == 1);
			(*discreteStates)[i]->push_back(nBuf);
		}
	}	
	numStates = buf;
}

void CStateList::saveBIN(FILE *stream)
{
	int buf = getNumStates();
	unsigned int i, j;
    int nBuf;
    rlt_real dBuf;
	
	fwrite(&buf, sizeof(int), 1, stream);
	for (i = 0; i < properties->getNumContinuousStates(); i++)
	{
		for (j = 0; j < getNumStates(); j++)
		{
            dBuf = (*(*continuousStates)[i])[j];
			fwrite(&dBuf, sizeof(rlt_real), 1, stream);
		}
	}
	for (i = 0; i < properties->getNumDiscreteStates(); i++)
	{
		for (j = 0; j < getNumStates(); j++)
		{
            nBuf = (*(*discreteStates)[i])[j];
			fwrite( &nBuf, sizeof(int), 1, stream);
		}
	}
}

void CStateList::loadASCII(FILE *stream)
{
	int buf1, buf2, res;
	unsigned int i, j, buf;

	rlt_real dBuf;
	int bBuf;
	res = fscanf(stream,"States: %d\n", &buf);
	assert(res == 1);
	fscanf(stream, "\n");
	for (i = 0; i < properties->getNumContinuousStates(); i++)
	{
		res = fscanf(stream, "ContinuousState %d:\n", &buf1);
		assert(res == 1);
		for (j = 0; j < buf; j++)
		{
			assert(fscanf(stream, "%lf ", &dBuf) == 1);

			(*continuousStates)[i]->push_back(dBuf);
		}
		fscanf(stream, "\n");
	}
	for (i = 0; i < properties->getNumDiscreteStates(); i++)
	{
		res = fscanf(stream, "DiscreteState %d:\n", &buf2);
		assert(res == 1);
		for (j = 0; j < buf; j++)
		{
			assert(fscanf( stream, "%d ", &bBuf) == 1);
			(*discreteStates)[i]->push_back(bBuf);
		}
		fscanf(stream, "\n");
	}
	numStates = buf;
}

void CStateList::saveASCII(FILE *stream)
{
	fprintf(stream,"States: %d\n", getNumStates());
	fprintf(stream, "\n");
	unsigned int i, j;
	for (i = 0; i < properties->getNumContinuousStates(); i++)
	{
		fprintf(stream, "ContinuousState %d:\n", i);
		for (j = 0; j < getNumStates(); j++)
		{
			fprintf(stream, "%f ", (*(*continuousStates)[i])[j]);
		}
		fprintf(stream, "\n");
	}
	for (i = 0; i < properties->getNumDiscreteStates(); i++)
	{
		fprintf(stream, "DiscreteState %d:\n", i);
		for (j = 0; j < getNumStates(); j++)
		{
			fprintf( stream, "%d ", (int)((*(*discreteStates)[i])[j]));
		}
		fprintf(stream, "\n");
	}
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -