⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cqfunction.cpp

📁 强化学习算法(R-Learning)难得的珍贵资料
💻 CPP
📖 第 1 页 / 共 2 页
字号:
// Copyright (C) 2003
// Gerhard Neumann (gerhard@igi.tu-graz.ac.at)

//                
// This file is part of RL Toolbox.
// http://www.igi.tugraz.at/ril_toolbox
//
// All rights reserved.
// 
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
//    notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
//    notice, this list of conditions and the following disclaimer in the
//    documentation and/or other materials provided with the distribution.
// 3. The name of the author may not be used to endorse or promote products
//    derived from this software without specific prior written permission.
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
// OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
// IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
// NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "ril_debug.h"
#include "cqfunction.h"
#include <assert.h>
#include "cdynamicprogramming.h"
#include <sstream>

CAbstractQFunction::CAbstractQFunction(CActionSet *actions) : CActionObject(actions, false)
{
	type = 0;
	mayDiverge = false;
}

CAbstractQFunction::~CAbstractQFunction()
{
}

void CAbstractQFunction::getActionValues(CStateCollection *stateCol,  CActionSet *actions, rlt_real *actionValues, CActionDataSet *data)
{
	CActionSet::iterator it = actions->begin();
	for (unsigned int i = 0; it != actions->end(); it++, i++)
	{
		if (data)
		{
			actionValues[i] = this->getValue(stateCol, *it, data->getActionData(*it));
		}
		else
		{
			actionValues[i] = this->getValue(stateCol, *it);
		}
	}
}

rlt_real CAbstractQFunction::getMaxValue(CStateCollection *state, CActionSet *availableActions)
{
	assert(availableActions->size() > 0);

	rlt_real max, value;
	rlt_real *actionValues = new rlt_real[availableActions->size()];
	

	getActionValues(state, availableActions, actionValues);

    max = actionValues[0];

	for (unsigned int i = 1; i < availableActions->size(); i++)
	{
        value = actionValues[i];
        if ( max < value)
		{
			max = value;
		}
	}

	delete [] actionValues;

	return max;
}


CAction* CAbstractQFunction::getMax(CStateCollection* stateCol, CActionSet *availableActions, CActionDataSet *data)
{
	assert(availableActions->size() > 0);

	rlt_real max, value;
	rlt_real *actionValues = new rlt_real[availableActions->size()];

	
	CActionSet::iterator it = availableActions->begin();
	CActionSet *max_list = new CActionSet();

	getActionValues(stateCol, availableActions, actionValues);


    max = actionValues[0];
	max_list->push_back(*it++);

	for (unsigned int i = 1; it != availableActions->end(); it++, i++)
	{
        value = actionValues[i];
        if ( max < value)
		{
			max_list->clear();
			max = value;
			max_list->push_back(*it);
		}
     	else if (max == value)
		{
			max_list->push_back(*it);
		}							
	}

	//int index = rand() % max_list->size();
	int index = 0;
	CAction *action = max_list->get(index);

	DebugPrint('q', "ActionValues: ");
	for (unsigned int j = 0; j < availableActions->size(); j++)
	{
		DebugPrint('q', "%f ", actionValues[j]);
	}
	DebugPrint('q', "\nMax: %d\n", actions->getIndex(action));

	delete max_list;
	delete [] actionValues;
	return action;
}

void CAbstractQFunction::getStatistics(CStateCollection* state, CAction* action, CActionSet *availableActions, CActionStatistics *statistics)
{
	assert(availableActions->size() > 0);
    assert(statistics != NULL);
	
	rlt_real *actionValues = new rlt_real[availableActions->size()];
	// get Q-Values
	getActionValues(state, availableActions, actionValues);
	// Hier wird die WK-Verteilung erstellt, aus der die Statistik berechnet wird
	//transform: smallest Value = 0, Value sum = 1;
	CDistributions::getS1L0Distribution(actionValues, availableActions->size());

	statistics->action = action;
    statistics->equal = 0;
	statistics->superior = 0;
	statistics->probability = actionValues[availableActions->getIndex(action)];
    	
	for (unsigned int i = 0; i < availableActions->size(); i++)
	{
		if (statistics->probability == actionValues[i]) statistics->equal++;
		if (statistics->probability < actionValues[i]) statistics->superior++;
	}
	delete [] actionValues;
}

int CAbstractQFunction::getType()
{
	return type;
}

void CAbstractQFunction::addType(int Type)
{
	type = type | Type;	
}

bool CAbstractQFunction::isType(int type)
{
	return (this->type & type) > 0;
}


void CAbstractQFunction::saveData(FILE *file)
{
    fprintf(file, "Q-Function:\n");
    fprintf(file, "Actions: %d\n\n", actions->size());
}

void CAbstractQFunction::loadData(FILE *file)
{
	unsigned int buf = 0;
    assert(fscanf(file, "Q-Function:\n") == 0);
	assert(fscanf(file, "Actions: %d\n\n", &buf) == 1 &&  buf == actions->size());
    assert(fscanf(file, "\n") == 0);
}

CQFunctionSum::CQFunctionSum(CActionSet *actions) : CAbstractQFunction(actions)
{
	qFunctions = new std::map<CAbstractQFunction *, rlt_real>;
}

CQFunctionSum::~CQFunctionSum()
{
	delete qFunctions;
}

rlt_real CQFunctionSum::getValue(CStateCollection *state, CAction *action, CActionData *data)
{
	std::map<CAbstractQFunction *, rlt_real>::iterator it = qFunctions->begin();

	rlt_real sum = 0.0;
	for (;it != qFunctions->end(); it ++)
	{
		CAbstractQFunction *qFunc = (*it).first;
		if (qFunc->getActions()->isMember(action))
		{
			sum += (*it).second * qFunc->getValue(state, action, data);
		}
	}
	return sum;
}


rlt_real CQFunctionSum::getQFunctionFactor(CAbstractQFunction *qFunction)
{
	return (*qFunctions)[qFunction];
}

void CQFunctionSum::setQFunctionFactor(CAbstractQFunction *qFunction, rlt_real factor)
{
	(*qFunctions)[qFunction] = factor;
}

void CQFunctionSum::addQFunction(CAbstractQFunction *qFunction, rlt_real factor)
{
	(*qFunctions)[qFunction] = factor;
}

void CQFunctionSum::removeQFunction(CAbstractQFunction *qFunction)
{
	std::map<CAbstractQFunction *, rlt_real>::iterator it = qFunctions->find(qFunction);

	qFunctions->erase(it);
}

void CQFunctionSum::normFactors(rlt_real factor)
{
	std::map<CAbstractQFunction *, rlt_real>::iterator it = qFunctions->begin();

	rlt_real sum = 0.0;
	for (;it != qFunctions->end(); it ++)
	{
		sum += (*it).second;
	}
	for (;it != qFunctions->end(); it ++)
	{
		(*it).second *= factor / sum;
	}
}


CDivergentQFunctionException::CDivergentQFunctionException(string qFunctionName, CAbstractQFunction *qFunction, CState *state, rlt_real value) : CMyException(102, "DivergentQFunction")
{
	this->qFunction = qFunction;
	this->qFunctionName = qFunctionName;
	this->state = state;
	this->value = value;
}

string CDivergentQFunctionException::getInnerErrorMsg()
{
	stringstream stream;

	stream << qFunctionName.c_str() << " diverges (value = " << value << ", |value| > 1000000).";

	return stream.str();
}

CGradientQFunction::CGradientQFunction(CActionSet *actions) : CAbstractQFunction(actions)
{
	addType(GRADIENTQFUNCTION);

	this->localGradientQFunctionFeatures = new CFeatureList();
}

CGradientQFunction::~CGradientQFunction()
{
	delete localGradientQFunctionFeatures;
}

void CGradientQFunction::updateValue(CStateCollection *state, CAction *action,rlt_real td, CActionData *data)
{
	localGradientFeatureBuffer->clear();
	getGradient(state, action, data, localGradientFeatureBuffer);

	updateGradient(localGradientFeatureBuffer, td);
}

CGradientDelayedUpdateQFunction::CGradientDelayedUpdateQFunction(CGradientQFunction *qFunction) :  CGradientQFunction(qFunction->getActions()), CGradientDelayedUpdateFunction(qFunction)
{
	this->qFunction = qFunction;
}

rlt_real CGradientDelayedUpdateQFunction::getValue(CStateCollection *state, CAction *action, CActionData *data )
{
	return qFunction->getValue(state, action, data);
}

void CGradientDelayedUpdateQFunction::getGradient(CStateCollection *state, CAction *action, CActionData *data, CFeatureList *gradientFeatures)
{
	qFunction->getGradient(state, action, data, gradientFeatures);
}


CAbstractQETraces *CGradientQFunction::getStandardETraces()
{
	return new CGradientQETraces(this);
}

CQFunction::CQFunction(CActionSet *act) : CGradientQFunction(act)
{
	this->vFunctions = new std::map<CAction *, CAbstractVFunction *>();
}

CQFunction::~CQFunction()
{
	delete vFunctions;
}

void CQFunction::updateValue(CState *state, CAction *action, rlt_real td, CActionData *data)
{
	assert((*vFunctions)[action]);

	(*vFunctions)[action]->updateValue(state, td);
}

void CQFunction::setValue(CState *state, CAction *action, rlt_real value, CActionData *data)
{
	assert((*vFunctions)[action]);

	(*vFunctions)[action]->setValue(state, value);
}

rlt_real CQFunction::getValue(CState *state, CAction *action, CActionData *data)
{
	rlt_real value = 0.0;

	assert((*vFunctions)[action]);

	value = (*vFunctions)[action]->getValue(state);

	return value;
}

void CQFunction::updateValue(CStateCollection *state, CAction *action, rlt_real td, CActionData *data)
{
	assert((*vFunctions)[action]);

	(*vFunctions)[action]->updateValue(state, td);
}

void CQFunction::setValue(CStateCollection *state, CAction *action, rlt_real value, CActionData *data)
{
	assert((*vFunctions)[action]);

	(*vFunctions)[action]->setValue(state, value);
}

rlt_real CQFunction::getValue(CStateCollection *state, CAction *action, CActionData *data)
{
	rlt_real value = 0.0;

	assert((*vFunctions)[action]);

	value = (*vFunctions)[action]->getValue(state);

	return value;
}


CAbstractVFunction *CQFunction::getVFunction(CAction *action)
{
	return (*vFunctions)[action];
}

CAbstractVFunction *CQFunction::getVFunction(int index)
{
	return (*vFunctions)[actions->get(index)];
}


void CQFunction::setVFunction(CAction *action, CAbstractVFunction *vfunction, bool bDelete)
{
	if (bDelete && (*vFunctions)[action] != NULL)
	{
		delete (*vFunctions)[action];
	}
	(*vFunctions)[action] = vfunction;

	if (!vfunction->isType(GRADIENTVFUNCTION))
	{
		type = type & (~ GRADIENTQFUNCTION);
	}

	addParameters(vfunction);
}

void CQFunction::setVFunction(int index, CAbstractVFunction *vfunction, bool bDelete)
{
	setVFunction(actions->get(index), vfunction, bDelete);
}

int CQFunction::getNumVFunctions()
{
	return vFunctions->size();
}

void CQFunction::saveData(FILE *file)
{
	CAbstractQFunction::saveData(file);

	CActionSet::iterator it = actions->begin();

	for (; it != actions->end(); it++)
	{
		(*vFunctions)[(*it)]->saveData(file);
	}
}

void CQFunction::loadData(FILE *file)
{
	CAbstractQFunction::loadData(file);

	CActionSet::iterator it = actions->begin();

	for (; it != actions->end(); it++)
	{
		(*vFunctions)[(*it)]->loadData(file);
	}
    //assert(fscanf(file, "Lambda: %f\n", &lambda) == 1);
    assert(fscanf(file, "\n") == 0);
}

void CQFunction::printValues()
{
	CAbstractQFunction::printValues();

	CActionSet::iterator it = actions->begin();

	for (; it != actions->end(); it++)
	{
		(*vFunctions)[(*it)]->printValues();
	}
}

CAbstractQETraces *CQFunction::getStandardETraces()
{
	return new CQETraces(this);
}

void CQFunction::resetData()
{
	CActionSet::iterator it = actions->begin();

	for (; it != actions->end(); it++)
	{
		(*vFunctions)[(*it)]->resetData();
	}
}
/*CStateProperties *CQFunction::getGradientCalculator(CAction *action)
{
    int index = actions->getIndex(action);

	if ((*vFunctions)[action]->isType(GRADIENTVFUNCTION))
	{
		CGradientVFunction *gradVFunc = dynamic_cast<CGradientVFunction *>((*vFunctions)[action]);

		return gradVFunc->getGradientCalculator();
	}
	else
	{
		return NULL;
	}
}*/

void CQFunction::getGradient(CStateCollection *stateCol, CAction *action, CActionData *data, CFeatureList *gradient)
{
	if ((*vFunctions)[action]->isType(GRADIENTVFUNCTION))
	{
		CGradientVFunction *gradVFunc = dynamic_cast<CGradientVFunction *>((*vFunctions)[action]);
		gradVFunc->getGradient(stateCol, gradient);
		gradient->addIndexOffset(getWeightsOffset(action));
	}
}

void CQFunction::updateWeights(CFeatureList *features)
{
	unsigned int featureBegin = 0;
	unsigned int featureEnd = 0;

	if (DebugIsEnabled('q'))
	{
		DebugPrint('q', "Updating Features: ");
		features->saveASCII(DebugGetFileHandle('q'));
		DebugPrint('q', "\n");
	}
	
	if(isType(GRADIENTQFUNCTION))
	{
		std::map<CAction *, CAbstractVFunction *>::iterator it = vFunctions->begin();
		CFeatureList::iterator itFeat;

		for (int i = 0; it != vFunctions->end();it++, i++)
		{
			CGradientVFunction *gradVFunction = dynamic_cast<CGradientVFunction *>((*it).second);
			featureEnd += gradVFunction->getNumWeights();
	
			localGradientQFunctionFeatures->clear();

			for (itFeat = features->begin(); itFeat != features->end(); itFeat++)
			{
				if ((*itFeat)->featureIndex >= featureBegin && (*itFeat)->featureIndex < featureEnd)
				{
					localGradientQFunctionFeatures->update((*itFeat)->featureIndex - featureBegin, (*itFeat)->factor);
				}
			}

			if (DebugIsEnabled('q'))
			{
				DebugPrint('q', "Updating Features for Action %d: ",i);
				localGradientQFunctionFeatures->saveASCII(DebugGetFileHandle('q'));
				DebugPrint('q', "\n");
			}
			gradVFunction->updateGradient(localGradientQFunctionFeatures, 1.0);

			featureBegin += gradVFunction->getNumWeights();
		}
	}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -