⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ctorchvfunction.cpp

📁 强化学习算法(R-Learning)难得的珍贵资料
💻 CPP
📖 第 1 页 / 共 2 页
字号:
	}
}

void CTorchGradientFunction::setWeights(rlt_real *parameters)
{
	Parameters *params = (gradientMachine)->params;

	

	if(params)
	{
		int paramIndex = 0;
		for(int i = 0; i < params->n_data; i++)
		{
			real *ptr_params = params->data[i];

			for(int j = 0; j < params->size[i]; j++)
			{
				ptr_params[j] = parameters[paramIndex]; 
				paramIndex ++;
			}
		}
	}
	if (DebugIsEnabled('t'))
	{
		DebugPrint('t', "Setting Torch Weights: ");

		saveData(DebugGetFileHandle('t'));
	}
}

CTorchGradientEtaCalculator::CTorchGradientEtaCalculator(GradientMachine *gradientMachine) : CIndividualEtaCalculator(gradientMachine->params->n_params)
{
	Parameters *params = gradientMachine->params;

	int inputs = gradientMachine->n_inputs + 1;
	int neurons = 1;
	int parameterIndex = 0;
	rlt_real factor = 1.0;
	if (params)
	{
		for(int i = 0; i < params->n_data; i++)
		{
			for(int j = 0; j < params->size[i]; j++)
			{
				this->etas[parameterIndex] = factor;
				parameterIndex ++;
			}
			inputs = params->size[i] / inputs + 1;
			neurons = inputs - 1;
			factor = 1 / sqrt((rlt_real) neurons);
		}
	}
}

CTorchVFunction::CTorchVFunction(CTorchFunction *torchFunction, CStateProperties *properties) : CAbstractVFunction(properties)
{
	input = new Sequence(1, properties->getNumContinuousStates() + properties->getNumDiscreteStates());

	this->torchFunction = torchFunction;
}

CTorchVFunction::~CTorchVFunction()
{
	delete input;
}

void CTorchVFunction::getInputSequence(CState *state, Sequence *sequence)
{
	for (unsigned int i = 0; i < state->getNumActiveDiscreteStates(); i ++)
	{
		sequence->frames[0][i] = state->getContinuousState(i);
	}
	for (unsigned int i = 0; i < state->getNumActiveDiscreteStates(); i++)
	{
		sequence->frames[0][i + state->getNumContinuousStates()] = state->getDiscreteState(i);
	}
}

rlt_real CTorchVFunction::getValue(CState *state)
{
	getInputSequence(state, input);
	rlt_real value = torchFunction->getValueFromMachine(input);
	
	if (!mayDiverge && (value < - DIVERGENTVFUNCTIONVALUE || value > DIVERGENTVFUNCTIONVALUE))
	{
		throw new CDivergentVFunctionException("Torch VFunction", this, state, value);
	}
	
	return value;
}

CVFunctionFromGradientFunction::CVFunctionFromGradientFunction(CGradientFunction *l_gradientFunction, CStateProperties *properties) : CGradientVFunction(properties) , CVFunctionInputDerivationCalculator(properties)
{
	this->gradientFunction = l_gradientFunction;

	assert(properties->getNumContinuousStates() + properties->getNumDiscreteStates() == gradientFunction->getNumInputs() && gradientFunction->getNumOutputs() == 1);

	input = new CMyVector(properties->getNumContinuousStates() + properties->getNumDiscreteStates());
	outputError = new CMyVector(1);
	outputError->setElement(0, 1.0);

	this->inputDerivation = new CMyMatrix(1, properties->getNumContinuousStates() + properties->getNumDiscreteStates());

	addParameters(l_gradientFunction);
}

CVFunctionFromGradientFunction::~CVFunctionFromGradientFunction()
{
	delete input;
	delete outputError;
	delete inputDerivation;
}

void CVFunctionFromGradientFunction::getInputSequence(CState *state, CMyVector *sequence)
{
	for (unsigned int i = 0; i < state->getNumActiveContinuousStates(); i ++)
	{
		sequence->setElement(i, state->getContinuousState(i));
	}
	for (unsigned int i = 0; i < state->getNumActiveDiscreteStates(); i++)
	{
		sequence->setElement(i + state->getNumContinuousStates(), state->getContinuousState(i));
	}
}

void CVFunctionFromGradientFunction::setValue(CState *state, rlt_real value)
{
	updateValue(state, value - getValue(state));
}

void CVFunctionFromGradientFunction::resetData()
{
	gradientFunction->resetData();
}


rlt_real CVFunctionFromGradientFunction::getValue(CState *state)
{
	getInputSequence(state, input);
	gradientFunction->getFunctionValue(input, outputError);

	rlt_real value = outputError->getElement(0);

	if (!mayDiverge && (value < - DIVERGENTVFUNCTIONVALUE || value > DIVERGENTVFUNCTIONVALUE))
	{
		throw new CDivergentVFunctionException("Torch VFunction", this, state, value);
	}

	return value;
}
	
void CVFunctionFromGradientFunction::updateWeights(CFeatureList *gradientFeatures)
{
	gradientFunction->updateWeights(gradientFeatures);
}

int CVFunctionFromGradientFunction::getNumWeights()
{
	return gradientFunction->getNumWeights();
}

void CVFunctionFromGradientFunction::getGradient(CStateCollection *originalState, CFeatureList *modifiedState)
{
	CState *state = originalState->getState(this->getStateProperties());

	getInputSequence(state, input);
	outputError->setElement(0, 1.0);

	gradientFunction->getGradient(input, outputError, modifiedState);
}

void CVFunctionFromGradientFunction::getInputDerivation(CStateCollection *originalState, CMyVector *targetVector)
{
	CState *state = originalState->getState(this->getStateProperties());

	getInputSequence(state, input);

	gradientFunction->getInputDerivation(input, inputDerivation);

	memcpy(targetVector->getData(), inputDerivation->getRow(0), sizeof(rlt_real) * gradientFunction->getNumInputs());
}

CAbstractVETraces *CVFunctionFromGradientFunction::getStandardETraces()
{
	return new CGradientVETraces(this);
}

void CVFunctionFromGradientFunction::getWeights(rlt_real *parameters)
{
	gradientFunction->getWeights(parameters);
}

void CVFunctionFromGradientFunction::setWeights(rlt_real *parameters)
{
	gradientFunction->setWeights(parameters);
}

CQFunctionFromGradientFunction::CQFunctionFromGradientFunction(CContinuousAction *contAction, CGradientFunction *gradientFunction, CActionSet *actions, CStateProperties *properties) : CContinuousActionQFunction(contAction), CStateObject(properties)
{
	assert(properties->getNumContinuousStates() + properties->getNumDiscreteStates() + contAction->getNumDimensions() == gradientFunction->getNumInputs() && gradientFunction->getNumOutputs() == 1);

	input = new CMyVector(properties->getNumContinuousStates() + properties->getNumDiscreteStates() + contAction->getNumDimensions());
	outputError = new CMyVector(1);
	outputError->setElement(0, 1.0);

	this->gradientFunction = gradientFunction;

	staticActions = actions;
}


CQFunctionFromGradientFunction::~CQFunctionFromGradientFunction()
{
	delete input;
	delete outputError;
}

void CQFunctionFromGradientFunction::getInputSequence(CMyVector *sequence, CState *state, CContinuousActionData *data)
{
	for (unsigned int i = 0; i < state->getNumContinuousStates(); i ++)
	{
		sequence->setElement(i, state->getContinuousState(i));
	}
	for (unsigned int i = 0; i < state->getNumDiscreteStates(); i++)
	{
		sequence->setElement(i + state->getNumContinuousStates(), state->getDiscreteState(i));
	}
	for (unsigned int i = 0; i < data->getNumDimensions(); i++)
	{
		rlt_real min =  contAction->getContinuousActionProperties()->getMinActionValue(i);
		rlt_real width = contAction->getContinuousActionProperties()->getMaxActionValue(i) - min;

		sequence->setElement(i + state->getNumContinuousStates() + state->getNumDiscreteStates(), ((data->getActionValue(i) - min) / width) * 2  - 1.0);
	}
}

void CQFunctionFromGradientFunction::getBestContinuousAction(CStateCollection *state, CContinuousActionData *actionData)
{
	CAction *staticAction = CAbstractQFunction::getMax(state, staticActions);
	actionData->setData(staticAction->getActionData());
}

void CQFunctionFromGradientFunction::updateCAValue(CStateCollection *state, CContinuousActionData *data, rlt_real td)
{
	this->localGradientFeatureBuffer->clear();

	getCAGradient(state, data, localGradientFeatureBuffer);

	updateGradient(localGradientFeatureBuffer, td);
}

void CQFunctionFromGradientFunction::setCAValue(CStateCollection *state, CContinuousActionData *data, rlt_real qValue)
{
	updateCAValue(state, data, qValue - getCAValue(state, data));
}

rlt_real CQFunctionFromGradientFunction::getCAValue(CStateCollection *state, CContinuousActionData *data)
{
	getInputSequence(input, state->getState(properties), data);
	
	gradientFunction->getFunctionValue(input, outputError);

	return outputError->getElement(0);
}


void CQFunctionFromGradientFunction::getCAGradient(CStateCollection *state, CContinuousActionData *data, CFeatureList *gradient)
{
	getInputSequence(input, state->getState(properties), data);
	outputError->setElement(0, 1.0);

	gradientFunction->getGradient(input, outputError, gradient);
}

void CQFunctionFromGradientFunction::updateWeights(CFeatureList *gradientFeatures)
{
	gradientFunction->updateWeights(gradientFeatures);
}

int CQFunctionFromGradientFunction::getNumWeights()
{
	return gradientFunction->getNumWeights();
}

void CQFunctionFromGradientFunction::resetData()
{
	gradientFunction->resetData();
}


void CQFunctionFromGradientFunction::getWeights(rlt_real *weights)
{
	gradientFunction->getWeights(weights);
}

void CQFunctionFromGradientFunction::setWeights(rlt_real *parameters)
{
	gradientFunction->setWeights(parameters);
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -