📄 cadaptivesoftmaxnetwork.h
字号:
// Copyright (C) 2003
// Gerhard Neumann (gerhard@igi.tu-graz.ac.at)
//
// This file is part of RL Toolbox.
// http://www.igi.tugraz.at/ril_toolbox
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// 3. The name of the author may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
// OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
// IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
// NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
// THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef C_ADAPTIVESOFTMAX__H
#define C_ADAPTIVESOFTMAX__H
#include <stdio.h>
#include <vector>
#include <list>
#include <map>
#include "cfeaturefunction.h"
#include "cgradientfunction.h"
#include "cvfunction.h"
#include "cerrorlistener.h"
#include "clinearfafeaturecalculator.h"
#include "ril_debug.h"
class CRBFCenter
{
public:
CRBFCenter(int numDim, CMyVector *centers, CMyVector *sigmas);
CRBFCenter(int numDim);
~CRBFCenter();
CMyVector *centers;
CMyVector *sigmas;
int numCenter;
rlt_real getFactor(CState *state);
};
/*
class CRBFCenterSortedList
{
protected:
std::list<CRBFCenter *> *sortedList;
int dimension;
public:
CRBFCenterSortedList(int dimension);
~CRBFCenterSortedList();
void addRBFCenter(CRBFCenter *center);
void changeRBFCenter(CRBFCenter *center, rlt_real difference);
std::list<CRBFCenter *>::iterator getLeftRBFCenter(rlt_real value);
std::list<CRBFCenter *>::iterator getRightRBFCenter(rlt_real value);
};*/
class CAdaptiveSoftMaxNetworkEtaCalculator : public CAdaptiveEtaCalculator
{
protected:
int numDim;
public:
CAdaptiveSoftMaxNetworkEtaCalculator(int numDim);
virtual void getWeightUpdates(CFeatureList *updates);
};
class CAdaptiveSoftMaxNetwork : public CGradientUpdateFunction, public CFeatureCalculator
{
protected:
CAdaptiveSoftMaxNetworkEtaCalculator *softMaxEtaCalc;
int maxCenters;
int maxActiveCenters;
int currentCenters;
int numDim;
std::map<int, CRBFCenter *> *centers;
std::list<CRBFCenter *> *startCenters;
CMyVector *startSigma;
CMyVector *epsilon;
int featureOffset;
std::list<CRBFCenter *> *searchList1;
std::list<CRBFCenter *> *searchList2;
CFeatureList *sortedList;
public:
CAdaptiveSoftMaxNetwork(CStateProperties *stateProperties, int maxCenters, int maxActiveCenters, int featureOffset, CMyVector *startSigma, CMyVector *epsilon);
~CAdaptiveSoftMaxNetwork();
void clearCenters();
void addStartCenter(CRBFCenter *center);
int addRBFCenter(CRBFCenter *center);
void addStartCenters();
virtual void saveData(FILE *stream);
virtual void loadData(FILE *stream);
virtual void addCenterGrid(CGridFeatureCalculator *grid, CMyVector *sigmas);
virtual void getGradient(CStateCollection *state, int featureIndex, CFeatureList *gradientFeatures);
virtual void updateWeights(CFeatureList *gradientFeatures);
virtual int getNumWeights();
virtual void resetData();
virtual void getWeights(rlt_real *parameters);
virtual void setWeights(rlt_real *parameters);
virtual int addCenterOnError(rlt_real error, CStateCollection *state);
virtual void getModifiedState(CStateCollection *state, CState *targetState);
virtual int getMaxCenters() {return maxCenters;};
};
class CAdaptiveSoftMaxVFunction : public CFeatureVFunction, public CErrorListener
{
protected:
CAdaptiveSoftMaxNetwork *adaptiveSoftMaxNetwork;
CFeatureList *gradient1List;
CFeatureList *gradient2List;
public:
CAdaptiveSoftMaxVFunction(CAdaptiveSoftMaxNetwork *adaptiveSoftMaxNetwork);
~CAdaptiveSoftMaxVFunction();
virtual void updateWeights(CFeatureList *gradientFeatures);
virtual CAbstractVETraces *getStandardETraces();
virtual void getGradient(CStateCollection *state, CFeatureList *gradientFeatures);
virtual int getNumWeights();
virtual void resetData();
virtual void saveData(FILE *stream);
virtual void loadData(FILE *stream);
virtual void getWeights(rlt_real *parameters);
virtual void setWeights(rlt_real *parameters);
virtual void receiveError(rlt_real error, CStateCollection *state, CAction *action, CActionData *data = NULL);
};
#endif
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -