⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cgridworldmodel.cpp

📁 强化学习算法(R-Learning)难得的珍贵资料
💻 CPP
📖 第 1 页 / 共 3 页
字号:

CSmallLocalGridWorldDiscreteState::CSmallLocalGridWorldDiscreteState(CStateModifier* orig_state, unsigned int neighbourhood, CGridWorld *gridworld) : CAbstractStateDiscretizer((unsigned int)pow((rlt_real) 3.0,(rlt_real) neighbourhood))
{
	this->orig_state = orig_state;
	this->gridworld = gridworld;
}

CSmallLocalGridWorldDiscreteState::~CSmallLocalGridWorldDiscreteState()
{
}

unsigned int CSmallLocalGridWorldDiscreteState::getDiscreteStateNumber(CStateCollection *state)
{
    CState *source_state = state->getState(orig_state);
    unsigned int discstate = 0;
	unsigned int temp;
	for (unsigned int i = 0; i < source_state->getNumDiscreteStates() - 1; i++)
	{
		if (gridworld->getTargetValues()->find((char)source_state->getDiscreteState(i)) != gridworld->getTargetValues()->end())
			temp = 2;
		else if (gridworld->getProhibitedValues()->find((char)source_state->getDiscreteState(i)) != gridworld->getProhibitedValues()->end())
			temp = 1;
		else
			temp = 0;
		discstate = discstate * 3 + temp;
	}
	return discstate;
}


CGridWorldAction::CGridWorldAction(int x_move, int y_move)
{
	addType(GRIDWORLDACTION);

    this->x_move = x_move;
	this->y_move = y_move;
}

int CGridWorldAction::getXMove()
{
    return this->x_move;
}

int CGridWorldAction::getYMove()
{
    return this->y_move;
}


#ifdef WIN32

CGridWorldController::CGridWorldController(CGridWorld *gridworld, CActionSet *actions) : CAgentStatisticController(actions)
{
    this->gridworld = gridworld;
	target_points = new std::set<std::pair<unsigned int, unsigned int>*>();
	record = new std::vector<GridControllerRecord>();
	record->resize(actions->size());
	lastXMove = 0;
	lastYMove = 0;
	init();
}

CGridWorldController::~CGridWorldController()
{
	record->clear();
	delete record;
	if (target_points->size() > 0)
	{
		std::set<std::pair<unsigned int, unsigned int> *>::iterator it = target_points->begin();
        for(;it != target_points->end(); it++)
		{
			delete *it;
		}
		target_points->clear();
	}
    delete target_points;
}

void CGridWorldController::init()
{
	std::pair<unsigned int, unsigned int> *target_point;
	if (target_points->size() > 0)
	{
		std::set<std::pair<unsigned int, unsigned int> *>::iterator it = target_points->begin();
        for(;it != target_points->end(); it++)
		{
			delete *it;
		}
		target_points->clear();
	}

	for (unsigned int j = 0; j < gridworld->getSizeY(); j++)
	{
		for (unsigned int i = 0; i < gridworld->getSizeX(); i++)
		{
			if (gridworld->getTargetValues()->find(gridworld->getGridValue(i ,j)) != gridworld->getTargetValues()->end())
			{
				target_point = new std::pair<unsigned int, unsigned int>(i, j);
				target_points->insert(target_point);
			}
		}
	}
}

void CGridWorldController::newEpisode()
{
	for (unsigned int i = 0; i < record->size(); i++)
	{
		(*record)[i].factor = 1.0;
	}
}

CAction* CGridWorldController::getNextAction(CStateCollection *state, CActionStatistics *stat)
{
	std::set<std::pair<unsigned int, unsigned int> *>::iterator itp;
	int x = state->getState()->getDiscreteState(0);
	int y = state->getState()->getDiscreteState(1);
	unsigned int bestind = 0;
	rlt_real dist;
	rlt_real maxdist = gridworld->getSizeX() * gridworld->getSizeX() + gridworld->getSizeY() * gridworld->getSizeY();
	CActionSet::iterator it = actions->begin();
	for (int i = 0; it != actions->end(); i++, it++)
	{
		CGridWorldAction *gridAction = dynamic_cast<CGridWorldAction*>(*it);
		(*record)[i].pos_x = x + gridAction->getXMove();
		(*record)[i].pos_y = y + gridAction->getYMove();
		(*record)[i].action = gridAction;
		dist = maxdist * 2.0;
		if (gridworld->getProhibitedValues()->find(gridworld->getGridValue((*record)[i].pos_x ,(*record)[i].pos_y)) == gridworld->getProhibitedValues()->end())
		{ // kein verbotener wert
			for (itp = target_points->begin(); itp != target_points->end(); itp++)
			{
		        dist = min(dist, ((*record)[i].pos_x - (*itp)->first) * ((*record)[i].pos_x - (*itp)->first) + ((*record)[i].pos_y - (*itp)->second) * ((*record)[i].pos_y - (*itp)->second));
			}
			dist *= (*record)[i].factor;
		}
		else
		{
			//(*record)[i].factor = 1.0;
		}
		if ((lastXMove == -gridAction->getXMove()) && (lastYMove == -gridAction->getYMove()))
		{ // wir wollen doch nicht den gleichen weg zurueck gehen
			(*record)[i].distance = max(maxdist - 0.5, dist);
		}
		else
		{
			(*record)[i].distance = dist;
		}
		if ((*record)[i].distance < (*record)[bestind].distance)
			bestind = i;
	}
	lastXMove = (*record)[bestind].action->getXMove();
	lastYMove = (*record)[bestind].action->getYMove();

	for (unsigned int j = 0; j < record->size(); j++)
	{
		if (j == bestind)
			(*record)[j].factor = max(0.7, (*record)[j].factor * 0.99);
		else
			(*record)[j].factor = min(1.3, (*record)[j].factor * 1.01);
	}

	if (stat != NULL)
	{
		stat->action = (*record)[bestind].action;
		stat->owner = this;
		stat->equal = 1;
		stat->probability = 0.5;
		stat->superior = 0;
		return stat->action;
	}
	else
	{
		return (*record)[bestind].action;
	}
}


CGridWorldVisualizer::CGridWorldVisualizer(CGridWorldModel *gridworld)
{
	this->gridworld = gridworld;
	this->console = GetStdHandle(STD_OUTPUT_HANDLE);
	this->flgDisplay = true;
	this->flgTranspose = false;
	this->xpos = 0;
    this->ypos = 0;
	this->xoffset = 1;
	this->yoffset = 3;
}

CGridWorldVisualizer::~CGridWorldVisualizer()
{
}

void CGridWorldVisualizer::nextStep(CStateCollection *oldStateCol, CAction *action, CStateCollection *nextStateCol)
{
	COORD coord;
    const WORD attribute = BACKGROUND_RED;
	DWORD dummy;

	CState *nextState = nextStateCol->getState();
	CState *oldState = oldStateCol->getState();

	if (flgTranspose)
	{
		xpos = (short)oldState->getDiscreteState(1);
		ypos = (short)oldState->getDiscreteState(0);
	}
	else
	{
		xpos = (short)oldState->getDiscreteState(0);
		ypos = (short)oldState->getDiscreteState(1);
	}

	if (flgDisplay)
	{
		coord.X = xpos + xoffset;
		coord.Y = ypos + yoffset;
		SetConsoleCursorPosition(console, coord);
		if (gridworld->getGridValue(xpos, ypos) != 0)
		{
			printf("%1X", gridworld->getGridValue(xpos, ypos));
		}
		else
		{
			printf(" ");
		}
		WriteConsoleOutputAttribute(console, &attribute, 1, coord, &dummy);

		if (flgTranspose)
		{
			xpos = (short)nextState->getDiscreteState(1);
			ypos = (short)nextState->getDiscreteState(0);
		}
		else
		{
			xpos = (short)nextState->getDiscreteState(0);
			ypos = (short)nextState->getDiscreteState(1);
		}
		coord.X = xpos + xoffset;
		coord.Y = ypos + yoffset;
		SetConsoleCursorPosition(console, coord);
		printf("#");
	}
}

void CGridWorldVisualizer::newEpisode()
{
	DWORD dummy;
	COORD coord;
	CONSOLE_SCREEN_BUFFER_INFO sInfo;
	GetConsoleScreenBufferInfo(console, &sInfo);
	coord.X = 0;
	coord.Y = 0;
	FillConsoleOutputCharacter(console, ' ', (1 + sInfo.srWindow.Right - sInfo.srWindow.Left) * (1 + sInfo.srWindow.Bottom - sInfo.srWindow.Top), coord, &dummy);
	
	if (flgDisplay)
	{
		flgTranspose = (gridworld->getSizeX() < gridworld->getSizeY());
		if (flgTranspose)
		{
			coord.X = xoffset;
			for (unsigned int h = 0; h < gridworld->getSizeX(); h++)
			{
				coord.Y = (short)h + yoffset;
				SetConsoleCursorPosition(console, coord);
				for (unsigned int j = 0; j < gridworld->getSizeY(); j++)
				{
					if (gridworld->getGridValue(h, j) != 0)
					{
						printf("%1X", gridworld->getGridValue(h, j));
					}
					else
					{
						printf(" ");
					}
				}
			}
			//xpos = (short)gridworld->getPosY();
			//ypos = (short)gridworld->getPosX();
		}
		else
		{
			coord.X = xoffset;
			for (unsigned int h = 0; h < gridworld->getSizeY(); h++)
			{
				coord.Y = (short)h + yoffset;
				SetConsoleCursorPosition(console, coord);
				for (unsigned int j = 0; j < gridworld->getSizeX(); j++)
				{
					if (gridworld->getGridValue(j, h) != 0)
					{
						printf("%1X", gridworld->getGridValue(j, h));
					}
					else
					{
						printf(" ");
					}
				}
			}
			//xpos = (short)gridworld->getPosX();
			//ypos = (short)gridworld->getPosY();
		}
		coord.X = xpos + xoffset;
		coord.Y = ypos + yoffset;
		//SetConsoleCursorPosition(console, coord);
		//printf("#");
	}
}

bool CGridWorldVisualizer::getDisplay()
{
	return flgDisplay;
}

void CGridWorldVisualizer::setDisplay(bool flgDisplay)
{
	this->flgDisplay = flgDisplay;
}

#endif // WIN32


void CRaceTrack::generateRaceTrack(CGridWorld *gridworld, unsigned int width, unsigned int length, unsigned int h_max, unsigned int dx_min, unsigned int dx_max)
{
	unsigned int i, j, x, dx, tmp1, tmp2, y1, y2, h;
	gridworld->setSize(length, width);
	gridworld->initGrid();
	gridworld->addProhibitedValue(1);
	gridworld->addProhibitedValue(2);
	gridworld->addStartValue(3);
	gridworld->addTargetValue(4);

	for (j = 1; j < width - 1; j++)
	{
		gridworld->setGridValue(0, j, 2);
		gridworld->setGridValue(1, j, 3);
		gridworld->setGridValue(length - 2, j, 4);
		gridworld->setGridValue(length - 1, j, 2);
		for (i = 2; i < length - 2; i++)
		{
			gridworld->setGridValue(i, j, 0);
		}

	}
	for (i = 0; i < length; i++)
	{
		gridworld->setGridValue(i, 0, 2);
		gridworld->setGridValue(i, width - 1, 2);
	}

	dx = (int)ceil((rlt_real)rand() / (rlt_real)RAND_MAX * (rlt_real)dx_max);
	x = 2 + dx;
	while (x < length - 2)
	{
		tmp1 = (int)ceil((rlt_real)rand() / (rlt_real)RAND_MAX * (rlt_real)(width - 1));
		tmp2 = (int)ceil((rlt_real)rand() / (rlt_real)RAND_MAX * (rlt_real)(width - 1));
		y1 = min(tmp1, tmp2);
		y2 = max(tmp1, tmp2);
		if (( y1 < 2) && (y2 > width - 3))
		{
			if (rand() < (RAND_MAX / 2))
			{
				y1 = 2;
				y2 = width - 2;
			}
			else
			{
				y1 = 1;
				y2 = width - 3;
			}
		}
		h = min(dx, (unsigned int)ceil((rlt_real)rand() / (rlt_real)RAND_MAX * (rlt_real)h_max)) - 1;
		for(i = y1; i <= y2; i++)
		{
			gridworld->setGridValue(x, i, 1);
		}
        for(i = x-h+1; i <= x; i++)
		{
			gridworld->setGridValue(i, y1, 1);
			gridworld->setGridValue(i, y2, 1);
		}
		dx = dx_min + (int)ceil((rlt_real)rand() / (rlt_real)RAND_MAX * (rlt_real)dx_max);
		x += dx;
	}
}


CRaceTrackDiscreteState::CRaceTrackDiscreteState(CStateModifier* orig_state, unsigned int neighbourhood, CGridWorld *gridworld) : CAbstractStateDiscretizer((unsigned int)pow((rlt_real) 2.0, (rlt_real) neighbourhood))
{
	this->orig_state = orig_state;
	this->gridworld = gridworld;
}

CRaceTrackDiscreteState::~CRaceTrackDiscreteState()
{
}

unsigned int CRaceTrackDiscreteState::getDiscreteStateNumber(CStateCollection *state)
{
    CState *source_state = state->getState(orig_state);
    unsigned int discstate = 0;
	unsigned int temp;
	for (unsigned int i = 0; i < source_state->getNumDiscreteStates() - 1; i++)
	{
        if (gridworld->getProhibitedValues()->find((char)source_state->getDiscreteState(i)) != gridworld->getProhibitedValues()->end())
			temp = 1;
		else
			temp = 0;
		discstate = discstate * 2 + temp;
	}
	return discstate;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -