⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 neuralnet.h

📁 神经网络的结构实现
💻 H
字号:
/*
filename: NeuralNet.h
function: define the NeuralNet class
BPNet,Perseption // Hopfield
@copyright
date	: 2008-12-22
author  : loop111
e-mail  : loop111@gmail.com
*/

#ifndef NEURALNET_H
#define NEURALNET_H

#include "NerveCell.h"
#include "PR_unit.h"

#include <iostream>
using namespace std;

#include<fstream>
#include<string>
#include <vector>

#include <time.h>

class NeuralNet//abstract product
{
public:
	
	NeuralNet():num_layer(2){}
	NeuralNet(int *a,int n,bool hidden = false):num_layer(n),num_input(a[0]),
		num_output(a[n-1]),yita(1),thresh_error(0.01),alpha(0.8)//build a n layer network with the cell num a[i] in layer[i]
	{
		for (int i=0;i<n-1;i++)
		{
			Layer *layer = new Layer(a[i]+1,hidden); // add 1 hidden node
			layers.push_back(layer);
		}

		Layer *layer = new Layer(a[n-1]);
		layers.push_back(layer);

		dest = new PR_unit(num_output);
	}
	~NeuralNet()
	{
		for(int i=0;i<samples.size();i++)
		{
			delete samples[i];
		}
		if (dest != NULL)
		{
			delete dest;
		}
	}
	
	virtual void Connect() = 0;				//connect each pair of cell
	
	virtual bool LoadData(ifstream &fin);//load the train samples data
	//virtual void SaveData();

	virtual void LoadWeightData();//ifstream &fin
	virtual void SaveWeightData();//ofstream &fout
	
	virtual void InitialWeight();
	double GetWeight(int m,int i,int j);	//get the j weight of the i cell in m layer
	double GetLastWeight(int m,int i,int j);
	void SetWeight(int m,int i,int j,double value);		//set the j weight of the i cell in m layer
	double GetOutput(int m,int i);			//get the i output in m layer
	

	virtual void ChooseSample(PR_unit *unit);
	virtual void CalcOutput();
	virtual double CalcError();
	virtual double CalcSampleError();
	virtual void CalcDelta(){};
	virtual void Train(){};
    virtual void UpdateWight(){};

	void PrintStruct();
	void PrintSamples();
	void PrintWeight();
	void PrintOutput();
	void PrintDelta();
	void PrintError();
	


public:
	vector< PR_unit* > samples;
	PR_unit *dest;		//the destination of the sample

	static const int max_iterator; 
	int dimension;
	int d_output;		//the dimension of desire output vector
	double yita;
	double alpha;

	double thresh_error;//the error to detimine when it ends
	double error;
	

protected:
	int num_layer;
	int num_input;
	int num_output;
	int num_samples;
	// 	int	num_vector;

	vector<Layer*> layers;

private:
	
};

class PerceptionNet:public NeuralNet
{
public:
	PerceptionNet(int *a,int n):NeuralNet(a,n)
	{
		Connect();
	}
	
	void Connect();

};

class BPNet:public NeuralNet
{
public:
	BPNet(int *a,int n):NeuralNet(a,n,true)
	{
		Connect();
	}
	
	void Connect();
	void CalcDelta();
	void Train();
	void UpdateWight();
};

#endif

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -