⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 radialbasisnetwork.h

📁 几种神经网络的源程序
💻 H
字号:
#ifndef _RADIALBASISNETWORK_H#define _RADIALBASISNETWORK_H#include "Network.h"#include "SimpleNeuron.h"#include "CenterNeuron.h"#include "InputLayer.h"#include "TrainingSet.h"namespace annie{/** A Radial Basis Function Neural Network.  *  * The network consists of a layer of N-inputs, then h-(N-dimensional) centers, and   * some outputs.  * The default activation function for the centers is gaussian(), which is the gaussian  * distribution function with sigma = 1. If you want to change that then you'll have  * to write your own activation function and change it with setCenterActivationFunction().  * The output neurons use the identity() function as the activation function, thus the   * output is simply the weighted sum of the outputs of the centers.  *  * There are atleast two ways to train a radial basis network.  *  * One involves fixed centers and training the weights only in order to minimize  * the error over a training set. The other involves using the gradient descent  * rule to adjust both centers and weights. Currently only the former is implemented  * in annie using trainWeights().  *  * @see trainWeights  * \todo Implement gradient-descent rule based updation of centers and weights  * \todo The copy constructor  */class RadialBasisNetwork : public Network{protected:	/** Number of centers in the network.	  * If you plan to extend this class, then the onus of keeping this value	  * consistent lies on you	  */	int _nCenters;	/// Layer of input. Each member is an InputNeuron	InputLayer *_inputLayer;	/// Layer of centers, each member is a CenterNEuron	Layer *_centerLayer;	/// Layer of output, each member if a SimpleNeuron	Layer *_outputLayer;public:	/** Creates a Radial basis function network. All the outputs will have a bias.	  * @param inputs Number of inputs taken in by the network	  * @param centers Number of centers the network has. Each center will be 	  *			an inputs-dimensional point	  * @param outputs The number of outputs given by the neuron. All of them will have	  *			a bias	  */	RadialBasisNetwork(int inputs, int centers, int outputs);	/** Loads a network from a text file	  * @see save	  * @param filename Name of the file from which to load network structure	  * @throws Exception On any error	  */	RadialBasisNetwork(const char *filename);	/// Copy constructor, NOT YET IMPLEMENTED	RadialBasisNetwork(RadialBasisNetwork &srcNet);	virtual ~RadialBasisNetwork();	/** Sets the ith center point to the given point.	  * @param i The center that is to be changed	  * @param center The getInputCount() dimensional point	  */	virtual void setCenter(uint i, Vector &center);	/** Sets the ith center point to the given point.	  * @param i The center that is to be changed	  * @param center The getInputCount() dimensional point	  */	virtual void setCenter(uint i, real *center);	/** Returns the point corresponding to the ith center.	  * @param i The center whose point is wanted	  * @return The getInputCount() dimensional point corresponding to the 	  *		ith center	  */	virtual Vector getCenter(int i) const;	/// The number of centers in the network	virtual uint getCenterCount() const;	/** Sets the bias of the ith output.	  * @param i The index of the output (0<=i<getOutputCount()).	  * @param bias The bias to be given to that output.	  * @throws Exception if that neuron isn't allowed to be biased.	  */	virtual void setBias(uint i, real bias);	/** Returns the bias of the ith output	  * @param i The index of the output (0<=i<getOutputCount()).	  * @return The bias of the output. If there is no bias, it returns 0.0	  */	virtual real getBias(uint i) const;	/** Prevents the ith output from having any bias.	  * @param i The index of the output (0<=i<getOutputCount()).	  * @throws Exception if the index given is invalid	  */	virtual void removeBias(uint i);	/** Sets the weight between the given center and output	* @param center Index of the center (0<=center<getCenterCount()).	* @param output Index of the output (0<=output<getOutputCount()).	* @param weight The weight to give to the link between the center and output.	* @throws Exception if any of the parameters given is invalid	*/	virtual void setWeight(int center, int output, real weight);	/** Returns the weight of the link between the given center and output	* @param center Index of the center (0<=center<getCenterCount()).	* @param output Index of the output (0<=output<getOutputCount()).	* @throws Exception if any of the parameters given is invalid	*/	virtual real getWeight(int center, int output) const;	/** Returns the output of the network for the given input.	  * @param input A vector of getDimension() reals	  * @return The corresponding output of the network	  */	virtual Vector getOutput(Vector &input);	/** Wrapper function to allow getOutput() to work for an array	  * of real as input as well.	  * Does exactly the same thing as Network::getOutput(real*).	  */	virtual Vector getOutput(real *input);	/** Sets the activation function of the center neurons.	  * (The activation function is gaussian by default)	  * @param f The activation function to be used.	  * @param df The derivation of the activation function, used in gradient descent training	  */	void setCenterActivationFunction(ActivationFunction f, ActivationFunction df);	/** Trains the weights of the network, centers are kept fixed.	  * @param T The TrainingSet from which input/desired-output pairs will be obtained	  */	void trainWeights(TrainingSet &T);	//void trainCentersAndWeights(TrainingSet &T, int epochs, real learningRate);	///Returns "RadialBasisNetwork"	virtual const char *getClassName() const;	///Save the network to a text file	virtual void save(const char *filename);};}; //namespace annie#endif // define _RADIALBASISNETWORK_H

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -