📄 neuralnet.h
字号:
/*
filename: NeuralNet.h
function: define the NeuralNet class
BPNet,Perseption // Hopfield
@copyright
date : 2008-12-22
author : loop111
e-mail : loop111@gmail.com
*/
#ifndef NEURALNET_H
#define NEURALNET_H
#include "NerveCell.h"
#include "PR_unit.h"
#include <iostream>
using namespace std;
#include<fstream>
#include<string>
#include <vector>
#include <time.h>
class NeuralNet//abstract product
{
public:
NeuralNet():num_layer(2){}
NeuralNet(int *a,int n,bool hidden = false):num_layer(n),num_input(a[0]),
num_output(a[n-1]),yita(1),thresh_error(0.01),alpha(0.8)//build a n layer network with the cell num a[i] in layer[i]
{
for (int i=0;i<n-1;i++)
{
Layer *layer = new Layer(a[i]+1,hidden); // add 1 hidden node
layers.push_back(layer);
}
Layer *layer = new Layer(a[n-1]);
layers.push_back(layer);
dest = new PR_unit(num_output);
}
~NeuralNet()
{
for(int i=0;i<samples.size();i++)
{
delete samples[i];
}
if (dest != NULL)
{
delete dest;
}
}
virtual void Connect() = 0; //connect each pair of cell
virtual bool LoadData(ifstream &fin);//load the train samples data
//virtual void SaveData();
virtual void LoadWeightData();//ifstream &fin
virtual void SaveWeightData();//ofstream &fout
virtual void InitialWeight();
double GetWeight(int m,int i,int j); //get the j weight of the i cell in m layer
double GetLastWeight(int m,int i,int j);
void SetWeight(int m,int i,int j,double value); //set the j weight of the i cell in m layer
double GetOutput(int m,int i); //get the i output in m layer
virtual void ChooseSample(PR_unit *unit);
virtual void CalcOutput();
virtual double CalcError();
virtual double CalcSampleError();
virtual void CalcDelta(){};
virtual void Train(){};
virtual void UpdateWight(){};
void PrintStruct();
void PrintSamples();
void PrintWeight();
void PrintOutput();
void PrintDelta();
void PrintError();
public:
vector< PR_unit* > samples;
PR_unit *dest; //the destination of the sample
static const int max_iterator;
int dimension;
int d_output; //the dimension of desire output vector
double yita;
double alpha;
double thresh_error;//the error to detimine when it ends
double error;
protected:
int num_layer;
int num_input;
int num_output;
int num_samples;
// int num_vector;
vector<Layer*> layers;
private:
};
class PerceptionNet:public NeuralNet
{
public:
PerceptionNet(int *a,int n):NeuralNet(a,n)
{
Connect();
}
void Connect();
};
class BPNet:public NeuralNet
{
public:
BPNet(int *a,int n):NeuralNet(a,n,true)
{
Connect();
}
void Connect();
void CalcDelta();
void Train();
void UpdateWight();
};
#endif
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -