⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cadaptivesoftmaxnetwork.cpp

📁 强化学习算法(R-Learning)难得的珍贵资料
💻 CPP
📖 第 1 页 / 共 2 页
字号:
// Copyright (C) 2003
// Gerhard Neumann (gerhard@igi.tu-graz.ac.at)

//                
// This file is part of RL Toolbox.
// http://www.igi.tugraz.at/ril_toolbox
//
// All rights reserved.
// 
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
//    notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
//    notice, this list of conditions and the following disclaimer in the
//    documentation and/or other materials provided with the distribution.
// 3. The name of the author may not be used to endorse or promote products
//    derived from this software without specific prior written permission.
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
// OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
// IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
// NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "ril_debug.h"
#include "cadaptivesoftmaxnetwork.h"

CRBFCenter::CRBFCenter(int numDim, CMyVector *l_centers, CMyVector *l_sigmas)
{
	centers = new CMyVector(numDim);
	sigmas = new CMyVector(numDim);

	centers->setVector(l_centers);
	sigmas->setVector(l_sigmas);
}

CRBFCenter::CRBFCenter(int numDim)
{
	centers = new CMyVector(numDim);
	sigmas = new CMyVector(numDim);
}

CRBFCenter::~CRBFCenter()
{
	delete centers;
	delete sigmas;
}

rlt_real CRBFCenter::getFactor(CState *state)
{	
	rlt_real sum = 0.0;
	for (int i = 0; i < state->getNumDimensions(); i++)
	{
		sum += pow((state->getSingleStateDifference(i, centers->getElement(i))) / sigmas->getElement(i), 2.0);
	}
	return my_exp(-sum);
}

CAdaptiveSoftMaxNetworkEtaCalculator::CAdaptiveSoftMaxNetworkEtaCalculator(int numDim)
{
	this->numDim = numDim;

	addParameter("CenterLearningRate", 1.0);
	addParameter("SigmaLearningRate", 1.0);
}

void CAdaptiveSoftMaxNetworkEtaCalculator::getWeightUpdates(CFeatureList *updates)
{
	CFeatureList::iterator it = updates->begin();
	rlt_real centerLearningRate = getParameter("CenterLearningRate");
	rlt_real sigmaLearningRate = getParameter("SigmaLearningRate");

	for (; it != updates->end(); it ++)
	{
		if (((*it)->featureIndex / numDim) % 2 == 0)
		{
			(*it)->factor *= centerLearningRate;
		}
		else
		{
			(*it)->factor *= sigmaLearningRate;
		}
	}
}

CAdaptiveSoftMaxNetwork::CAdaptiveSoftMaxNetwork(CStateProperties *stateProperties, int maxCenters, int maxActiveCenters, int featureOffset, CMyVector *startSigma, CMyVector *epsilon) : CFeatureCalculator(maxCenters, maxActiveCenters)
{
	centers = new std::map<int, CRBFCenter *>;
	startCenters = new std::list<CRBFCenter *>;

	this->numDim = stateProperties->getNumContinuousStates();
	this->originalState = stateProperties;

	this->maxCenters = maxCenters;
	this->maxActiveCenters = maxActiveCenters;
	this->featureOffset = featureOffset;

	this->startSigma = new CMyVector(*startSigma);
	this->epsilon = new CMyVector(*epsilon);

	addParameter("MaxFactorForAdd",0.3);
	addParameter("MinErrorForAdd", 0.2);

	softMaxEtaCalc = new CAdaptiveSoftMaxNetworkEtaCalculator(numDim);
	setEtaCalculator(softMaxEtaCalc);

	searchList1 = new std::list<CRBFCenter *>;
	searchList2 = new std::list<CRBFCenter *>;

	sortedList = new CFeatureList(maxActiveCenters, true);

	currentCenters = 0;

	changeState = true;
}

CAdaptiveSoftMaxNetwork::~CAdaptiveSoftMaxNetwork()
{
	clearCenters();
	delete centers;
	delete startCenters;

	delete startSigma;
	delete epsilon;

	delete searchList1;
	delete searchList2;

	delete sortedList;

	delete softMaxEtaCalc;

}

void CAdaptiveSoftMaxNetwork::clearCenters()
{
	std::map<int, CRBFCenter *>::iterator it = centers->begin();

	for (;it != centers->end(); it ++)
	{
		if ((*it).second != NULL)
		{
			delete (*it).second;
		}
	}

	centers->clear();

	currentCenters = 0;
}

void CAdaptiveSoftMaxNetwork::addStartCenter(CRBFCenter *center)
{
	startCenters->push_back(center);
	addRBFCenter(new CRBFCenter(numDim, center->centers, center->sigmas));
}

int CAdaptiveSoftMaxNetwork::addRBFCenter(CRBFCenter *center)
{
	if (currentCenters == maxCenters)
	{
		return -1;
	}
	center->numCenter = currentCenters;
	(*centers)[currentCenters] = center;
	currentCenters ++;

	
	if (DebugIsEnabled('s'))
	{
		DebugPrint('s', "Adding RBF center %d add pos :", center->numCenter);
		center->centers->saveASCII(DebugGetFileHandle('s'));
	}

	return currentCenters - 1;
}

void CAdaptiveSoftMaxNetwork::addCenterGrid(CGridFeatureCalculator *grid, CMyVector *sigmas)
{
	assert(grid->getNumFeatures() < maxCenters);

	for (int i = 0; i < grid->getNumFeatures(); i++)
	{
		CRBFCenter *rbfCenter = new CRBFCenter(grid->getNumDimensions());
		
		grid->getFeaturePosition(i, rbfCenter->centers);

		for (int i = 0; i < numDim; i ++)
		{
			rbfCenter->centers->setElement(i, rbfCenter->centers->getElement(i) * (originalState->getMaxValue(i) - originalState->getMinValue(i)) + originalState->getMinValue(i));
		}

		rbfCenter->sigmas->setVector(sigmas);

		addStartCenter(rbfCenter);
	}
}

void CAdaptiveSoftMaxNetwork::getGradient(CStateCollection *state, int featureIndex, CFeatureList *gradientFeatures)
{
	rlt_real centerLearningRate = getParameter("CenterLearningRate");
	rlt_real sigmaLearningRate = getParameter("SigmaLearningRate");
	if (centerLearningRate > 0 || sigmaLearningRate > 0)
	{
		CState *featState = state->getState(this);
		CState *modelState = state->getState(originalState);

		CRBFCenter *rbfCenter0 = (*centers)[featState->getDiscreteState(0)];

		if (rbfCenter0 != NULL)
		{
			//rlt_real normFactor = rbfCenter0->getFactor(modelState) / featState->getContinuousState(0);

			rlt_real dbk_dak = (1 - featState->getContinuousState(featureIndex)) * featState->getContinuousState(featureIndex);
			int centerNumber = featState->getDiscreteState(featureIndex);
			CRBFCenter *center = (*centers)[featState->getDiscreteState(featureIndex)];

			if (center == NULL)
			{
				printf("Warning: Non Existing RBF Center given\n");
			}

			if (centerLearningRate > 0 && dbk_dak > 0.001  && center != NULL)
			{
				for (int i = 0; i < numDim; i ++)
				{
					gradientFeatures->update(featureOffset + centerNumber * numDim * 2 + i, dbk_dak * (modelState->getSingleStateDifference(i,center->centers->getElement(i))) / pow(center->sigmas->getElement(i), 2.0));
				}
			}

			if (sigmaLearningRate > 0 && dbk_dak > 0.001 && center != NULL)
			{
				for (int i = 0; i < numDim; i ++)
				{
					gradientFeatures->update(featureOffset + (centerNumber * 2 + 1 ) * numDim  + i, dbk_dak * pow(modelState->getSingleStateDifference(i,center->centers->getElement(i)), 2.0) / pow(center->sigmas->getElement(i), 3.0));
				}
			}
		}	
	}
}

int CAdaptiveSoftMaxNetwork::addCenterOnError(rlt_real error, CStateCollection *state)
{
	int centerNum = -1;
	if (fabs(error) > getParameter("MinErrorForAdd") && currentCenters < maxCenters - 1)
	{
		CState *featState = state->getState(this);
		CState *modelState = state->getState(originalState);

		CRBFCenter *rbfCenter = (*centers)[featState->getDiscreteState(0)];
		bool addCenter = true;

		if (rbfCenter != NULL)
		{
			rlt_real normFactor = 0.0;
			if (featState->getContinuousState(0) > 0)
			{
				normFactor = rbfCenter->getFactor(modelState) / featState->getContinuousState(0);
			}
			rlt_real amin = getParameter("MaxFactorForAdd");

			int i = 0;

			while (addCenter && i < featState->getNumActiveContinuousStates())
			{
				addCenter = featState->getContinuousState(i) * normFactor < amin;
				i++;
			}
		}
		else
		{
			centers->erase(centers->find(featState->getDiscreteState(0)));
		}
		
		if (addCenter)
		{
			CRBFCenter *newCenter = new CRBFCenter(numDim, modelState, startSigma);
			centerNum = addRBFCenter(newCenter);
			stateChanged();
			DebugPrint('s', "Added RBF Center %d, %d centers\n", centerNum, centers->size());
			printf("Added RBF Center %d\n", centerNum);
		}
	}
	return centerNum;
}

void CAdaptiveSoftMaxNetwork::updateWeights(CFeatureList *gradientFeatures)
{
	CFeatureList::iterator it = gradientFeatures->begin();

	rlt_real centerLearningRate = getParameter("CenterLearningRate");
	rlt_real sigmaLearningRate = getParameter("SigmaLearningRate");
	if (centerLearningRate > 0 || sigmaLearningRate > 0)
	{
		for (; it != gradientFeatures->end(); it++)
		{
			int index = (*it)->featureIndex;
			index = index / numDim;

			CRBFCenter *rbfCenter = (*centers)[index / 2];

			if (rbfCenter)
			{
				if (index % 2 == 0)
				{
					rbfCenter->centers->setElement((*it)->featureIndex % numDim, rbfCenter->centers->getElement((*it)->featureIndex % numDim) + (*it)->factor);
				}
				else
				{
					rbfCenter->sigmas->setElement((*it)->featureIndex % numDim, rbfCenter->centers->getElement((*it)->featureIndex % numDim) + (*it)->factor);
				}
			}
		}
	}
}

int CAdaptiveSoftMaxNetwork::getNumWeights()
{
	return maxCenters *  numDim * 2;
}

void CAdaptiveSoftMaxNetwork::resetData()
{
	clearCenters();

	addStartCenters();
}

void CAdaptiveSoftMaxNetwork::addStartCenters()
{
	std::list<CRBFCenter *>::iterator it = startCenters->begin();

	for (; it != startCenters->end(); it ++)
	{
		addRBFCenter(new CRBFCenter(numDim, (*it)->centers, (*it)->sigmas));
	}
}

void CAdaptiveSoftMaxNetwork::getWeights(rlt_real *parameters)
{
	//parameters[0] = currentCenters;
	int arrayPos = 0;

	std::map<int, CRBFCenter *>::iterator it = centers->begin();
	for (; it != centers->end(); it ++)
	{
		CRBFCenter *center = (*it).second;
		int num = (*it).first;
		if ((*it).second != NULL)
		{

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -