📄 ctorchvfunction.cpp
字号:
}
}
void CTorchGradientFunction::setWeights(rlt_real *parameters)
{
Parameters *params = (gradientMachine)->params;
if(params)
{
int paramIndex = 0;
for(int i = 0; i < params->n_data; i++)
{
real *ptr_params = params->data[i];
for(int j = 0; j < params->size[i]; j++)
{
ptr_params[j] = parameters[paramIndex];
paramIndex ++;
}
}
}
if (DebugIsEnabled('t'))
{
DebugPrint('t', "Setting Torch Weights: ");
saveData(DebugGetFileHandle('t'));
}
}
CTorchGradientEtaCalculator::CTorchGradientEtaCalculator(GradientMachine *gradientMachine) : CIndividualEtaCalculator(gradientMachine->params->n_params)
{
Parameters *params = gradientMachine->params;
int inputs = gradientMachine->n_inputs + 1;
int neurons = 1;
int parameterIndex = 0;
rlt_real factor = 1.0;
if (params)
{
for(int i = 0; i < params->n_data; i++)
{
for(int j = 0; j < params->size[i]; j++)
{
this->etas[parameterIndex] = factor;
parameterIndex ++;
}
inputs = params->size[i] / inputs + 1;
neurons = inputs - 1;
factor = 1 / sqrt((rlt_real) neurons);
}
}
}
CTorchVFunction::CTorchVFunction(CTorchFunction *torchFunction, CStateProperties *properties) : CAbstractVFunction(properties)
{
input = new Sequence(1, properties->getNumContinuousStates() + properties->getNumDiscreteStates());
this->torchFunction = torchFunction;
}
CTorchVFunction::~CTorchVFunction()
{
delete input;
}
void CTorchVFunction::getInputSequence(CState *state, Sequence *sequence)
{
for (unsigned int i = 0; i < state->getNumActiveDiscreteStates(); i ++)
{
sequence->frames[0][i] = state->getContinuousState(i);
}
for (unsigned int i = 0; i < state->getNumActiveDiscreteStates(); i++)
{
sequence->frames[0][i + state->getNumContinuousStates()] = state->getDiscreteState(i);
}
}
rlt_real CTorchVFunction::getValue(CState *state)
{
getInputSequence(state, input);
rlt_real value = torchFunction->getValueFromMachine(input);
if (!mayDiverge && (value < - DIVERGENTVFUNCTIONVALUE || value > DIVERGENTVFUNCTIONVALUE))
{
throw new CDivergentVFunctionException("Torch VFunction", this, state, value);
}
return value;
}
CVFunctionFromGradientFunction::CVFunctionFromGradientFunction(CGradientFunction *l_gradientFunction, CStateProperties *properties) : CGradientVFunction(properties) , CVFunctionInputDerivationCalculator(properties)
{
this->gradientFunction = l_gradientFunction;
assert(properties->getNumContinuousStates() + properties->getNumDiscreteStates() == gradientFunction->getNumInputs() && gradientFunction->getNumOutputs() == 1);
input = new CMyVector(properties->getNumContinuousStates() + properties->getNumDiscreteStates());
outputError = new CMyVector(1);
outputError->setElement(0, 1.0);
this->inputDerivation = new CMyMatrix(1, properties->getNumContinuousStates() + properties->getNumDiscreteStates());
addParameters(l_gradientFunction);
}
CVFunctionFromGradientFunction::~CVFunctionFromGradientFunction()
{
delete input;
delete outputError;
delete inputDerivation;
}
void CVFunctionFromGradientFunction::getInputSequence(CState *state, CMyVector *sequence)
{
for (unsigned int i = 0; i < state->getNumActiveContinuousStates(); i ++)
{
sequence->setElement(i, state->getContinuousState(i));
}
for (unsigned int i = 0; i < state->getNumActiveDiscreteStates(); i++)
{
sequence->setElement(i + state->getNumContinuousStates(), state->getContinuousState(i));
}
}
void CVFunctionFromGradientFunction::setValue(CState *state, rlt_real value)
{
updateValue(state, value - getValue(state));
}
void CVFunctionFromGradientFunction::resetData()
{
gradientFunction->resetData();
}
rlt_real CVFunctionFromGradientFunction::getValue(CState *state)
{
getInputSequence(state, input);
gradientFunction->getFunctionValue(input, outputError);
rlt_real value = outputError->getElement(0);
if (!mayDiverge && (value < - DIVERGENTVFUNCTIONVALUE || value > DIVERGENTVFUNCTIONVALUE))
{
throw new CDivergentVFunctionException("Torch VFunction", this, state, value);
}
return value;
}
void CVFunctionFromGradientFunction::updateWeights(CFeatureList *gradientFeatures)
{
gradientFunction->updateWeights(gradientFeatures);
}
int CVFunctionFromGradientFunction::getNumWeights()
{
return gradientFunction->getNumWeights();
}
void CVFunctionFromGradientFunction::getGradient(CStateCollection *originalState, CFeatureList *modifiedState)
{
CState *state = originalState->getState(this->getStateProperties());
getInputSequence(state, input);
outputError->setElement(0, 1.0);
gradientFunction->getGradient(input, outputError, modifiedState);
}
void CVFunctionFromGradientFunction::getInputDerivation(CStateCollection *originalState, CMyVector *targetVector)
{
CState *state = originalState->getState(this->getStateProperties());
getInputSequence(state, input);
gradientFunction->getInputDerivation(input, inputDerivation);
memcpy(targetVector->getData(), inputDerivation->getRow(0), sizeof(rlt_real) * gradientFunction->getNumInputs());
}
CAbstractVETraces *CVFunctionFromGradientFunction::getStandardETraces()
{
return new CGradientVETraces(this);
}
void CVFunctionFromGradientFunction::getWeights(rlt_real *parameters)
{
gradientFunction->getWeights(parameters);
}
void CVFunctionFromGradientFunction::setWeights(rlt_real *parameters)
{
gradientFunction->setWeights(parameters);
}
CQFunctionFromGradientFunction::CQFunctionFromGradientFunction(CContinuousAction *contAction, CGradientFunction *gradientFunction, CActionSet *actions, CStateProperties *properties) : CContinuousActionQFunction(contAction), CStateObject(properties)
{
assert(properties->getNumContinuousStates() + properties->getNumDiscreteStates() + contAction->getNumDimensions() == gradientFunction->getNumInputs() && gradientFunction->getNumOutputs() == 1);
input = new CMyVector(properties->getNumContinuousStates() + properties->getNumDiscreteStates() + contAction->getNumDimensions());
outputError = new CMyVector(1);
outputError->setElement(0, 1.0);
this->gradientFunction = gradientFunction;
staticActions = actions;
}
CQFunctionFromGradientFunction::~CQFunctionFromGradientFunction()
{
delete input;
delete outputError;
}
void CQFunctionFromGradientFunction::getInputSequence(CMyVector *sequence, CState *state, CContinuousActionData *data)
{
for (unsigned int i = 0; i < state->getNumContinuousStates(); i ++)
{
sequence->setElement(i, state->getContinuousState(i));
}
for (unsigned int i = 0; i < state->getNumDiscreteStates(); i++)
{
sequence->setElement(i + state->getNumContinuousStates(), state->getDiscreteState(i));
}
for (unsigned int i = 0; i < data->getNumDimensions(); i++)
{
rlt_real min = contAction->getContinuousActionProperties()->getMinActionValue(i);
rlt_real width = contAction->getContinuousActionProperties()->getMaxActionValue(i) - min;
sequence->setElement(i + state->getNumContinuousStates() + state->getNumDiscreteStates(), ((data->getActionValue(i) - min) / width) * 2 - 1.0);
}
}
void CQFunctionFromGradientFunction::getBestContinuousAction(CStateCollection *state, CContinuousActionData *actionData)
{
CAction *staticAction = CAbstractQFunction::getMax(state, staticActions);
actionData->setData(staticAction->getActionData());
}
void CQFunctionFromGradientFunction::updateCAValue(CStateCollection *state, CContinuousActionData *data, rlt_real td)
{
this->localGradientFeatureBuffer->clear();
getCAGradient(state, data, localGradientFeatureBuffer);
updateGradient(localGradientFeatureBuffer, td);
}
void CQFunctionFromGradientFunction::setCAValue(CStateCollection *state, CContinuousActionData *data, rlt_real qValue)
{
updateCAValue(state, data, qValue - getCAValue(state, data));
}
rlt_real CQFunctionFromGradientFunction::getCAValue(CStateCollection *state, CContinuousActionData *data)
{
getInputSequence(input, state->getState(properties), data);
gradientFunction->getFunctionValue(input, outputError);
return outputError->getElement(0);
}
void CQFunctionFromGradientFunction::getCAGradient(CStateCollection *state, CContinuousActionData *data, CFeatureList *gradient)
{
getInputSequence(input, state->getState(properties), data);
outputError->setElement(0, 1.0);
gradientFunction->getGradient(input, outputError, gradient);
}
void CQFunctionFromGradientFunction::updateWeights(CFeatureList *gradientFeatures)
{
gradientFunction->updateWeights(gradientFeatures);
}
int CQFunctionFromGradientFunction::getNumWeights()
{
return gradientFunction->getNumWeights();
}
void CQFunctionFromGradientFunction::resetData()
{
gradientFunction->resetData();
}
void CQFunctionFromGradientFunction::getWeights(rlt_real *weights)
{
gradientFunction->getWeights(weights);
}
void CQFunctionFromGradientFunction::setWeights(rlt_real *parameters)
{
gradientFunction->setWeights(parameters);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -