⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cqfunction.cpp

📁 强化学习算法(R-Learning)难得的珍贵资料
💻 CPP
📖 第 1 页 / 共 2 页
字号:
}

int CQFunction::getNumWeights()
{
	int nparams = 0;
	std::map<CAction *, CAbstractVFunction *>::iterator it = vFunctions->begin();
	for (; it != vFunctions->end();it++)
	{
		CGradientVFunction *gradVFunction = dynamic_cast<CGradientVFunction *>((*it).second);
		nparams += gradVFunction->getNumWeights();
	}
	return nparams;
}

int CQFunction::getWeightsOffset(CAction *action)
{
	int nparams = 0;
	std::map<CAction *, CAbstractVFunction *>::iterator it = vFunctions->begin();
	for (; it != vFunctions->end();it++)
	{
		if ((*it).first == action)
		{	
			break;
		}
		CGradientVFunction *gradVFunction = dynamic_cast<CGradientVFunction *>((*it).second);
		nparams += gradVFunction->getNumWeights();
	}
	return nparams;
}

void CQFunction::getWeights(rlt_real *weights)
{
	rlt_real *vFuncWeights = weights;
	std::map<CAction *, CAbstractVFunction *>::iterator it = vFunctions->begin();
	for (; it != vFunctions->end();it++)
	{
		CGradientVFunction *gradVFunction = dynamic_cast<CGradientVFunction *>((*it).second);
		gradVFunction->getWeights(vFuncWeights);
		vFuncWeights += gradVFunction->getNumWeights();
	}
}

void CQFunction::setWeights(rlt_real *weights)
{
	rlt_real *vFuncWeights = weights;
	std::map<CAction *, CAbstractVFunction *>::iterator it = vFunctions->begin();
	for (; it != vFunctions->end();it++)
	{
		CGradientVFunction *gradVFunction = dynamic_cast<CGradientVFunction *>((*it).second);
		gradVFunction->setWeights(vFuncWeights);
		vFuncWeights += gradVFunction->getNumWeights();
	}
}

CQFunctionFromStochasticModel::CQFunctionFromStochasticModel(CFeatureVFunction *vfunction, CAbstractFeatureStochasticModel *model, CFeatureRewardFunction *rewardfunction) : CAbstractQFunction(model->getActions()), CStateObject(vfunction->getStateProperties())
{
	this->vfunction = vfunction;
	this->model = model;
	this->discretizer = vfunction->getStateProperties();
	this->rewardfunction = rewardfunction;

	discState = new CState(new CStateProperties(0,1,DISCRETESTATE));
	discState->getStateProperties()->setDiscreteStateSize(0, vfunction->getNumFeatures());

	addParameter("DiscountFactor", 0.95);
}

CQFunctionFromStochasticModel::~CQFunctionFromStochasticModel()
{
	delete discState->getStateProperties();
	delete discState;
}

rlt_real CQFunctionFromStochasticModel::getValue(CStateCollection *state, CAction *action, CActionData *data)
{
	rlt_real value = getValue(state->getState(properties), action);

	return value;
}

rlt_real CQFunctionFromStochasticModel::getValue(int state, CAction *action, CActionData *data)
{
	discState->setDiscreteState(0, state);

	rlt_real value = CDynamicProgramming::getActionValue(model, this->rewardfunction, this->vfunction, discState, action, getParameter("DiscountFactor"));
	
	return value;
}


rlt_real CQFunctionFromStochasticModel::getValue(CState *featState, CAction *action, CActionData *data)
{
	rlt_real stateValue = 0.0;
		
	int type = featState->getStateProperties()->getType() & (DISCRETESTATE | FEATURESTATE);
	switch (type)
	{
		case DISCRETESTATE:
		{
			stateValue = CDynamicProgramming::getActionValue(model, this->rewardfunction, this->vfunction, featState, action, getParameter("DiscountFactor"));
			break;
		}
		case FEATURESTATE:
		{
			for (unsigned int i = 0; i < featState->getNumContinuousStates(); i++)
			{
				stateValue += getValue(featState->getDiscreteState(i), action) * featState->getContinuousState(i);
			}
			break;
		}
	}
	return stateValue;
}


/*
CQTable::CQTable(CActionSet *actions, CAbstractStateDiscretizer *discretizer) : CQFunction(actions), CAbstractQFunction(actions)
{
	this->discretizer = discretizer;
	init(discretizer->getDiscreteStateSize());
}


void CQTable::init(int states)
{
	this->states = states;
	CVTable *table = NULL;

	for (CActionSet::iterator it = actions->begin(); it != actions->end(); it++)
	{
		if (discretizer != NULL) table = new CVTable(discretizer);
		else table = new CVTable(discretizer);
		this->setVFunction(*it, table);
	}
}
	
CQTable::~CQTable()
{
	for (std::list<CAbstractVFunction *>::iterator it = vFunctions->begin(); it != vFunctions->end(); it ++)
	{
		delete *it;
	}
}

void CQTable::setDiscretizer(CAbstractStateDiscretizer *discretizer)
{
	assert(discretizer == NULL || discretizer->getDiscreteStateSize() == states);

	this->discretizer = discretizer;
}

CAbstractStateDiscretizer *CQTable::getDiscretizer()
{
	return discretizer;
}


int CQTable::getNumStates()
{
	return states;
}*/

CFeatureQFunction::CFeatureQFunction(CActionSet *actions, CStateModifier *discretizer) : CQFunction(actions)
{
	this->discretizer = discretizer;
	this->features = discretizer->getDiscreteStateSize();
	init();
}

CFeatureQFunction::CFeatureQFunction(CFeatureVFunction *vfunction, CAbstractFeatureStochasticModel *model,  CFeatureRewardFunction *rewardFunction, rlt_real gamma) : CQFunction(model->getActions())
{
	this->discretizer = (CStateModifier *) vfunction->getStateProperties();
	this->features = discretizer->getDiscreteStateSize();

	init();

	initVFunctions(vfunction, model, rewardFunction, gamma);
}

void CFeatureQFunction::init()
{
	CFeatureVFunction *vFunction = NULL;

	featureVFunctions = new std::list<CFeatureVFunction *>();
	for (CActionSet::iterator it = actions->begin(); it != actions->end(); it++)
	{
		vFunction = new CFeatureVFunction(discretizer);
		featureVFunctions->push_back(vFunction);
		this->setVFunction(*it, vFunction);
	}
}
	
CFeatureQFunction::~CFeatureQFunction()
{
	for (std::list<CFeatureVFunction *>::iterator it = featureVFunctions->begin(); it != featureVFunctions->end(); it ++)
	{
		delete *it;
	}
	delete featureVFunctions;
}

void CFeatureQFunction::setFeatureCalculator(CStateModifier *discretizer)
{
	assert(discretizer == NULL || discretizer->getDiscreteStateSize() == features);

	this->discretizer = discretizer;
}

CStateModifier *CFeatureQFunction::getFeatureCalculator()
{
	return discretizer;
}

int CFeatureQFunction::getNumFeatures()
{
	return features;
}

void CFeatureQFunction::initVFunctions(CFeatureVFunction *vfunction, CAbstractFeatureStochasticModel *model, CFeatureRewardFunction *rewardFunction, rlt_real gamma)
{
	std::list<CAction *>::iterator itAction;

	CState *discState = new CState(new CStateProperties(0,1));

	for (int feature = 0; feature < getNumFeatures(); feature ++)
	{
		discState->setDiscreteState(0, feature);
		for (itAction = actions->begin(); itAction != actions->end(); itAction ++)
		{
			((CFeatureVFunction *)(*vFunctions)[*itAction])->setFeature(feature, CDynamicProgramming::getActionValue(model, rewardFunction, vfunction, discState, *itAction, gamma));
		}
	}
	delete discState->getStateProperties();
	delete discState;
}


void CFeatureQFunction::updateValue(CFeature *state, CAction *action, rlt_real td, CActionData *data)
{
	((CFeatureVFunction *) getVFunction(action))->updateFeature(state, td);
}

void CFeatureQFunction::setValue(int state, CAction *action, rlt_real qValue, CActionData *data)
{
	((CFeatureVFunction *) getVFunction(action))->setFeature(state, qValue);
}

rlt_real CFeatureQFunction::getValue(int feature, CAction *action, CActionData *data)
{
	return ((CFeatureVFunction *) getVFunction(action))->getFeature(feature);
}


void CFeatureQFunction::saveFeatureActionValueTable(FILE *stream)
{
	fprintf(stream, "Q-FeatureActionValue Table\n");
	CActionSet::iterator it;

	for (unsigned int i = 0; i < discretizer->getDiscreteStateSize(); i++)
	{
		fprintf(stream,"State %d: ", i);
		for (it = actions->begin(); it != actions->end(); it++)
		{
			fprintf(stream,"%f ", ((CFeatureVFunction *) (*vFunctions)[*it])->getFeature(i));
		}
		fprintf(stream, "\n");
	}
}

void CFeatureQFunction::saveFeatureActionTable(FILE *stream)
{
	fprintf(stream, "Q-FeatureAction Table\n");
	CActionSet::iterator it;
	rlt_real max = 0.0;
	unsigned int maxIndex = 0;

	for (unsigned int i = 0; i < discretizer->getDiscreteStateSize(); i++)
	{
		fprintf(stream,"State %d: ", i);
		
		it = actions->begin();
		max = ((CFeatureVFunction *)(*vFunctions)[*it])->getFeature(i);
		it ++;
		maxIndex = 0;
		for (unsigned int j = 1; it != actions->end(); it++, j++)
		{
			rlt_real qValue = ((CFeatureVFunction*)(*vFunctions)[*it])->getFeature(i);
			if (max < qValue)
			{
				max  = qValue;
				maxIndex = j;
			}
		}
		fprintf(stream, "%d", maxIndex);
		fprintf(stream, "\n");
	}
}


CComposedQFunction::CComposedQFunction() : CGradientQFunction(new CActionSet())
{
	this->qFunctions = new std::list<CAbstractQFunction *>();
	//gradientFeatures = new CFeatureList();
}

CComposedQFunction::~CComposedQFunction()
{
	delete qFunctions;
	//delete gradientFeatures;
}

void CComposedQFunction::saveData(FILE *file)
{
	std::list<CAbstractQFunction *>::iterator it = qFunctions->begin();
	fprintf(file, "Composed QFunction (containing %d QFunctions)\n", qFunctions->size());
	for (; it != qFunctions->begin(); it++)
	{
		(*it)->saveData(file);
	}
}

void CComposedQFunction::loadData(FILE *file)
{
	int buf = 0;
	std::list<CAbstractQFunction *>::iterator it = qFunctions->begin();
	fscanf(file, "Composed QFunction (containing %d QFunctions)\n", &buf);
	for (; it != qFunctions->begin(); it++)
	{
		(*it)->loadData(file);
	}
}

void CComposedQFunction::printValues()
{
	std::list<CAbstractQFunction *>::iterator it = qFunctions->begin();
	for (; it != qFunctions->begin(); it++)
	{
		(*it)->printValues();
	}
}


void CComposedQFunction::getStatistics(CStateCollection *state, CAction *action, CActionSet *actions, CActionStatistics* statistics)
{
	std::list<CAbstractQFunction *>::iterator it = qFunctions->begin();
	for (; it != qFunctions->begin(); it++)
	{
		if ((*it)->getActions()->isMember(action))
		{
			(*it)->getStatistics(state, action, actions, statistics);
		}
	}
}

void CComposedQFunction::updateValue(CStateCollection *state, CAction *action, rlt_real td, CActionData *data)
{
	std::list<CAbstractQFunction *>::iterator it = qFunctions->begin();
	for (; it != qFunctions->begin(); it++)
	{
		if ((*it)->getActions()->isMember(action))
		{
			(*it)->updateValue(state, action, td, data);
		}
	}
}


void CComposedQFunction::setValue(CStateCollection *state, CAction *action, rlt_real qValue, CActionData *data)
{
	std::list<CAbstractQFunction *>::iterator it = qFunctions->begin();
	for (; it != qFunctions->begin(); it++)
	{
		if ((*it)->getActions()->isMember(action))
		{
			(*it)->setValue(state, action, qValue, data);
		}
	}
}

rlt_real CComposedQFunction::getValue(CStateCollection *state, CAction *action, CActionData *data)
{
	std::list<CAbstractQFunction *>::iterator it = qFunctions->begin();
	for (; it != qFunctions->begin(); it++)
	{
		if ((*it)->getActions()->isMember(action))
		{
			return (*it)->getValue(state, action, data);
		}
	}
	return 0;
}


void CComposedQFunction::addQFunction(CAbstractQFunction *qFunction)
{
	qFunctions->push_back(qFunction);

	actions->add(qFunction->getActions());

	if (!qFunction->isType(GRADIENTQFUNCTION))
	{
		type = type & (~ GRADIENTQFUNCTION);
	}
	addParameters(qFunction);
}


std::list<CAbstractQFunction *> *CComposedQFunction::getQFunctions()
{
	return qFunctions;
}

int CComposedQFunction::getNumQFunctions()
{
	return qFunctions->size();
}

CAbstractQETraces *CComposedQFunction::getStandardETraces()
{
	return new CComposedQETraces(this);
}


void CComposedQFunction::getGradient(CStateCollection *stateCol, CAction *action, CActionData *data, CFeatureList *gradient)
{
	std::list<CAbstractQFunction *>::iterator it = qFunctions->begin();

	for (; it != qFunctions->end(); it++)
	{
		if ((*it)->getActions()->isMember(action))
		{	
			if ((*it)->isType(GRADIENTQFUNCTION))
			{
				CGradientQFunction *gradQFunc = dynamic_cast<CGradientQFunction *>(*it);
				gradQFunc->getGradient(stateCol, action, data, gradient);
				gradient->addIndexOffset(getWeightsOffset(action));
			}	
		}
	}
}



void CComposedQFunction::updateWeights(CFeatureList *features)
{
	unsigned int featureBegin = 0;
	unsigned int featureEnd = 0;
	if(isType(GRADIENTQFUNCTION))
	{
		std::list<CAbstractQFunction *>::iterator it = qFunctions->begin();
		CFeatureList::iterator itFeat;

		for (; it != qFunctions->end();it++)
		{
			CGradientQFunction *gradQFunction = dynamic_cast<CGradientQFunction *>(*it);
			featureEnd += gradQFunction->getNumWeights();

			localGradientQFunctionFeatures->clear();

			for (itFeat = features->begin(); itFeat != features->end(); it++)
			{
				if ((*itFeat)->featureIndex >= featureBegin && (*itFeat)->featureIndex < featureEnd)
				{
					localGradientQFunctionFeatures->add(*itFeat);
				}
			}
			gradQFunction->updateGradient(localGradientQFunctionFeatures);
		}
	}
}

int CComposedQFunction::getNumWeights()
{
	int nparams = 0;
	std::list<CAbstractQFunction *>::iterator it = qFunctions->begin();
	for (; it != qFunctions->end();it++)
	{
		CGradientQFunction *gradQFunction = dynamic_cast<CGradientQFunction *>(*it);
		nparams += gradQFunction->getNumWeights();
	}
	return nparams;
}

int CComposedQFunction::getWeightsOffset(CAction *action)
{

	int nparams = 0;
	std::list<CAbstractQFunction *>::iterator it = qFunctions->begin();
	for (; it != qFunctions->end();it++)
	{
		CGradientQFunction *gradQFunction = dynamic_cast<CGradientQFunction *>(*it);

		if ((*it)->getActions()->isMember(action))
		{
			break;
		}
		nparams += gradQFunction->getNumWeights();
	}
	return nparams;
}

void CComposedQFunction::getWeights(rlt_real *weights)
{
	rlt_real *qFuncWeights = weights;
	std::list<CAbstractQFunction *>::iterator it = qFunctions->begin();
	for (; it != qFunctions->end();it++)
	{
		CGradientQFunction *gradQFunction = dynamic_cast<CGradientQFunction *>(*it);
		gradQFunction->getWeights(qFuncWeights);
		qFuncWeights += gradQFunction->getNumWeights();
	}
}

void CComposedQFunction::setWeights(rlt_real *weights)
{
	rlt_real *qFuncWeights = weights;
	std::list<CAbstractQFunction *>::iterator it = qFunctions->begin();
	for (; it != qFunctions->end();it++)
	{
		CGradientQFunction *gradQFunction = dynamic_cast<CGradientQFunction *>(*it);
		gradQFunction->setWeights(qFuncWeights);
		qFuncWeights += gradQFunction->getNumWeights();
	}
}

void CComposedQFunction::resetData()
{
	std::list<CAbstractQFunction *>::iterator it = qFunctions->begin();
	for (; it != qFunctions->end();it++)
	{
		(*it)->resetData();
	}
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -