📄 cgridworldmodel.cpp
字号:
/*
void CGridWorldModel::setActualBounces(unsigned int value)
{
this->actual_bounces = value;
}
unsigned int CGridWorldModel::getActualBounces()
{
return actual_bounces;
}
void CGridWorldModel::setPosX(unsigned int value)
{
if (value >= (unsigned int)size_x)
throw new std::invalid_argument("GridWorldModel_setPosX: ungueltiger Parameter!");
pos_x = value;
}
int CGridWorldModel::getPosX()
{
return pos_x;
}
void CGridWorldModel::setPosY(unsigned int value)
{
if (value >= (unsigned int)size_y)
throw new std::invalid_argument("GridWorldModel_setPosY: ungueltiger Parameter!");
pos_y = value;
}
int CGridWorldModel::getPosY()
{
return pos_y;
}
*/
void CGridWorldModel::setRewardStandard(rlt_real value)
{
this->reward_standard = value;
}
rlt_real CGridWorldModel::getRewardStandard()
{
return this->reward_standard;
}
void CGridWorldModel::setRewardSuccess(rlt_real value)
{
this->reward_success = value;
}
rlt_real CGridWorldModel::getRewardSuccess()
{
return this->reward_success;
}
void CGridWorldModel::setRewardBounce(rlt_real value)
{
reward_bounce = value;
}
rlt_real CGridWorldModel::getRewardBounce()
{
return reward_bounce;
}
void CGridWorldModel::setRewardForSymbol(char symbol, rlt_real reward)
{
(*rewards)[symbol] = reward;
}
rlt_real CGridWorldModel::getRewardForSymbol(char symbol)
{
std::map<char, rlt_real>::iterator it = rewards->find(symbol);
rlt_real rew = 0.0;
if (it != rewards->end())
{
rew = (*it).second;
}
else
{
if (target_values->find(symbol) != target_values->end())
{
rew = reward_success;
}
else
{
rew = reward_standard;
}
}
return rew;
}
void CGridWorldModel::parseGrid()
{
if (isValid() && !is_parsed)
{
for(unsigned int h = 0; h < start_points->size(); h++)
{
delete (*start_points)[h];
}
start_points->clear();
for (int j = 0; j < size_y; j++)
{
for (int i = 0; i < size_x; i++)
{
if(start_values->find(getGridValue(i, j)) != start_values->end())
start_points->push_back(new std::pair<int, int>(j, i));
}
}
}
is_parsed = true;
}
void CGridWorldModel::load(FILE *stream)
{
CGridWorld::load(stream);
is_parsed = false;
}
void CGridWorldModel::initGrid()
{
CGridWorld::initGrid();
is_parsed = false;
}
void CGridWorldModel::setGridValue(unsigned int pos_x, unsigned int pos_y, char value)
{
CGridWorld::setGridValue(pos_x, pos_y, value);
is_parsed = false;
}
void CGridWorldModel::addStartValue(char value)
{
CGridWorld::addStartValue(value);
is_parsed = false;
}
void CGridWorldModel::removeStartValue(char value)
{
CGridWorld::removeStartValue(value);
is_parsed = false;
}
void CGridWorldModel::transitionFunction(CState *oldstate, CAction *action, CState *newState, CActionData *data)
{
int pos_x = oldstate->getDiscreteState(0);
int pos_y = oldstate->getDiscreteState(1);
int actual_bounces = oldstate->getDiscreteState(2);
CGridWorldAction* gridAction = dynamic_cast<CGridWorldAction*>(action);
int tmp_pos_x = pos_x + gridAction->getXMove();
int tmp_pos_y = pos_y + gridAction->getYMove();
if (tmp_pos_x < 0 || tmp_pos_x >= size_x || tmp_pos_y < 0 || tmp_pos_y >= size_y)
{
actual_bounces++;
}
else if (prohibited_values->find(getGridValue(tmp_pos_x, tmp_pos_y)) != prohibited_values->end())
{
actual_bounces++;
}
else
{
pos_x = tmp_pos_x;
pos_y = tmp_pos_y;
}
newState->setDiscreteState(0, pos_x);
newState->setDiscreteState(1, pos_y);
newState->setDiscreteState(2,actual_bounces);
}
bool CGridWorldModel::isResetState(CState *state)
{
return ((unsigned int) state->getDiscreteState(2) > this->max_bounces) || (target_values->find(getGridValue(state->getDiscreteState(0), state->getDiscreteState(1))) != target_values->end());
}
bool CGridWorldModel::isFailedState(CState *state)
{
return ((unsigned int) state->getDiscreteState(2) > this->max_bounces);
}
void CGridWorldModel::getResetState(CState *resetState)
{
if (!is_parsed)
parseGrid();
if (start_points->size() > 0)
{
int i = rand() % start_points->size();
resetState->setDiscreteState(0, (*start_points)[i]->second);
resetState->setDiscreteState(1, (*start_points)[i]->first);
}
else
{
resetState->setDiscreteState(0, 0);
resetState->setDiscreteState(1, 0);
}
resetState->setDiscreteState(2, 0);
}
rlt_real CGridWorldModel::getReward(CStateCollection *oldState, CAction *action, CStateCollection *newState) {
rlt_real rew = 0.0;
if (newState->getState()->getDiscreteState(2) > oldState->getState()->getDiscreteState(2))
{
rew = reward_bounce;
}
else
{
int x = newState->getState()->getDiscreteState(0);
int y = newState->getState()->getDiscreteState(1);
rew = getRewardForSymbol(getGridValue(x, y));
}
return rew;
}
CLocal4GridWorldState::CLocal4GridWorldState(CGridWorld* grid_world) : CStateModifier(0, 4)
{
this->grid_world = grid_world;
for (int i = 0; i < 4; i++)
{
setDiscreteStateSize(i, grid_world->getUsedValues()->size());
}
}
CLocal4GridWorldState::~CLocal4GridWorldState()
{
}
void CLocal4GridWorldState::getModifiedState(CStateCollection *originalState, CState *state)
{
int pos_x = originalState->getState()->getDiscreteState(0);
int pos_y = originalState->getState()->getDiscreteState(1);
if (pos_y > 0)
state->setDiscreteState(0, grid_world->getGridValue(pos_x, pos_y - 1));
else
state->setDiscreteState(0, grid_world->getGridValue(pos_x, pos_y));
if (pos_x < (int)grid_world->getSizeX() - 1)
state->setDiscreteState(1, grid_world->getGridValue(pos_x + 1, pos_y));
else
state->setDiscreteState(1, grid_world->getGridValue(pos_x, pos_y));
if (pos_y < (int)grid_world->getSizeY() - 1)
state->setDiscreteState(2, grid_world->getGridValue(pos_x, pos_y + 1));
else
state->setDiscreteState(2, grid_world->getGridValue(pos_x, pos_y));
if (pos_x > 0)
state->setDiscreteState(3, grid_world->getGridValue(pos_x - 1, pos_y));
else
state->setDiscreteState(3, grid_world->getGridValue(pos_x, pos_y));
}
CLocal4XGridWorldState::CLocal4XGridWorldState(CGridWorld* grid_world) : CStateModifier(0, 4)
{
this->grid_world = grid_world;
for (int i = 0; i < 4; i++)
{
setDiscreteStateSize(i, grid_world->getUsedValues()->size());
}
}
CLocal4XGridWorldState::~CLocal4XGridWorldState()
{
}
void CLocal4XGridWorldState::getModifiedState(CStateCollection *originalState, CState *state)
{
int pos_x = originalState->getState()->getDiscreteState(0);
int pos_y = originalState->getState()->getDiscreteState(1);
if ((pos_y > 0) && (pos_x > 0))
state->setDiscreteState(0, grid_world->getGridValue(pos_x - 1, pos_y - 1));
else
state->setDiscreteState(0, grid_world->getGridValue(pos_x, pos_y));
if ((pos_y > 0) && (pos_x < (int)grid_world->getSizeX() - 1))
state->setDiscreteState(1, grid_world->getGridValue(pos_x + 1, pos_y - 1));
else
state->setDiscreteState(1, grid_world->getGridValue(pos_x, pos_y));
if ((pos_y < (int)grid_world->getSizeY() - 1) && (pos_x < (int)grid_world->getSizeX() - 1))
state->setDiscreteState(2, grid_world->getGridValue(pos_x + 1, pos_y + 1));
else
state->setDiscreteState(2, grid_world->getGridValue(pos_x, pos_y));
if ((pos_y < (int)grid_world->getSizeY() - 1) && (pos_x > 0))
state->setDiscreteState(3, grid_world->getGridValue(pos_x - 1, pos_y + 1));
else
state->setDiscreteState(3, grid_world->getGridValue(pos_x, pos_y));
}
CLocal8GridWorldState::CLocal8GridWorldState(CGridWorld* grid_world) : CStateModifier(0, 8)
{
this->grid_world = grid_world;
for (int i = 0; i < 8; i++)
{
setDiscreteStateSize(i, grid_world->getUsedValues()->size());
}
}
CLocal8GridWorldState::~CLocal8GridWorldState()
{
}
void CLocal8GridWorldState::getModifiedState(CStateCollection *originalState, CState *state)
{
int pos_x = originalState->getState()->getDiscreteState(0);
int pos_y = originalState->getState()->getDiscreteState(1);
if (pos_y > 0)
state->setDiscreteState(0, grid_world->getGridValue(pos_x, pos_y - 1));
else
state->setDiscreteState(0, grid_world->getGridValue(pos_x, pos_y));
if ((pos_y > 0) && (pos_x < (int)grid_world->getSizeX() - 1))
state->setDiscreteState(1, grid_world->getGridValue(pos_x + 1, pos_y - 1));
else
state->setDiscreteState(1, grid_world->getGridValue(pos_x, pos_y));
if (pos_x < (int)grid_world->getSizeX() - 1)
state->setDiscreteState(2, grid_world->getGridValue(pos_x + 1, pos_y));
else
state->setDiscreteState(2, grid_world->getGridValue(pos_x, pos_y));
if ((pos_y < (int)grid_world->getSizeY() - 1) && (pos_x < (int)grid_world->getSizeX() - 1))
state->setDiscreteState(3, grid_world->getGridValue(pos_x + 1, pos_y + 1));
else
state->setDiscreteState(3, grid_world->getGridValue(pos_x, pos_y));
if (pos_y < (int)grid_world->getSizeY() - 1)
state->setDiscreteState(4, grid_world->getGridValue(pos_x, pos_y + 1));
else
state->setDiscreteState(4, grid_world->getGridValue(pos_x, pos_y));
if ((pos_y < (int)grid_world->getSizeY() - 1) && (pos_x > 0))
state->setDiscreteState(5, grid_world->getGridValue(pos_x - 1, pos_y + 1));
else
state->setDiscreteState(5, grid_world->getGridValue(pos_x, pos_y));
if (pos_x > 0)
state->setDiscreteState(6, grid_world->getGridValue(pos_x - 1, pos_y));
else
state->setDiscreteState(6, grid_world->getGridValue(pos_x, pos_y));
if ((pos_y > 0) && (pos_x > 0))
state->setDiscreteState(7, grid_world->getGridValue(pos_x - 1, pos_y - 1));
else
state->setDiscreteState(7, grid_world->getGridValue(pos_x, pos_y));
}
CGlobalGridWorldDiscreteState::CGlobalGridWorldDiscreteState(unsigned int size_x, unsigned int size_y) : CAbstractStateDiscretizer(size_x * size_y + 1)
{
this->size_x = size_x;
this->size_y = size_y;
}
unsigned int CGlobalGridWorldDiscreteState::getDiscreteStateNumber(CStateCollection *state) {
unsigned int discstate;
int x = state->getState()->getDiscreteState(0);
int y = state->getState()->getDiscreteState(1);
if (x < 0 || (unsigned int)x >= size_x || y < 0 || (unsigned int)y >= size_y)
{
discstate = 0;
}
else
{
discstate = y * size_x + x + 1;
}
return discstate;
}
CLocalGridWorldDiscreteState::CLocalGridWorldDiscreteState(CStateModifier* orig_state, unsigned int neighbourhood, std::set<char> *possible_values) : CAbstractStateDiscretizer((int)pow((rlt_real) possible_values->size(), (rlt_real) neighbourhood))
{
this->orig_state = orig_state;
valuemap = new std::map<char, short>();
std::set<char>::iterator it = possible_values->begin();
for(short i = 0; it != possible_values->end(); i++, it++)
{
(*valuemap)[(*it)] = i;
}
}
CLocalGridWorldDiscreteState::~CLocalGridWorldDiscreteState()
{
valuemap->clear();
delete valuemap;
}
unsigned int CLocalGridWorldDiscreteState::getDiscreteStateNumber(CStateCollection *state)
{
CState *source_state = state->getState(orig_state);
unsigned int discstate = 0;
for (unsigned int i = 0; i < source_state->getNumDiscreteStates() - 1; i++)
{
discstate = discstate * valuemap->size() + (unsigned int)((*valuemap)[(char)source_state->getDiscreteState(i)]);
}
return discstate;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -