⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cgridworldmodel.cpp

📁 强化学习算法(R-Learning)难得的珍贵资料
💻 CPP
📖 第 1 页 / 共 3 页
字号:
// Copyright (C) 2003
// Gerhard Neumann (gerhard@igi.tu-graz.ac.at)

//                
// This file is part of RL Toolbox.
// http://www.igi.tugraz.at/ril_toolbox
//
// All rights reserved.
// 
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
//    notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
//    notice, this list of conditions and the following disclaimer in the
//    documentation and/or other materials provided with the distribution.
// 3. The name of the author may not be used to endorse or promote products
//    derived from this software without specific prior written permission.
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
// OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
// IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
// NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "ril_debug.h"
#include "cgridworldmodel.h"

#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <string>
#include <stdexcept>

CGridWorld::CGridWorld(char* filename)
{
	start_values = new std::set<char>();
	target_values = new std::set<char>();
	prohibited_values = new std::set<char>();
	grid = new std::vector<char *>();
	size_x = 0;
	size_y = 0;
	load(filename);
}

CGridWorld::CGridWorld(unsigned int size_x, unsigned int size_y)
{
	start_values = new std::set<char>();
	target_values = new std::set<char>();
	prohibited_values = new std::set<char>();
	grid = new std::vector<char *>();
	this->size_x = size_x;
	this->size_y = size_y;
	initGrid();
}

CGridWorld::~CGridWorld()
{
	deallocGrid();
	start_values->clear();
	delete start_values;
	target_values->clear();
	delete target_values;
	prohibited_values->clear();
	delete prohibited_values;
	delete grid;
}

void CGridWorld::allocGrid()
{
	deallocGrid();
    for (int i = 0; i < size_y; i++)
	{
		grid->push_back(new char[size_x]);
	}
	is_allocated = true;
}

void CGridWorld::deallocGrid()
{
    if (is_allocated)
	{
		for (int i = 0; i < size_y; i++)
		{
			delete (*grid)[i];
		}
		grid->clear();
		is_allocated = false;
	}
}

void CGridWorld::initGrid()
{
	if (!is_allocated && size_x > 0 && size_y > 0)
	{
		allocGrid();
	}
}

bool CGridWorld::isValid()
{
	return is_allocated;
}

void CGridWorld::load(char* filename)
{
    FILE* stream = fopen(filename, "r");
	if (!stream) 
	{
		throw new std::runtime_error("Gridworld: Datei konnte nicht gefunden werden!");
	}
	load(stream);
	fclose(stream);
}

void CGridWorld::load(FILE *stream)
{
    int xsize = 0, ysize = 0, i = 0;
	bool flg_grid_world = false;
	char buffer[1024];
	std::string text;
	std::set<char>* tmpset;
	char temp;

    if (fgets(buffer, 1024, stream) == NULL)
	{
		throw new std::runtime_error("Gridworld: Datei enthaelt keine gueltige Gridworld! (unerwartetes Dateiende)");
	}
    
	while (!((buffer[0] <= 57 && buffer[0] >= 48) || (buffer[0] <= 70 && buffer[0] >= 65))) // != 1..9,A..F
	{
		if (!(buffer[0] == '#' || buffer[0] == 13 || buffer[0] == 10)) // ignore # and enter
		{
			text = buffer;
			if(text.find("Gridworld") == 0)
			{
				flg_grid_world = true;
			}
			else if (text.find("Size: ") == 0)
			{
				if (sscanf(buffer, "Size: %dx%d", &xsize, &ysize) != 2)
					throw new std::runtime_error("Gridworld: Datei enthaelt keine gueltigen Gridworldgroessenangaben!");
				this->size_x = xsize;
				this->size_y = ysize;
			}
			else
			{
				if (text.find("StartValues: ") == 0)
				{
					tmpset = start_values;
					i = strlen("StartValues:");
				}
				else if (text.find("TargetValues: ") == 0)
				{
					tmpset = target_values;
					i = strlen("TargetValues:");
				}
				else if (text.find("ProhibitedValues: ") == 0)
				{
					tmpset = prohibited_values;
					i = strlen("ProhibitedValues:");
				}
				else 
				{
					throw new std::runtime_error("Gridworld: Datei enthaelt keine gueltige Gridworld! (ungueltiges Token)");
				}

				while (i > 0)
				{
					if (sscanf(text.substr(i + 1).c_str(), "%c", &temp) != 1)
					{
						throw new std::runtime_error("Gridworld: Datei enthaelt keine gueltige Gridworld! (ungueltiger Wert)");
					}
					tmpset->insert((char)temp);
					i = text.find(",", i + 1);
				}
			}
		}

		if (fgets(buffer, 1024, stream) == NULL)
			throw new std::runtime_error("Gridworld: Datei enthaelt keine gueltige Gridworld! (unerwartetes Dateiende)");
	}

    if (!flg_grid_world)
		throw new std::runtime_error("Gridworld: Datei enthaelt keine gueltige Gridworld!");

	i = 0;
	if (size_x == 0) size_x = strlen(buffer) - 1;
		
	do 
	{
		if (strlen(buffer) < (unsigned int)(size_x))
			throw new std::runtime_error("Gridworld: Datei enthaelt keine gueltige Gridworld! (size_x != Gridgroesse)");
		grid->push_back(new char[size_x]);
		for (int j = size_x - 1; j >= 0; j--)
		{
			/*if (sscanf (&buffer[j], "%X", &temp) != 1)
				throw new std::runtime_error("Gridworld: Datei enthaelt keine gueltige Gridworld! (ungueltiger Gridwert)");*/
			(*grid)[i][j] = (char)buffer[j];
			buffer[j] = 0;
		}
		i++;
	}
	while (fgets(buffer, 1024, stream) != NULL);

	if (size_y == 0)
		size_y = i;
	if (i != size_y)
		throw new std::runtime_error("Gridworld: Datei enthaelt keine gueltige Gridworld! (size_y != Gridgroesse)");

	is_allocated = true;
}

void CGridWorld::save(char* filename)
{
    FILE* stream = fopen(filename, "w");
	if (!stream)
		throw new std::runtime_error("Gridworld: Datei konnte nicht erstellt werden!");
	save(stream);
	fclose(stream);
}

void CGridWorld::save(FILE *stream)
{
    fprintf(stream, "Gridworld\n");
	fprintf(stream, "Size: %dx%d\n", size_x, size_y);

	if (!start_values->empty())
	{
		std::set<char>::iterator it = start_values->begin();
		fprintf(stream, "StartValues: %1X", *(it++));
		while (it != start_values->end())
		{
			fprintf(stream, ", %1X", *(it++));
		}
		fprintf(stream, "\n");
	}

	if (!target_values->empty())
	{
		std::set<char>::iterator it = target_values->begin();
		fprintf(stream, "TargetValues: %1X", *(it++));
		while (it != target_values->end())
		{
			fprintf(stream, ", %1X", *(it++));
		}
		fprintf(stream, "\n");
	}

	if (!prohibited_values->empty())
	{
		std::set<char>::iterator it = prohibited_values->begin();
		fprintf(stream, "ProhibitedValues: %1X", *(it++));
		while (it != prohibited_values->end())
		{
			fprintf(stream, ", %1X", *(it++));
		}
		fprintf(stream, "\n");
	}

	fprintf(stream, "\n");
    for (int i = 0; i < size_y; i++)
	{
        for (int j = 0; j < size_x; j++)
		{
            fprintf(stream, "%1X", (*grid)[i][j]);
        }
        fprintf(stream, "\n");			
    }
}

void CGridWorld::setGridValue(unsigned int pos_x, unsigned int pos_y, char value)
{
	if (is_allocated && ((int)pos_x < size_x) && ((int)pos_y < size_y))
	{
		(*grid)[pos_y][pos_x] = value;
	}
}

char CGridWorld::getGridValue(unsigned int pos_x, unsigned int pos_y)
{
	if (is_allocated && ((int)pos_x < size_x) && ((int)pos_y < size_y))
	{
		return (*grid)[pos_y][pos_x];
	}
	else throw new std::invalid_argument("Gridworld_getGridValue: ungueltiger Parameter oder Grid nicht initialisiert)");
}

void CGridWorld::addStartValue(char value)
{
	start_values->insert(value);
}

void CGridWorld::removeStartValue(char value)
{
	start_values->erase(value);
}

std::set<char> *CGridWorld::getStartValues()
{
	return start_values;
}

void CGridWorld::addTargetValue(char value)
{
	target_values->insert(value);
}

void CGridWorld::removeTargetValue(char value)
{
	target_values->erase(value);
}

std::set<char> *CGridWorld::getTargetValues()
{
	return target_values;
}

void CGridWorld::addProhibitedValue(char value)
{
	prohibited_values->insert(value);
}

void CGridWorld::removeProhibitedValue(char value)
{
	prohibited_values->erase(value);
}

std::set<char> *CGridWorld::getProhibitedValues()
{
	return prohibited_values;
}

void CGridWorld::setSize(unsigned int size_x, unsigned int size_y)
{
	this->deallocGrid();
	this->size_x = size_x;
	this->size_y = size_y;
}

unsigned int CGridWorld::getSizeX()
{
	return this->size_x;
}

unsigned int CGridWorld::getSizeY()
{
	return this->size_y;
}

std::set<char> *CGridWorld::getUsedValues()
{
	static std::set<char> *values = new std::set<char>();
	values->clear();
	char tmp;
	for (int j = 0; j < size_y; j++)
	{
		for (int i = 0; i < size_x; i++)
		{
			tmp = (*grid)[j][i];
			values->insert((*grid)[j][i]);
		}
	}
	return values;
}


CGridWorldModel::CGridWorldModel(char* filename, unsigned int max_bounces) : CGridWorld(filename), CTransitionFunction(new CStateProperties(0,3), new CActionSet()) {
	this->max_bounces = max_bounces;
	this->reward_standard = -1.0;
	this->reward_bounce = -10;
	this->reward_success = 100.0;
	this->is_parsed = false;
	this->start_points = new std::vector<std::pair<int, int>* >();
	this->rewards = new std::map<char, rlt_real>();
}

CGridWorldModel::CGridWorldModel(unsigned int size_x, unsigned int size_y, unsigned int max_bounces) : CGridWorld(size_x, size_y), CTransitionFunction(new CStateProperties(0,3), new CActionSet()) {
	this->max_bounces = max_bounces;
	this->reward_standard = -1.0;
	this->reward_bounce = -10;
	this->reward_success = 100.0;
	this->is_parsed = false;
	this->start_points = new std::vector<std::pair<int, int>* >();
	this->rewards = new std::map<char, rlt_real>();
}

CGridWorldModel::~CGridWorldModel()
{
    for(unsigned int j = 0; j < start_points->size(); j++)
	{
		delete (*start_points)[j];
	}
	delete start_points;

	delete properties;
	delete actions;
	delete rewards;
}

void CGridWorldModel::setMaxBounces(unsigned int value)
{
	this->max_bounces = value;
}

unsigned int CGridWorldModel::getMaxBounces()
{
	return max_bounces;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -