⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ctorchvfunction.cpp

📁 强化学习算法(R-Learning)难得的珍贵资料
💻 CPP
📖 第 1 页 / 共 2 页
字号:
// Copyright (C) 2003
// Gerhard Neumann (gerhard@igi.tu-graz.ac.at)

//                
// This file is part of RL Toolbox.
// http://www.igi.tugraz.at/ril_toolbox
//
// All rights reserved.
// 
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
//    notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
//    notice, this list of conditions and the following disclaimer in the
//    documentation and/or other materials provided with the distribution.
// 3. The name of the author may not be used to endorse or promote products
//    derived from this software without specific prior written permission.
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
// OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
// IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
// NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "ril_debug.h"
#include "ctorchvfunction.h"
#include <assert.h>
#include <math.h>

CTorchFunction::CTorchFunction(Machine *machine) 
{
	this->machine = machine;
}

CTorchFunction::~CTorchFunction()
{
}

rlt_real CTorchFunction::getValueFromMachine(Sequence *input)
{
	machine->forward(input);
	return (**machine->outputs->frames);
}

Machine *CTorchFunction::getMachine()
{
	return machine;
}


	/// Creates a new value function learning with a torch gradient machine	
CTorchGradientFunction::CTorchGradientFunction(GradientMachine *machine) : CTorchFunction(machine)
{
	input = new Sequence(1, machine->n_inputs);
	alpha = new Sequence(1, machine->n_outputs);	

	this->gradientMachine = machine;

	localEtaCalc = new CTorchGradientEtaCalculator(gradientMachine);

	//addParameter("InitWeightVarianceFactor", 1.0);
	//addParameter("TorchNormalizeWeights", 0.0);
}

CTorchGradientFunction::~CTorchGradientFunction()
{
	delete alpha;
	delete input;

	delete localEtaCalc;
}

void CTorchGradientFunction::getFunctionValue(CMyVector *inputVector, CMyVector *output)
{
	for (int i = 0; i < getNumInputs(); i++)
	{
		input->frames[0][i] = inputVector->getElement(i);
	}

	gradientMachine->forward(input);

	for (int i = 0; i < getNumOutputs(); i++)
	{
		output->setElement(i, gradientMachine->outputs->frames[0][i]);
	}
}

void CTorchGradientFunction::updateWeights(CFeatureList *gradientFeatures)
{
	Parameters *params = (gradientMachine)->params;

	if (DebugIsEnabled('v'))
	{
		DebugPrint('v', "Updating Torch Function, Gradient: ");
		gradientFeatures->saveASCII(DebugGetFileHandle('v'));
		DebugPrint('v',"\n");
	}

	if(params)
	{
		CFeatureList::iterator it = gradientFeatures->begin();

		for(; it != gradientFeatures->end(); it++)
		{
			assert((*it)->featureIndex < (unsigned int) params->n_params);

			int param_index = 0;
			int param_offset = 0;

			while ((unsigned int) ((*it)->featureIndex - param_offset) > (unsigned int) (params->size[param_index]))
			{
				assert(param_index + 1 < params->n_data);
				param_offset += params->size[param_index];
				param_index += 1;
			}

			//DebugPrint('v', "\n Parameter: %d, Params Number: %d Params Size: %d\n", (*it)->featureIndex, param_index, params->size[param_index]);

			//DebugPrint('v', "(%f %f)", params->data[param_index][(*it)->featureIndex - param_offset], (*it)->factor);

			params->data[param_index][(*it)->featureIndex - param_offset] += (*it)->factor;

			//DebugPrint('v', "\n");
		}
	}
	/*if (getParameter("TorchNormalizeWeights") > 0.5)
	{
		for (int i = 0; i < params->n_data - 1; i ++)
		{
			double sum = 0;

			for (int j = 0; j < params->size[i]; j++)
			{
				sum += pow(params->data[i][j], 2);
			}
			sum = sqrt(sum);
			
			for (int j = 0; j < params->size[i]; j++)
			{
				params->data[i][j] /= 2.0;
			}
		}
	}*/
}

void CTorchGradientFunction::resetData()
{
	Parameters *params = gradientMachine->params;

	int inputs = getNumInputs() + 1;
	DebugPrint('v', "Torch Gradient Function InitValues:\n");

	rlt_real sigma = 1.0;// getParameter("InitWeightVarianceFactor");

	if (params)
	{
		for(int i = 0; i < params->n_data; i++)
		{
			for(int j = 0; j < params->size[i]; j++)
			{
				params->data[i][j] = CDistributions::getNormalDistributionSample(0.0, (1.0 / inputs) * sigma);
				DebugPrint('v', "%f ", params->data[i][j]);

			}
			DebugPrint('v',  "\n");

			inputs = params->size[i] / inputs + 1;
		}
	}
}

int CTorchGradientFunction::getNumInputs()
{
	return gradientMachine->n_inputs;
}

int CTorchGradientFunction::getNumOutputs()
{
	return gradientMachine->n_outputs;
}


void CTorchGradientFunction::getGradient(CMyVector *inputVector, CMyVector *outputErrors, CFeatureList *gradientFeatures)
{
	Parameters *params = (gradientMachine)->params;
	Parameters *der_params = (gradientMachine)->der_params;

	for (int i = 0; i < getNumInputs(); i ++)
	{
		input->frames[0][i] = inputVector->getElement(i);
	}

	for (int i = 0; i < getNumOutputs(); i ++)
	{
		alpha->frames[0][i] = outputErrors->getElement(i);
	}

	gradientMachine->iterInitialize();
	if (der_params)
	{
		for(int i = 0; i < der_params->n_data; i++)
		{
			for(int j = 0; j < params->size[i]; j++)
			{
				der_params->data[i][j] = 0.0; 
			}
		}
		//memset(der_params->data[i], 0, sizeof(real)*der_params->size[i]);
	}

	gradientMachine->forward(input);
	gradientMachine->backward(input, alpha);

	DebugPrint('v', "\n Getting Torch Gradient Params Size: %d\n", getNumWeights());

	if(params)
	{
		int param_offset = 0;
		for(int i = 0; i < params->n_data; i++)
		{
			real *ptr_der_params = der_params->data[i];
			real *ptr_params = params->data[i];

			for(int j = 0; j < params->size[i]; j++)
			{
				gradientFeatures->set(param_offset + j, ptr_der_params[j]);
				DebugPrint('v', "%f (%f)", ptr_der_params[j], ptr_params[j]);
			}
			DebugPrint('v',  "\n");

		
			param_offset += params->size[i];
		}
	}
}

int CTorchGradientFunction::getNumWeights()
{
	return gradientMachine->params->n_params;
}

void CTorchGradientFunction::getInputDerivation(CMyVector *inputVector, CMyMatrix *targetMatrix)
{
	targetMatrix->initMatrix(0.0);

	Sequence *beta = (gradientMachine)->beta;

	for (int i = 0; i < getNumInputs(); i ++)
	{
		input->frames[0][i] = inputVector->getElement(i);
	}


	for (int nout = 0; nout < getNumOutputs(); nout ++)
	{
		for (int i = 0; i < getNumOutputs(); i ++)
		{
			alpha->frames[0][i] = 0.0;
		}

		alpha->frames[0][nout] = 1.0;

		gradientMachine->iterInitialize();

		gradientMachine->forward(input);
		gradientMachine->backward(input, alpha);

		if(beta)
		{
			for (int i = 0; i < getNumInputs(); i++)
			{
				targetMatrix->setElement(nout, i, beta->frames[0][i]);
			}
		}
	}
	
}

void CTorchGradientFunction::getWeights(rlt_real *parameters)
{
	Parameters *params = (gradientMachine)->params;

	if(params)
	{
		int paramIndex = 0;
		for(int i = 0; i < params->n_data; i++)
		{
			real *ptr_params = params->data[i];

			for(int j = 0; j < params->size[i]; j++)
			{
				parameters[paramIndex] = ptr_params[j];
				paramIndex ++;
			}
		}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -