⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 neural_network.h

📁 本文介绍了一个用java语言编写的用于分析消费行为的数据挖掘应用软件
💻 H
字号:
// Neural_NetWork.h: interface for the CNeural_NetWork class.
//
//////////////////////////////////////////////////////////////////////

#if !defined(AFX_NEURAL_NETWORK_H__9C8B80B4_A05F_4604_9177_CAA21780A136__INCLUDED_)
#define AFX_NEURAL_NETWORK_H__9C8B80B4_A05F_4604_9177_CAA21780A136__INCLUDED_

#if _MSC_VER > 1000
#pragma once
#endif // _MSC_VER > 1000
#include "H_Layer.h"
#include "I_Layer.h"
#include "O_Layer.h"
#include "Weight.h"
#include <iostream.h>
class CNeural_NetWork  
{
public:
	CNeural_NetWork();
    CNeural_NetWork(short num_h); //num_h表示隐藏层数
	void InitNeural_NetWorkLayers(short num_h);
	void InitNeural_NetWork(int m_i_n,int m_o_n,int *p_o_h,float *p_hbasvalue,float *p_obasvalue,CWeight *m_wp); //初始化神经网络,参数说明:
	                       //m_i_n,输入层结点数,m_o_n输出层结点数,*p_o_h指向隐藏层结点数的指针,
	                       // *p_hbasvalue 指向隐藏层结点偏倚值的数组指针
	                       // *p_obasvalue 指向输出层结点偏倚值的数组指针
	                       // *m_wp 神经网络权重 
	void Train(float *p_data,double num_data,double m_train_t_num); //利用后向传播算法进行训练,参数1是训练数据集,
	                                                                  //参数2是数据总数,参数3是训练的最大周期数
	void SK_Train(float *p_data,double num_data,double m_min_w);  //参数1是训练数据集,参数2是数据总数,参数3是指定的最小权重增量阈值

	void TD_Train(float *p_data,double num_data,double m_err_kind);//参数1是训练数据集,参数2是数据总数,参数3是指定的误分类百分比阈值
	CString GetResult(); //输出结果
	CH_Layer& GetH_layer(short n); //取相应隐藏层
	CI_Layer& GetI_layer(); //取输入层
	CO_Layer& GetO_layer(); //取输出层
	short GetNum_h(); //取隐藏层数
	short GetNum_i(); //取输入层数
	short GetNum_o(); //取输出层数
	void SetNum_h(int n); //设置隐藏层数
	void ParametersSave(int n); //保存模型终止训练时各参数包括偏倚和权重
	BOOL WriteToSaveFile(CString m_str,CString FileName);//将参数写入文件
	int ReadFromSaveFile(CString FileName);
	float* ReadPFromSaveFile(CString FileName);
	CWeight* ReadWFromSaveFile(CString FileName);
	int *ReadPHFromSaveFile(CString FileName);
	float GetOvalue(float *p_data);
	virtual ~CNeural_NetWork();
private:
	CH_Layer *h_layers; //隐藏层
	CI_Layer *i_layers; //输入层
	CO_Layer *o_layers; //输出层
	CWeight *wp;//权重
	short Num_h;//隐藏层数
	short Num_i;//输入层数
	short Num_o;//输出层数
	float l; //学习率
	short count_of_w;//网络权重总数
	short count_of_hbas;//网络隐藏层偏倚总数
	bool isStop;

};

#endif // !defined(AFX_NEURAL_NETWORK_H__9C8B80B4_A05F_4604_9177_CAA21780A136__INCLUDED_)

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -