⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cgridworldmodel.cpp

📁 强化学习算法(R-Learning)难得的珍贵资料
💻 CPP
📖 第 1 页 / 共 3 页
字号:
/*
void CGridWorldModel::setActualBounces(unsigned int value)
{
	this->actual_bounces = value;
}

unsigned int CGridWorldModel::getActualBounces()
{
	return actual_bounces;
}

void CGridWorldModel::setPosX(unsigned int value)
{
	if (value >= (unsigned int)size_x)
		throw new std::invalid_argument("GridWorldModel_setPosX: ungueltiger Parameter!");
	pos_x = value;
}

int CGridWorldModel::getPosX()
{
	return pos_x;
}

void CGridWorldModel::setPosY(unsigned int value)
{
	if (value >= (unsigned int)size_y)
		throw new std::invalid_argument("GridWorldModel_setPosY: ungueltiger Parameter!");
	pos_y = value;
}

int CGridWorldModel::getPosY()
{
	return pos_y;
}
*/
void CGridWorldModel::setRewardStandard(rlt_real value)
{
	this->reward_standard = value;
	
}

rlt_real CGridWorldModel::getRewardStandard()
{
	return this->reward_standard;
}

void CGridWorldModel::setRewardSuccess(rlt_real value)
{
	this->reward_success = value;


}

rlt_real CGridWorldModel::getRewardSuccess()
{
	return this->reward_success;
}

void CGridWorldModel::setRewardBounce(rlt_real value)
{
	reward_bounce = value;
}


rlt_real CGridWorldModel::getRewardBounce()
{
	return reward_bounce;
}

void CGridWorldModel::setRewardForSymbol(char symbol, rlt_real reward)
{
	(*rewards)[symbol] = reward;
}

rlt_real CGridWorldModel::getRewardForSymbol(char symbol)
{
	std::map<char, rlt_real>::iterator it = rewards->find(symbol);
	rlt_real rew = 0.0;

	if (it != rewards->end())
	{
		rew = (*it).second;
	}
	else
	{
		if (target_values->find(symbol) != target_values->end())
		{
			rew = reward_success;
		}
		else
		{
			rew = reward_standard;
		}
	}
	return rew;
}



void CGridWorldModel::parseGrid()
{
	if (isValid() && !is_parsed)
	{
		for(unsigned int h = 0; h < start_points->size(); h++)
		{
			delete (*start_points)[h];
		}
		start_points->clear();
	    for (int j = 0; j < size_y; j++)
		{
			for (int i = 0; i < size_x; i++)
			{
				if(start_values->find(getGridValue(i, j)) != start_values->end())
					start_points->push_back(new std::pair<int, int>(j, i));
			}
		}
	}
	is_parsed = true;
}

void CGridWorldModel::load(FILE *stream)
{
	CGridWorld::load(stream);
	is_parsed = false;
}

void CGridWorldModel::initGrid()
{
	CGridWorld::initGrid();
	is_parsed = false;
}

void CGridWorldModel::setGridValue(unsigned int pos_x, unsigned int pos_y, char value)
{
	CGridWorld::setGridValue(pos_x, pos_y, value);
	is_parsed = false;
}

void CGridWorldModel::addStartValue(char value)
{
	CGridWorld::addStartValue(value);
	is_parsed = false;
}

void CGridWorldModel::removeStartValue(char value)
{
	CGridWorld::removeStartValue(value);
	is_parsed = false;
}

void CGridWorldModel::transitionFunction(CState *oldstate, CAction *action, CState *newState, CActionData *data)
{
	int pos_x = oldstate->getDiscreteState(0);
	int pos_y = oldstate->getDiscreteState(1);
	int actual_bounces = oldstate->getDiscreteState(2);

	CGridWorldAction* gridAction = dynamic_cast<CGridWorldAction*>(action);

	int tmp_pos_x = pos_x + gridAction->getXMove();
	int tmp_pos_y = pos_y + gridAction->getYMove();

	if (tmp_pos_x < 0 || tmp_pos_x >= size_x || tmp_pos_y < 0 || tmp_pos_y >= size_y)
	{
		actual_bounces++;
	}
	else if (prohibited_values->find(getGridValue(tmp_pos_x, tmp_pos_y)) != prohibited_values->end())
	{
		actual_bounces++;
	}
	else
	{
		pos_x = tmp_pos_x;
		pos_y = tmp_pos_y;
	}

	newState->setDiscreteState(0, pos_x);
	newState->setDiscreteState(1, pos_y);
	newState->setDiscreteState(2,actual_bounces);
}

bool CGridWorldModel::isResetState(CState *state)
{
	return ((unsigned int) state->getDiscreteState(2) > this->max_bounces) || (target_values->find(getGridValue(state->getDiscreteState(0), state->getDiscreteState(1))) != target_values->end());
}

bool CGridWorldModel::isFailedState(CState *state)
{
	return ((unsigned int) state->getDiscreteState(2) > this->max_bounces);
}

void CGridWorldModel::getResetState(CState *resetState)
{
	if (!is_parsed)
		parseGrid();
	if (start_points->size() > 0)
	{
		int i = rand() % start_points->size();
		resetState->setDiscreteState(0, (*start_points)[i]->second);
		resetState->setDiscreteState(1, (*start_points)[i]->first);
	}
	else
	{
		resetState->setDiscreteState(0, 0);
		resetState->setDiscreteState(1, 0);
	}
	resetState->setDiscreteState(2, 0);
}

rlt_real CGridWorldModel::getReward(CStateCollection *oldState, CAction *action, CStateCollection *newState) {
	rlt_real rew = 0.0;
	if (newState->getState()->getDiscreteState(2) > oldState->getState()->getDiscreteState(2))
	{
		rew = reward_bounce;
	}
	else
	{
		int x = newState->getState()->getDiscreteState(0);
		int y = newState->getState()->getDiscreteState(1);

		rew = getRewardForSymbol(getGridValue(x, y));
	}
    return rew;
} 

CLocal4GridWorldState::CLocal4GridWorldState(CGridWorld* grid_world) : CStateModifier(0, 4)
{
	this->grid_world = grid_world;
	for (int i = 0; i < 4; i++)
	{
		setDiscreteStateSize(i, grid_world->getUsedValues()->size());
	}
}

CLocal4GridWorldState::~CLocal4GridWorldState()
{
}

void CLocal4GridWorldState::getModifiedState(CStateCollection *originalState, CState *state)
{
	int pos_x = originalState->getState()->getDiscreteState(0);
	int pos_y = originalState->getState()->getDiscreteState(1);
	
	if (pos_y > 0)
		state->setDiscreteState(0, grid_world->getGridValue(pos_x, pos_y - 1));
	else
		state->setDiscreteState(0, grid_world->getGridValue(pos_x, pos_y));

	if (pos_x < (int)grid_world->getSizeX() - 1)
		state->setDiscreteState(1, grid_world->getGridValue(pos_x + 1, pos_y));
	else
		state->setDiscreteState(1, grid_world->getGridValue(pos_x, pos_y));

	if (pos_y < (int)grid_world->getSizeY() - 1)
		state->setDiscreteState(2, grid_world->getGridValue(pos_x, pos_y + 1));
	else
		state->setDiscreteState(2, grid_world->getGridValue(pos_x, pos_y));

	if (pos_x > 0)
		state->setDiscreteState(3, grid_world->getGridValue(pos_x - 1, pos_y));
	else
		state->setDiscreteState(3, grid_world->getGridValue(pos_x, pos_y));
} 


CLocal4XGridWorldState::CLocal4XGridWorldState(CGridWorld* grid_world) : CStateModifier(0, 4)
{
	this->grid_world = grid_world;
	for (int i = 0; i < 4; i++)
	{
		setDiscreteStateSize(i, grid_world->getUsedValues()->size());
	}
}

CLocal4XGridWorldState::~CLocal4XGridWorldState()
{
}

void CLocal4XGridWorldState::getModifiedState(CStateCollection *originalState, CState *state)
{
	int pos_x = originalState->getState()->getDiscreteState(0);
	int pos_y = originalState->getState()->getDiscreteState(1);
	
	if ((pos_y > 0) && (pos_x > 0))
		state->setDiscreteState(0, grid_world->getGridValue(pos_x - 1, pos_y - 1));
	else
		state->setDiscreteState(0, grid_world->getGridValue(pos_x, pos_y));

	if ((pos_y > 0) && (pos_x < (int)grid_world->getSizeX() - 1))
		state->setDiscreteState(1, grid_world->getGridValue(pos_x + 1, pos_y - 1));
	else
		state->setDiscreteState(1, grid_world->getGridValue(pos_x, pos_y));

	if ((pos_y < (int)grid_world->getSizeY() - 1) && (pos_x < (int)grid_world->getSizeX() - 1))
		state->setDiscreteState(2, grid_world->getGridValue(pos_x + 1, pos_y + 1));
	else
		state->setDiscreteState(2, grid_world->getGridValue(pos_x, pos_y));

	if ((pos_y < (int)grid_world->getSizeY() - 1) && (pos_x > 0))
		state->setDiscreteState(3, grid_world->getGridValue(pos_x - 1, pos_y + 1));
	else
		state->setDiscreteState(3, grid_world->getGridValue(pos_x, pos_y));
} 


CLocal8GridWorldState::CLocal8GridWorldState(CGridWorld* grid_world) : CStateModifier(0, 8)
{
	this->grid_world = grid_world;
	for (int i = 0; i < 8; i++)
	{
		setDiscreteStateSize(i, grid_world->getUsedValues()->size());
	}
}

CLocal8GridWorldState::~CLocal8GridWorldState()
{
}

void CLocal8GridWorldState::getModifiedState(CStateCollection *originalState, CState *state)
{
	int pos_x = originalState->getState()->getDiscreteState(0);
	int pos_y = originalState->getState()->getDiscreteState(1);

	if (pos_y > 0)
		state->setDiscreteState(0, grid_world->getGridValue(pos_x, pos_y - 1));
	else
		state->setDiscreteState(0, grid_world->getGridValue(pos_x, pos_y));

	if ((pos_y > 0) && (pos_x < (int)grid_world->getSizeX() - 1))
		state->setDiscreteState(1, grid_world->getGridValue(pos_x + 1, pos_y - 1));
	else
		state->setDiscreteState(1, grid_world->getGridValue(pos_x, pos_y));

	if (pos_x < (int)grid_world->getSizeX() - 1)
		state->setDiscreteState(2, grid_world->getGridValue(pos_x + 1, pos_y));
	else
		state->setDiscreteState(2, grid_world->getGridValue(pos_x, pos_y));

	if ((pos_y < (int)grid_world->getSizeY() - 1) && (pos_x < (int)grid_world->getSizeX() - 1))
		state->setDiscreteState(3, grid_world->getGridValue(pos_x + 1, pos_y + 1));
	else
		state->setDiscreteState(3, grid_world->getGridValue(pos_x, pos_y));

	if (pos_y < (int)grid_world->getSizeY() - 1)
		state->setDiscreteState(4, grid_world->getGridValue(pos_x, pos_y + 1));
	else
		state->setDiscreteState(4, grid_world->getGridValue(pos_x, pos_y));

	if ((pos_y < (int)grid_world->getSizeY() - 1) && (pos_x > 0))
		state->setDiscreteState(5, grid_world->getGridValue(pos_x - 1, pos_y + 1));
	else
		state->setDiscreteState(5, grid_world->getGridValue(pos_x, pos_y));

	if (pos_x > 0)
		state->setDiscreteState(6, grid_world->getGridValue(pos_x - 1, pos_y));
	else
		state->setDiscreteState(6, grid_world->getGridValue(pos_x, pos_y));

	if ((pos_y > 0) && (pos_x > 0))
		state->setDiscreteState(7, grid_world->getGridValue(pos_x - 1, pos_y - 1));
	else
		state->setDiscreteState(7, grid_world->getGridValue(pos_x, pos_y));
} 


CGlobalGridWorldDiscreteState::CGlobalGridWorldDiscreteState(unsigned int size_x, unsigned int size_y) : CAbstractStateDiscretizer(size_x * size_y + 1)
{
	this->size_x = size_x;
	this->size_y = size_y;
}

unsigned int CGlobalGridWorldDiscreteState::getDiscreteStateNumber(CStateCollection *state) {
    unsigned int discstate;
	int x = state->getState()->getDiscreteState(0);
	int y = state->getState()->getDiscreteState(1);

	if (x < 0 || (unsigned int)x >= size_x || y < 0 || (unsigned int)y >= size_y)
	{
		discstate = 0;
    }
	else
	{
		discstate = y * size_x + x + 1;
	}
	return discstate;
}


CLocalGridWorldDiscreteState::CLocalGridWorldDiscreteState(CStateModifier* orig_state, unsigned int neighbourhood, std::set<char> *possible_values) : CAbstractStateDiscretizer((int)pow((rlt_real) possible_values->size(), (rlt_real) neighbourhood))
{
	this->orig_state = orig_state;
	valuemap = new std::map<char, short>();
	std::set<char>::iterator it = possible_values->begin();
	for(short i = 0; it != possible_values->end(); i++, it++)
	{
		(*valuemap)[(*it)] = i;
	}
}

CLocalGridWorldDiscreteState::~CLocalGridWorldDiscreteState()
{
	valuemap->clear();
	delete valuemap;
}

unsigned int CLocalGridWorldDiscreteState::getDiscreteStateNumber(CStateCollection *state)
{
    CState *source_state = state->getState(orig_state);
    unsigned int discstate = 0;
	for (unsigned int i = 0; i < source_state->getNumDiscreteStates() - 1; i++)
	{
		discstate = discstate * valuemap->size() + (unsigned int)((*valuemap)[(char)source_state->getDiscreteState(i)]);
	}
	return discstate;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -