📄 ctorchvfunction.cpp
字号:
// Copyright (C) 2003
// Gerhard Neumann (gerhard@igi.tu-graz.ac.at)
//
// This file is part of RL Toolbox.
// http://www.igi.tugraz.at/ril_toolbox
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// 3. The name of the author may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
// OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
// IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
// NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "ril_debug.h"
#include "ctorchvfunction.h"
#include <assert.h>
#include <math.h>
CTorchFunction::CTorchFunction(Machine *machine)
{
this->machine = machine;
}
CTorchFunction::~CTorchFunction()
{
}
rlt_real CTorchFunction::getValueFromMachine(Sequence *input)
{
machine->forward(input);
return (**machine->outputs->frames);
}
Machine *CTorchFunction::getMachine()
{
return machine;
}
/// Creates a new value function learning with a torch gradient machine
CTorchGradientFunction::CTorchGradientFunction(GradientMachine *machine) : CTorchFunction(machine)
{
input = new Sequence(1, machine->n_inputs);
alpha = new Sequence(1, machine->n_outputs);
this->gradientMachine = machine;
localEtaCalc = new CTorchGradientEtaCalculator(gradientMachine);
//addParameter("InitWeightVarianceFactor", 1.0);
//addParameter("TorchNormalizeWeights", 0.0);
}
CTorchGradientFunction::~CTorchGradientFunction()
{
delete alpha;
delete input;
delete localEtaCalc;
}
void CTorchGradientFunction::getFunctionValue(CMyVector *inputVector, CMyVector *output)
{
for (int i = 0; i < getNumInputs(); i++)
{
input->frames[0][i] = inputVector->getElement(i);
}
gradientMachine->forward(input);
for (int i = 0; i < getNumOutputs(); i++)
{
output->setElement(i, gradientMachine->outputs->frames[0][i]);
}
}
void CTorchGradientFunction::updateWeights(CFeatureList *gradientFeatures)
{
Parameters *params = (gradientMachine)->params;
if (DebugIsEnabled('v'))
{
DebugPrint('v', "Updating Torch Function, Gradient: ");
gradientFeatures->saveASCII(DebugGetFileHandle('v'));
DebugPrint('v',"\n");
}
if(params)
{
CFeatureList::iterator it = gradientFeatures->begin();
for(; it != gradientFeatures->end(); it++)
{
assert((*it)->featureIndex < (unsigned int) params->n_params);
int param_index = 0;
int param_offset = 0;
while ((unsigned int) ((*it)->featureIndex - param_offset) > (unsigned int) (params->size[param_index]))
{
assert(param_index + 1 < params->n_data);
param_offset += params->size[param_index];
param_index += 1;
}
//DebugPrint('v', "\n Parameter: %d, Params Number: %d Params Size: %d\n", (*it)->featureIndex, param_index, params->size[param_index]);
//DebugPrint('v', "(%f %f)", params->data[param_index][(*it)->featureIndex - param_offset], (*it)->factor);
params->data[param_index][(*it)->featureIndex - param_offset] += (*it)->factor;
//DebugPrint('v', "\n");
}
}
/*if (getParameter("TorchNormalizeWeights") > 0.5)
{
for (int i = 0; i < params->n_data - 1; i ++)
{
double sum = 0;
for (int j = 0; j < params->size[i]; j++)
{
sum += pow(params->data[i][j], 2);
}
sum = sqrt(sum);
for (int j = 0; j < params->size[i]; j++)
{
params->data[i][j] /= 2.0;
}
}
}*/
}
void CTorchGradientFunction::resetData()
{
Parameters *params = gradientMachine->params;
int inputs = getNumInputs() + 1;
DebugPrint('v', "Torch Gradient Function InitValues:\n");
rlt_real sigma = 1.0;// getParameter("InitWeightVarianceFactor");
if (params)
{
for(int i = 0; i < params->n_data; i++)
{
for(int j = 0; j < params->size[i]; j++)
{
params->data[i][j] = CDistributions::getNormalDistributionSample(0.0, (1.0 / inputs) * sigma);
DebugPrint('v', "%f ", params->data[i][j]);
}
DebugPrint('v', "\n");
inputs = params->size[i] / inputs + 1;
}
}
}
int CTorchGradientFunction::getNumInputs()
{
return gradientMachine->n_inputs;
}
int CTorchGradientFunction::getNumOutputs()
{
return gradientMachine->n_outputs;
}
void CTorchGradientFunction::getGradient(CMyVector *inputVector, CMyVector *outputErrors, CFeatureList *gradientFeatures)
{
Parameters *params = (gradientMachine)->params;
Parameters *der_params = (gradientMachine)->der_params;
for (int i = 0; i < getNumInputs(); i ++)
{
input->frames[0][i] = inputVector->getElement(i);
}
for (int i = 0; i < getNumOutputs(); i ++)
{
alpha->frames[0][i] = outputErrors->getElement(i);
}
gradientMachine->iterInitialize();
if (der_params)
{
for(int i = 0; i < der_params->n_data; i++)
{
for(int j = 0; j < params->size[i]; j++)
{
der_params->data[i][j] = 0.0;
}
}
//memset(der_params->data[i], 0, sizeof(real)*der_params->size[i]);
}
gradientMachine->forward(input);
gradientMachine->backward(input, alpha);
DebugPrint('v', "\n Getting Torch Gradient Params Size: %d\n", getNumWeights());
if(params)
{
int param_offset = 0;
for(int i = 0; i < params->n_data; i++)
{
real *ptr_der_params = der_params->data[i];
real *ptr_params = params->data[i];
for(int j = 0; j < params->size[i]; j++)
{
gradientFeatures->set(param_offset + j, ptr_der_params[j]);
DebugPrint('v', "%f (%f)", ptr_der_params[j], ptr_params[j]);
}
DebugPrint('v', "\n");
param_offset += params->size[i];
}
}
}
int CTorchGradientFunction::getNumWeights()
{
return gradientMachine->params->n_params;
}
void CTorchGradientFunction::getInputDerivation(CMyVector *inputVector, CMyMatrix *targetMatrix)
{
targetMatrix->initMatrix(0.0);
Sequence *beta = (gradientMachine)->beta;
for (int i = 0; i < getNumInputs(); i ++)
{
input->frames[0][i] = inputVector->getElement(i);
}
for (int nout = 0; nout < getNumOutputs(); nout ++)
{
for (int i = 0; i < getNumOutputs(); i ++)
{
alpha->frames[0][i] = 0.0;
}
alpha->frames[0][nout] = 1.0;
gradientMachine->iterInitialize();
gradientMachine->forward(input);
gradientMachine->backward(input, alpha);
if(beta)
{
for (int i = 0; i < getNumInputs(); i++)
{
targetMatrix->setElement(nout, i, beta->frames[0][i]);
}
}
}
}
void CTorchGradientFunction::getWeights(rlt_real *parameters)
{
Parameters *params = (gradientMachine)->params;
if(params)
{
int paramIndex = 0;
for(int i = 0; i < params->n_data; i++)
{
real *ptr_params = params->data[i];
for(int j = 0; j < params->size[i]; j++)
{
parameters[paramIndex] = ptr_params[j];
paramIndex ++;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -