📄 neural_network.h
字号:
// Neural_NetWork.h: interface for the CNeural_NetWork class.
//
//////////////////////////////////////////////////////////////////////
#if !defined(AFX_NEURAL_NETWORK_H__9C8B80B4_A05F_4604_9177_CAA21780A136__INCLUDED_)
#define AFX_NEURAL_NETWORK_H__9C8B80B4_A05F_4604_9177_CAA21780A136__INCLUDED_
#if _MSC_VER > 1000
#pragma once
#endif // _MSC_VER > 1000
#include "H_Layer.h"
#include "I_Layer.h"
#include "O_Layer.h"
#include "Weight.h"
#include <iostream.h>
class CNeural_NetWork
{
public:
CNeural_NetWork();
CNeural_NetWork(short num_h); //num_h表示隐藏层数
void InitNeural_NetWorkLayers(short num_h);
void InitNeural_NetWork(int m_i_n,int m_o_n,int *p_o_h,float *p_hbasvalue,float *p_obasvalue,CWeight *m_wp); //初始化神经网络,参数说明:
//m_i_n,输入层结点数,m_o_n输出层结点数,*p_o_h指向隐藏层结点数的指针,
// *p_hbasvalue 指向隐藏层结点偏倚值的数组指针
// *p_obasvalue 指向输出层结点偏倚值的数组指针
// *m_wp 神经网络权重
void Train(float *p_data,double num_data,double m_train_t_num); //利用后向传播算法进行训练,参数1是训练数据集,
//参数2是数据总数,参数3是训练的最大周期数
void SK_Train(float *p_data,double num_data,double m_min_w); //参数1是训练数据集,参数2是数据总数,参数3是指定的最小权重增量阈值
void TD_Train(float *p_data,double num_data,double m_err_kind);//参数1是训练数据集,参数2是数据总数,参数3是指定的误分类百分比阈值
CString GetResult(); //输出结果
CH_Layer& GetH_layer(short n); //取相应隐藏层
CI_Layer& GetI_layer(); //取输入层
CO_Layer& GetO_layer(); //取输出层
short GetNum_h(); //取隐藏层数
short GetNum_i(); //取输入层数
short GetNum_o(); //取输出层数
void SetNum_h(int n); //设置隐藏层数
void ParametersSave(int n); //保存模型终止训练时各参数包括偏倚和权重
BOOL WriteToSaveFile(CString m_str,CString FileName);//将参数写入文件
int ReadFromSaveFile(CString FileName);
float* ReadPFromSaveFile(CString FileName);
CWeight* ReadWFromSaveFile(CString FileName);
int *ReadPHFromSaveFile(CString FileName);
float GetOvalue(float *p_data);
virtual ~CNeural_NetWork();
private:
CH_Layer *h_layers; //隐藏层
CI_Layer *i_layers; //输入层
CO_Layer *o_layers; //输出层
CWeight *wp;//权重
short Num_h;//隐藏层数
short Num_i;//输入层数
short Num_o;//输出层数
float l; //学习率
short count_of_w;//网络权重总数
short count_of_hbas;//网络隐藏层偏倚总数
bool isStop;
};
#endif // !defined(AFX_NEURAL_NETWORK_H__9C8B80B4_A05F_4604_9177_CAA21780A136__INCLUDED_)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -