⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cadaptivesoftmaxnetwork.h

📁 强化学习算法(R-Learning)难得的珍贵资料
💻 H
字号:
// Copyright (C) 2003
// Gerhard Neumann (gerhard@igi.tu-graz.ac.at)

//                
// This file is part of RL Toolbox.
// http://www.igi.tugraz.at/ril_toolbox
//
// All rights reserved.
// 
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
//    notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
//    notice, this list of conditions and the following disclaimer in the
//    documentation and/or other materials provided with the distribution.
// 3. The name of the author may not be used to endorse or promote products
//    derived from this software without specific prior written permission.
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
// OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
// IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
// NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#ifndef C_ADAPTIVESOFTMAX__H
#define C_ADAPTIVESOFTMAX__H

#include <stdio.h> 
#include <vector>
#include <list>
#include <map>
#include "cfeaturefunction.h"
#include "cgradientfunction.h"
#include "cvfunction.h"
#include "cerrorlistener.h"
#include "clinearfafeaturecalculator.h"

#include "ril_debug.h"

class CRBFCenter
{
public:
	CRBFCenter(int numDim, CMyVector *centers, CMyVector *sigmas);
	CRBFCenter(int numDim);
	~CRBFCenter();

	CMyVector *centers;
	CMyVector *sigmas;

	int numCenter;

	rlt_real getFactor(CState *state);

};

/*
class CRBFCenterSortedList
{
protected:
	std::list<CRBFCenter *> *sortedList;
	int dimension;

	
public:
	CRBFCenterSortedList(int dimension);
	~CRBFCenterSortedList();

	void addRBFCenter(CRBFCenter *center);
	void changeRBFCenter(CRBFCenter *center, rlt_real difference);

	std::list<CRBFCenter *>::iterator getLeftRBFCenter(rlt_real value);
	std::list<CRBFCenter *>::iterator getRightRBFCenter(rlt_real value);
};*/


class CAdaptiveSoftMaxNetworkEtaCalculator : public CAdaptiveEtaCalculator
{
protected:
	int numDim;

public:
	CAdaptiveSoftMaxNetworkEtaCalculator(int numDim);

	virtual void getWeightUpdates(CFeatureList *updates);
	
};


class CAdaptiveSoftMaxNetwork : public CGradientUpdateFunction, public CFeatureCalculator
{
protected:
	CAdaptiveSoftMaxNetworkEtaCalculator *softMaxEtaCalc;
	
	int maxCenters;
	int maxActiveCenters;
	int currentCenters;

	int numDim;


	std::map<int, CRBFCenter *> *centers;
	std::list<CRBFCenter *> *startCenters;

	CMyVector *startSigma;
	CMyVector *epsilon;

	int featureOffset;


	std::list<CRBFCenter *> *searchList1;
	std::list<CRBFCenter *> *searchList2;

	CFeatureList *sortedList;
public:

	CAdaptiveSoftMaxNetwork(CStateProperties *stateProperties, int maxCenters, int maxActiveCenters, int featureOffset, CMyVector *startSigma, CMyVector *epsilon);
	~CAdaptiveSoftMaxNetwork();

	void clearCenters();

	void addStartCenter(CRBFCenter *center);
	int addRBFCenter(CRBFCenter *center);

	void addStartCenters();

	virtual void saveData(FILE *stream);
	virtual void loadData(FILE *stream);

	virtual void addCenterGrid(CGridFeatureCalculator *grid, CMyVector *sigmas);

	virtual void getGradient(CStateCollection *state, int featureIndex, CFeatureList *gradientFeatures);
	virtual void updateWeights(CFeatureList *gradientFeatures);

	virtual int getNumWeights();

	virtual void resetData();

	virtual void getWeights(rlt_real *parameters);
	virtual void setWeights(rlt_real *parameters);

	virtual int addCenterOnError(rlt_real error, CStateCollection *state);

	virtual void getModifiedState(CStateCollection *state, CState *targetState);

	virtual int getMaxCenters() {return maxCenters;};
};

class CAdaptiveSoftMaxVFunction : public CFeatureVFunction, public CErrorListener
{
protected:
	CAdaptiveSoftMaxNetwork *adaptiveSoftMaxNetwork;

	CFeatureList *gradient1List;
	CFeatureList *gradient2List;

public:
	CAdaptiveSoftMaxVFunction(CAdaptiveSoftMaxNetwork *adaptiveSoftMaxNetwork);
	~CAdaptiveSoftMaxVFunction();

	virtual void updateWeights(CFeatureList *gradientFeatures);

	virtual CAbstractVETraces *getStandardETraces();

	virtual void getGradient(CStateCollection *state, CFeatureList *gradientFeatures);

	virtual int getNumWeights();

	virtual void resetData();

	virtual void saveData(FILE *stream);
	virtual void loadData(FILE *stream);

	virtual void getWeights(rlt_real *parameters);
	virtual void setWeights(rlt_real *parameters);

	virtual void receiveError(rlt_real error, CStateCollection *state, CAction *action, CActionData *data = NULL);
};


#endif

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -