📄 radialbasisnetwork.h
字号:
#ifndef _RADIALBASISNETWORK_H#define _RADIALBASISNETWORK_H#include "Network.h"#include "SimpleNeuron.h"#include "CenterNeuron.h"#include "InputLayer.h"#include "TrainingSet.h"namespace annie{/** A Radial Basis Function Neural Network. * * The network consists of a layer of N-inputs, then h-(N-dimensional) centers, and * some outputs. * The default activation function for the centers is gaussian(), which is the gaussian * distribution function with sigma = 1. If you want to change that then you'll have * to write your own activation function and change it with setCenterActivationFunction(). * The output neurons use the identity() function as the activation function, thus the * output is simply the weighted sum of the outputs of the centers. * * There are atleast two ways to train a radial basis network. * * One involves fixed centers and training the weights only in order to minimize * the error over a training set. The other involves using the gradient descent * rule to adjust both centers and weights. Currently only the former is implemented * in annie using trainWeights(). * * @see trainWeights * \todo Implement gradient-descent rule based updation of centers and weights * \todo The copy constructor */class RadialBasisNetwork : public Network{protected: /** Number of centers in the network. * If you plan to extend this class, then the onus of keeping this value * consistent lies on you */ int _nCenters; /// Layer of input. Each member is an InputNeuron InputLayer *_inputLayer; /// Layer of centers, each member is a CenterNEuron Layer *_centerLayer; /// Layer of output, each member if a SimpleNeuron Layer *_outputLayer;public: /** Creates a Radial basis function network. All the outputs will have a bias. * @param inputs Number of inputs taken in by the network * @param centers Number of centers the network has. Each center will be * an inputs-dimensional point * @param outputs The number of outputs given by the neuron. All of them will have * a bias */ RadialBasisNetwork(int inputs, int centers, int outputs); /** Loads a network from a text file * @see save * @param filename Name of the file from which to load network structure * @throws Exception On any error */ RadialBasisNetwork(const char *filename); /// Copy constructor, NOT YET IMPLEMENTED RadialBasisNetwork(RadialBasisNetwork &srcNet); virtual ~RadialBasisNetwork(); /** Sets the ith center point to the given point. * @param i The center that is to be changed * @param center The getInputCount() dimensional point */ virtual void setCenter(uint i, Vector ¢er); /** Sets the ith center point to the given point. * @param i The center that is to be changed * @param center The getInputCount() dimensional point */ virtual void setCenter(uint i, real *center); /** Returns the point corresponding to the ith center. * @param i The center whose point is wanted * @return The getInputCount() dimensional point corresponding to the * ith center */ virtual Vector getCenter(int i) const; /// The number of centers in the network virtual uint getCenterCount() const; /** Sets the bias of the ith output. * @param i The index of the output (0<=i<getOutputCount()). * @param bias The bias to be given to that output. * @throws Exception if that neuron isn't allowed to be biased. */ virtual void setBias(uint i, real bias); /** Returns the bias of the ith output * @param i The index of the output (0<=i<getOutputCount()). * @return The bias of the output. If there is no bias, it returns 0.0 */ virtual real getBias(uint i) const; /** Prevents the ith output from having any bias. * @param i The index of the output (0<=i<getOutputCount()). * @throws Exception if the index given is invalid */ virtual void removeBias(uint i); /** Sets the weight between the given center and output * @param center Index of the center (0<=center<getCenterCount()). * @param output Index of the output (0<=output<getOutputCount()). * @param weight The weight to give to the link between the center and output. * @throws Exception if any of the parameters given is invalid */ virtual void setWeight(int center, int output, real weight); /** Returns the weight of the link between the given center and output * @param center Index of the center (0<=center<getCenterCount()). * @param output Index of the output (0<=output<getOutputCount()). * @throws Exception if any of the parameters given is invalid */ virtual real getWeight(int center, int output) const; /** Returns the output of the network for the given input. * @param input A vector of getDimension() reals * @return The corresponding output of the network */ virtual Vector getOutput(Vector &input); /** Wrapper function to allow getOutput() to work for an array * of real as input as well. * Does exactly the same thing as Network::getOutput(real*). */ virtual Vector getOutput(real *input); /** Sets the activation function of the center neurons. * (The activation function is gaussian by default) * @param f The activation function to be used. * @param df The derivation of the activation function, used in gradient descent training */ void setCenterActivationFunction(ActivationFunction f, ActivationFunction df); /** Trains the weights of the network, centers are kept fixed. * @param T The TrainingSet from which input/desired-output pairs will be obtained */ void trainWeights(TrainingSet &T); //void trainCentersAndWeights(TrainingSet &T, int epochs, real learningRate); ///Returns "RadialBasisNetwork" virtual const char *getClassName() const; ///Save the network to a text file virtual void save(const char *filename);};}; //namespace annie#endif // define _RADIALBASISNETWORK_H
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -