⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cadaptivesoftmaxnetwork.cpp

📁 强化学习算法(R-Learning)难得的珍贵资料
💻 CPP
📖 第 1 页 / 共 2 页
字号:
 			memcpy(parameters + arrayPos, (*it).second->centers->getData(), sizeof(rlt_real) * numDim);
			arrayPos += numDim;
			memcpy(parameters + arrayPos, (*it).second->sigmas->getData(), sizeof(rlt_real) * numDim);
			arrayPos += numDim;

		}
	}
}

void CAdaptiveSoftMaxNetwork::setWeights(rlt_real *parameters)
{
	int arrayPos = 0;

	std::map<int, CRBFCenter *>::iterator it = centers->begin();
	
	CRBFCenter *rbfCenter = new CRBFCenter(numDim);
	
	clearCenters();

	for (int i = 0; i < maxCenters; i ++)
	{
		memcpy( rbfCenter->centers->getData(),parameters + arrayPos, sizeof(rlt_real) * numDim);
		arrayPos += numDim;
		memcpy(rbfCenter->sigmas->getData(),parameters + arrayPos, sizeof(rlt_real) * numDim);
		arrayPos += numDim;

		if (fabs(rbfCenter->sigmas->getElement(0)) > 0.0000001)
		{
			addRBFCenter(rbfCenter);
			rbfCenter = new CRBFCenter(numDim);
		}
	}
	delete rbfCenter;
}

void CAdaptiveSoftMaxNetwork::saveData(FILE *stream)
{
	fprintf(stream, "Adaptive RBF Network: %d Centers\n", centers->size());
	std::map<int, CRBFCenter *>::iterator it = centers->begin();


	for (int i = 0; it !=  centers->end(); it ++, i++)
	{
		fprintf(stream, "%d : ", i);
		(*it).second->centers->saveASCII(stream);
		(*it).second->sigmas->saveASCII(stream);
		fprintf(stream, "\n");
	}
}

void CAdaptiveSoftMaxNetwork::loadData(FILE *stream)
{
	resetData();
	int buffer = 0;
	int buffer1 = 0;
	fscanf(stream, "Adaptive RBF Network: %d Centers\n", &buffer);

	for (int i = 0; i < buffer; i++)
	{
		CRBFCenter *rbfCenter = new CRBFCenter(numDim);

		fscanf(stream, "%d : ", &buffer1);
		rbfCenter->centers->loadASCII(stream);
		rbfCenter->sigmas->loadASCII(stream);
		fscanf(stream, "\n");
	}
}

void CAdaptiveSoftMaxNetwork::getModifiedState(CStateCollection *state, CState *targetState)
{
	CState *modelState = state->getState(originalState);

	targetState->resetState();

	// search 1st dimension
	std::map<int, CRBFCenter *>::iterator it = centers->begin();

	searchList1->clear();
	searchList2->clear();

	rlt_real minVal = modelState->getContinuousState(0) - epsilon->getElement(0);
	rlt_real maxVal = modelState->getContinuousState(0) + epsilon->getElement(0);

	minVal = modelState->getStateProperties()->getMirroredStateValue(0, minVal);
	maxVal = modelState->getStateProperties()->getMirroredStateValue(0, maxVal);

	bool periodic = modelState->getStateProperties()->getPeriodicity(0);

	DebugPrint('s', "\nBeginning searching for state: ");
	if (DebugIsEnabled('s'))
	{
		modelState->saveASCII(DebugGetFileHandle('s'));
	}
	DebugPrint('s', "\nsearching Dimension %d, %d centers left\n", 0, centers->size());
	DebugPrint('s', "search range: [%f, %f]\n", minVal, maxVal);

	for (; it != centers->end(); it ++)
	{
		CRBFCenter *center = (*it).second;
		int num = (*it).first;


		if ((*it).second != NULL)
		{
			rlt_real rbfCenterVal = (*it).second->centers->getElement(0);

			if (periodic && maxVal < minVal)
			{
				if (rbfCenterVal < maxVal || rbfCenterVal > minVal)
				{
					DebugPrint('s', "Center Number %d is in the area of dimension %d (Value %f)\n", (*it).first, 0, rbfCenterVal);
					searchList1->push_back((*it).second);
				}
				else
				{
					DebugPrint('s', "Center Number %d is NOT in the area of dimension %d (Value %f)\n", (*it).first, 0, rbfCenterVal);
				}
			}
			else
			{
				if (rbfCenterVal < maxVal && rbfCenterVal > minVal)
				{
					DebugPrint('s', "Center Number %d is in the area of dimension %d (Value %f)\n", (*it).first, 0, rbfCenterVal);
					searchList1->push_back((*it).second);
				}
				else
				{
					DebugPrint('s', "Center Number %d is NOT in the area of dimension %d (Value %f)\n", (*it).first, 0, rbfCenterVal);
				}
			}
		}
		else
		{
			DebugPrint('s', "RBF center is NULL\n");
		}
	}

	// search remaining dimensions


	for (int i = 1; i < numDim; i++)
	{
		minVal = modelState->getContinuousState(i) - epsilon->getElement(i);
		maxVal = modelState->getContinuousState(i) + epsilon->getElement(i);

		minVal = modelState->getStateProperties()->getMirroredStateValue(i, minVal);
		maxVal = modelState->getStateProperties()->getMirroredStateValue(i, maxVal);

		DebugPrint('s',"searching Dimension %d, %d centers left\n", i, searchList1->size());
		DebugPrint('s',"search range: [%f, %f]\n", minVal, maxVal);


		periodic = modelState->getStateProperties()->getPeriodicity(i);

		std::list<CRBFCenter *>::iterator itList = searchList1->begin();

		searchList2->clear();

		for (; itList != searchList1->end(); itList ++)
		{
			rlt_real rbfCenterVal = (*itList)->centers->getElement(i);

			if (periodic && maxVal < minVal)
			{
				if (rbfCenterVal < maxVal || rbfCenterVal > minVal)
				{
					DebugPrint('s', "Center Number %d is in the area of dimension %d (Value %f)\n", (*itList)->numCenter, i, rbfCenterVal);
					searchList2->push_back((*itList));
				}
				else
				{
					DebugPrint('s', "Center Number %d is NOT in the area of dimension %d (Value %f)\n", (*itList)->numCenter, i, rbfCenterVal);
				}
			}
			else
			{
				if (rbfCenterVal < maxVal && rbfCenterVal > minVal)
				{
					DebugPrint('s', "Center Number %d is in the area of dimension %d (Value %f)\n", (*itList)->numCenter, i, rbfCenterVal);
					searchList2->push_back((*itList));
				}
				else
				{
					DebugPrint('s', "Center Number %d is NOT in the area of dimension %d (Value %f)\n", (*itList)->numCenter, i, rbfCenterVal);
				}
			}
		}

		std::list<CRBFCenter *> *buffer = searchList2;
		searchList2 = searchList1;
		searchList1 = buffer;
	}

	// sorting centers with factors
	sortedList->clear();
	
	std::list<CRBFCenter *>::iterator itList = searchList1->begin();

	DebugPrint('s', "remaining centers: %d\n", searchList1->size());

	for (; itList != searchList1->end(); itList ++)
	{
		rlt_real factor = (*itList)->getFactor(modelState);
		sortedList->set((*itList)->numCenter, factor);

		if (DebugIsEnabled('s'))
		{
			DebugPrint('s', "Center %d, ", (*itList)->numCenter);
			(*itList)->centers->saveASCII(DebugGetFileHandle('s'));
			DebugPrint('s', ", factor %f\n", factor);
		}
	}

	CFeatureList::iterator featIt = sortedList->begin();

	

	for (int i = 0;i < maxActiveCenters && featIt != sortedList->end(); featIt  ++, i++)
	{
		targetState->setDiscreteState(i, (*featIt)->featureIndex);
		targetState->setContinuousState(i, (*featIt)->factor);
	}
	
	if (sortedList->size() < maxActiveCenters)
	{
		targetState->setNumActiveContinuousStates(sortedList->size());
		targetState->setNumActiveDiscreteStates(sortedList->size());
	}

	normalizeFeatures(targetState);
}

CAdaptiveSoftMaxVFunction::CAdaptiveSoftMaxVFunction(CAdaptiveSoftMaxNetwork *adaptiveSoftMaxNetwork) : CFeatureVFunction(adaptiveSoftMaxNetwork)
{
	this->adaptiveSoftMaxNetwork = adaptiveSoftMaxNetwork;

	addParameters(adaptiveSoftMaxNetwork);

	gradient1List = new CFeatureList();
	gradient2List = new CFeatureList();
}

CAdaptiveSoftMaxVFunction::~CAdaptiveSoftMaxVFunction()
{
	delete gradient1List;
	delete gradient2List;
}

void CAdaptiveSoftMaxVFunction::updateWeights(CFeatureList *gradientFeatures)
{
	CFeatureList::iterator it = gradientFeatures->begin();

	gradient1List->clear();
	gradient2List->clear();

	for (;it != gradientFeatures->end(); it ++)
	{
		if ((*it)->featureIndex < numFeatures)
		{
			gradient1List->add(*it);
		}
		else
		{
			gradient2List->set((*it)->featureIndex - numFeatures, (*it)->factor);
		}
	}
	CFeatureVFunction::updateWeights(gradient1List);
	adaptiveSoftMaxNetwork->updateWeights(gradient2List);
}

CAbstractVETraces *CAdaptiveSoftMaxVFunction::getStandardETraces()
{
	return new CGradientVETraces(this);
}

void CAdaptiveSoftMaxVFunction::getGradient(CStateCollection *state, CFeatureList *gradientFeatures)
{
	CFeatureVFunction::getGradient(state, gradientFeatures);

	CState *featState = state->getState(properties);

	DebugPrint('s', "Beginning Ada Gradient Calculation\n ");

	if (DebugIsEnabled('s'))
	{
		state->getState()->saveASCII(DebugGetFileHandle('s'));
		DebugPrint('s', "\n");
	}
	
	for (int i = 0; i < featState->getNumActiveContinuousStates(); i++)
	{
		gradient1List->clear();
		adaptiveSoftMaxNetwork->getGradient(state, i, gradient1List);
		DebugPrint('s', "Adaptive RBF Gradient for feature %d: ", featState->getDiscreteState(i));
		
		if (DebugIsEnabled('s'))
		{
			gradient1List->saveASCII(DebugGetFileHandle('s'));
			DebugPrint('s', "\n");
		}

		gradientFeatures->add(gradient1List, getFeature(featState->getDiscreteState(i)));
	}

	DebugPrint('s', "Adaptive RBF Gradient: ");
	if (DebugIsEnabled('s'))
	{
		gradientFeatures->saveASCII(DebugGetFileHandle('s'));
		DebugPrint('s', "\n");
	}
}

int CAdaptiveSoftMaxVFunction::getNumWeights()
{
	return numFeatures + adaptiveSoftMaxNetwork->getNumWeights();
}

void CAdaptiveSoftMaxVFunction::resetData()
{
	CFeatureVFunction::resetData();
	adaptiveSoftMaxNetwork->resetData();
}

void CAdaptiveSoftMaxVFunction::saveData(FILE *stream)
{	
	CFeatureVFunction::saveData(stream);
	adaptiveSoftMaxNetwork->saveData(stream);
}

void CAdaptiveSoftMaxVFunction::loadData(FILE *stream)
{
	CFeatureVFunction::loadData(stream);
	adaptiveSoftMaxNetwork->loadData(stream);
}

void CAdaptiveSoftMaxVFunction::getWeights(rlt_real *parameters)
{
	CFeatureVFunction::getWeights(parameters);
	adaptiveSoftMaxNetwork->getWeights(parameters + numFeatures);
}

void CAdaptiveSoftMaxVFunction::setWeights(rlt_real *parameters)
{
	CFeatureVFunction::setWeights(parameters);
	adaptiveSoftMaxNetwork->setWeights(parameters + numFeatures);
}

void CAdaptiveSoftMaxVFunction::receiveError(rlt_real error, CStateCollection *state, CAction *action, CActionData *data)
{
	rlt_real value = getValue(state->getState(properties));
	int newCenter = adaptiveSoftMaxNetwork->addCenterOnError(error, state);
	if (newCenter >= 0)
	{
		setFeature(newCenter, value);
	}
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -