📄 cadaptivesoftmaxnetwork.cpp
字号:
memcpy(parameters + arrayPos, (*it).second->centers->getData(), sizeof(rlt_real) * numDim);
arrayPos += numDim;
memcpy(parameters + arrayPos, (*it).second->sigmas->getData(), sizeof(rlt_real) * numDim);
arrayPos += numDim;
}
}
}
void CAdaptiveSoftMaxNetwork::setWeights(rlt_real *parameters)
{
int arrayPos = 0;
std::map<int, CRBFCenter *>::iterator it = centers->begin();
CRBFCenter *rbfCenter = new CRBFCenter(numDim);
clearCenters();
for (int i = 0; i < maxCenters; i ++)
{
memcpy( rbfCenter->centers->getData(),parameters + arrayPos, sizeof(rlt_real) * numDim);
arrayPos += numDim;
memcpy(rbfCenter->sigmas->getData(),parameters + arrayPos, sizeof(rlt_real) * numDim);
arrayPos += numDim;
if (fabs(rbfCenter->sigmas->getElement(0)) > 0.0000001)
{
addRBFCenter(rbfCenter);
rbfCenter = new CRBFCenter(numDim);
}
}
delete rbfCenter;
}
void CAdaptiveSoftMaxNetwork::saveData(FILE *stream)
{
fprintf(stream, "Adaptive RBF Network: %d Centers\n", centers->size());
std::map<int, CRBFCenter *>::iterator it = centers->begin();
for (int i = 0; it != centers->end(); it ++, i++)
{
fprintf(stream, "%d : ", i);
(*it).second->centers->saveASCII(stream);
(*it).second->sigmas->saveASCII(stream);
fprintf(stream, "\n");
}
}
void CAdaptiveSoftMaxNetwork::loadData(FILE *stream)
{
resetData();
int buffer = 0;
int buffer1 = 0;
fscanf(stream, "Adaptive RBF Network: %d Centers\n", &buffer);
for (int i = 0; i < buffer; i++)
{
CRBFCenter *rbfCenter = new CRBFCenter(numDim);
fscanf(stream, "%d : ", &buffer1);
rbfCenter->centers->loadASCII(stream);
rbfCenter->sigmas->loadASCII(stream);
fscanf(stream, "\n");
}
}
void CAdaptiveSoftMaxNetwork::getModifiedState(CStateCollection *state, CState *targetState)
{
CState *modelState = state->getState(originalState);
targetState->resetState();
// search 1st dimension
std::map<int, CRBFCenter *>::iterator it = centers->begin();
searchList1->clear();
searchList2->clear();
rlt_real minVal = modelState->getContinuousState(0) - epsilon->getElement(0);
rlt_real maxVal = modelState->getContinuousState(0) + epsilon->getElement(0);
minVal = modelState->getStateProperties()->getMirroredStateValue(0, minVal);
maxVal = modelState->getStateProperties()->getMirroredStateValue(0, maxVal);
bool periodic = modelState->getStateProperties()->getPeriodicity(0);
DebugPrint('s', "\nBeginning searching for state: ");
if (DebugIsEnabled('s'))
{
modelState->saveASCII(DebugGetFileHandle('s'));
}
DebugPrint('s', "\nsearching Dimension %d, %d centers left\n", 0, centers->size());
DebugPrint('s', "search range: [%f, %f]\n", minVal, maxVal);
for (; it != centers->end(); it ++)
{
CRBFCenter *center = (*it).second;
int num = (*it).first;
if ((*it).second != NULL)
{
rlt_real rbfCenterVal = (*it).second->centers->getElement(0);
if (periodic && maxVal < minVal)
{
if (rbfCenterVal < maxVal || rbfCenterVal > minVal)
{
DebugPrint('s', "Center Number %d is in the area of dimension %d (Value %f)\n", (*it).first, 0, rbfCenterVal);
searchList1->push_back((*it).second);
}
else
{
DebugPrint('s', "Center Number %d is NOT in the area of dimension %d (Value %f)\n", (*it).first, 0, rbfCenterVal);
}
}
else
{
if (rbfCenterVal < maxVal && rbfCenterVal > minVal)
{
DebugPrint('s', "Center Number %d is in the area of dimension %d (Value %f)\n", (*it).first, 0, rbfCenterVal);
searchList1->push_back((*it).second);
}
else
{
DebugPrint('s', "Center Number %d is NOT in the area of dimension %d (Value %f)\n", (*it).first, 0, rbfCenterVal);
}
}
}
else
{
DebugPrint('s', "RBF center is NULL\n");
}
}
// search remaining dimensions
for (int i = 1; i < numDim; i++)
{
minVal = modelState->getContinuousState(i) - epsilon->getElement(i);
maxVal = modelState->getContinuousState(i) + epsilon->getElement(i);
minVal = modelState->getStateProperties()->getMirroredStateValue(i, minVal);
maxVal = modelState->getStateProperties()->getMirroredStateValue(i, maxVal);
DebugPrint('s',"searching Dimension %d, %d centers left\n", i, searchList1->size());
DebugPrint('s',"search range: [%f, %f]\n", minVal, maxVal);
periodic = modelState->getStateProperties()->getPeriodicity(i);
std::list<CRBFCenter *>::iterator itList = searchList1->begin();
searchList2->clear();
for (; itList != searchList1->end(); itList ++)
{
rlt_real rbfCenterVal = (*itList)->centers->getElement(i);
if (periodic && maxVal < minVal)
{
if (rbfCenterVal < maxVal || rbfCenterVal > minVal)
{
DebugPrint('s', "Center Number %d is in the area of dimension %d (Value %f)\n", (*itList)->numCenter, i, rbfCenterVal);
searchList2->push_back((*itList));
}
else
{
DebugPrint('s', "Center Number %d is NOT in the area of dimension %d (Value %f)\n", (*itList)->numCenter, i, rbfCenterVal);
}
}
else
{
if (rbfCenterVal < maxVal && rbfCenterVal > minVal)
{
DebugPrint('s', "Center Number %d is in the area of dimension %d (Value %f)\n", (*itList)->numCenter, i, rbfCenterVal);
searchList2->push_back((*itList));
}
else
{
DebugPrint('s', "Center Number %d is NOT in the area of dimension %d (Value %f)\n", (*itList)->numCenter, i, rbfCenterVal);
}
}
}
std::list<CRBFCenter *> *buffer = searchList2;
searchList2 = searchList1;
searchList1 = buffer;
}
// sorting centers with factors
sortedList->clear();
std::list<CRBFCenter *>::iterator itList = searchList1->begin();
DebugPrint('s', "remaining centers: %d\n", searchList1->size());
for (; itList != searchList1->end(); itList ++)
{
rlt_real factor = (*itList)->getFactor(modelState);
sortedList->set((*itList)->numCenter, factor);
if (DebugIsEnabled('s'))
{
DebugPrint('s', "Center %d, ", (*itList)->numCenter);
(*itList)->centers->saveASCII(DebugGetFileHandle('s'));
DebugPrint('s', ", factor %f\n", factor);
}
}
CFeatureList::iterator featIt = sortedList->begin();
for (int i = 0;i < maxActiveCenters && featIt != sortedList->end(); featIt ++, i++)
{
targetState->setDiscreteState(i, (*featIt)->featureIndex);
targetState->setContinuousState(i, (*featIt)->factor);
}
if (sortedList->size() < maxActiveCenters)
{
targetState->setNumActiveContinuousStates(sortedList->size());
targetState->setNumActiveDiscreteStates(sortedList->size());
}
normalizeFeatures(targetState);
}
CAdaptiveSoftMaxVFunction::CAdaptiveSoftMaxVFunction(CAdaptiveSoftMaxNetwork *adaptiveSoftMaxNetwork) : CFeatureVFunction(adaptiveSoftMaxNetwork)
{
this->adaptiveSoftMaxNetwork = adaptiveSoftMaxNetwork;
addParameters(adaptiveSoftMaxNetwork);
gradient1List = new CFeatureList();
gradient2List = new CFeatureList();
}
CAdaptiveSoftMaxVFunction::~CAdaptiveSoftMaxVFunction()
{
delete gradient1List;
delete gradient2List;
}
void CAdaptiveSoftMaxVFunction::updateWeights(CFeatureList *gradientFeatures)
{
CFeatureList::iterator it = gradientFeatures->begin();
gradient1List->clear();
gradient2List->clear();
for (;it != gradientFeatures->end(); it ++)
{
if ((*it)->featureIndex < numFeatures)
{
gradient1List->add(*it);
}
else
{
gradient2List->set((*it)->featureIndex - numFeatures, (*it)->factor);
}
}
CFeatureVFunction::updateWeights(gradient1List);
adaptiveSoftMaxNetwork->updateWeights(gradient2List);
}
CAbstractVETraces *CAdaptiveSoftMaxVFunction::getStandardETraces()
{
return new CGradientVETraces(this);
}
void CAdaptiveSoftMaxVFunction::getGradient(CStateCollection *state, CFeatureList *gradientFeatures)
{
CFeatureVFunction::getGradient(state, gradientFeatures);
CState *featState = state->getState(properties);
DebugPrint('s', "Beginning Ada Gradient Calculation\n ");
if (DebugIsEnabled('s'))
{
state->getState()->saveASCII(DebugGetFileHandle('s'));
DebugPrint('s', "\n");
}
for (int i = 0; i < featState->getNumActiveContinuousStates(); i++)
{
gradient1List->clear();
adaptiveSoftMaxNetwork->getGradient(state, i, gradient1List);
DebugPrint('s', "Adaptive RBF Gradient for feature %d: ", featState->getDiscreteState(i));
if (DebugIsEnabled('s'))
{
gradient1List->saveASCII(DebugGetFileHandle('s'));
DebugPrint('s', "\n");
}
gradientFeatures->add(gradient1List, getFeature(featState->getDiscreteState(i)));
}
DebugPrint('s', "Adaptive RBF Gradient: ");
if (DebugIsEnabled('s'))
{
gradientFeatures->saveASCII(DebugGetFileHandle('s'));
DebugPrint('s', "\n");
}
}
int CAdaptiveSoftMaxVFunction::getNumWeights()
{
return numFeatures + adaptiveSoftMaxNetwork->getNumWeights();
}
void CAdaptiveSoftMaxVFunction::resetData()
{
CFeatureVFunction::resetData();
adaptiveSoftMaxNetwork->resetData();
}
void CAdaptiveSoftMaxVFunction::saveData(FILE *stream)
{
CFeatureVFunction::saveData(stream);
adaptiveSoftMaxNetwork->saveData(stream);
}
void CAdaptiveSoftMaxVFunction::loadData(FILE *stream)
{
CFeatureVFunction::loadData(stream);
adaptiveSoftMaxNetwork->loadData(stream);
}
void CAdaptiveSoftMaxVFunction::getWeights(rlt_real *parameters)
{
CFeatureVFunction::getWeights(parameters);
adaptiveSoftMaxNetwork->getWeights(parameters + numFeatures);
}
void CAdaptiveSoftMaxVFunction::setWeights(rlt_real *parameters)
{
CFeatureVFunction::setWeights(parameters);
adaptiveSoftMaxNetwork->setWeights(parameters + numFeatures);
}
void CAdaptiveSoftMaxVFunction::receiveError(rlt_real error, CStateCollection *state, CAction *action, CActionData *data)
{
rlt_real value = getValue(state->getState(properties));
int newCenter = adaptiveSoftMaxNetwork->addCenterOnError(error, state);
if (newCenter >= 0)
{
setFeature(newCenter, value);
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -