⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mlp_control.h

📁 用BP算法实现数字识别功能
💻 H
字号:
#ifndef __MLP_CONTROL_INCLUDE
#define __MLP_CONTROL_INCLUDE

#include <windows.h>
#include <vector>
#include <map>
#include "cell.h"

#define DIFF_RANGE             (double)0.01  //误差范围

//控制类
class CMlp_control
{
public:
	typedef void (PASCAL *PFN_DISPLAY_PROCESS)(int, double, void*);
	typedef void (PASCAL *PFN_DISPLAY)(int, int, double, double, bool, void*);
	
	typedef std::vector < std::string >     SAMPLE_VECTOR;
	typedef std::map < int, int >           POSITION_MAP;
	typedef std::vector < CNerve_cell * >   CELL_VECTOR;
	typedef std::vector < CELL_VECTOR >     LAYER_VECTOR;
	enum { LEARN_STATUS, APP_STATUS };
public:
	CMlp_control ();
	~CMlp_control();

	int       run_status;

	//初始化
	void  Initialize ( int sum_layers,
		int input,
		int hides,
		int output,
		const char * param_file,
		PFN_DISPLAY_PROCESS,
		PFN_DISPLAY,
		int nstatus = CMlp_control::LEARN_STATUS,
		void* p = NULL );

	void Run();
	void Stop ();
	void Save_param ();
	void Release_vector();

	void Insert_sample ( std::string sin )
	{
		vec_samples.push_back (sin);
	}
	void Insert_teacher (std::string sth)
	{
		vec_teacher.push_back (sth);
	}
	size_t Get_samples_size()
	{
		return vec_samples.size();
	}
	int    Choice_rand_sample (bool bfirst = false);
	void   Choice_sample ( int pos );
	double Get_sample_val( int pos );
	double Get_teacher_val ( int pos );
	static int Adjust_value( double dbval )
	{
		return (int)(dbval + 0.5); //四舍五入
	}
private:
	int   Calculate ();  //计算
	int   Evaluate (int t); //评估
	int   Modify_grads ();
	void  Modify_weight (); //修改权
	void  Calculate_diff ();  //计算误差
	void  Display(int index, CELL_VECTOR & vec, bool blast = false); //显示
	void  Run_teacher_mode ();  //训练
	void  Run_app_mode();       //测试
	template<class T> T Get_rand (T minval, T maxval);
	friend DWORD WINAPI pro_contrl (LPVOID lp)
	{
		CMlp_control * p = (CMlp_control*)lp;
		if (p)
			p->Thread_control();
		return 0;
	}
	void Thread_control ();
private:
	PFN_DISPLAY_PROCESS pfn_process;
	PFN_DISPLAY         pfn_display;

	void  *      p_void;
	double       db_output_diff;   //输出的平均误差
	
	LAYER_VECTOR  vec_layers;
	std::string   sparam_file;
	SAMPLE_VECTOR vec_samples;
	SAMPLE_VECTOR vec_teacher;
	POSITION_MAP  map_pos;
	std::string   scurr_sample;
	std::string   scurr_teacher;

	HANDLE        thread_t;
	HANDLE        event_t;
};

inline double
CMlp_control::Get_sample_val( int pos )
{
	return (scurr_sample[pos] - '0');
}

inline double
CMlp_control::Get_teacher_val ( int pos )
{
	return (scurr_teacher[pos] - '0');
}
inline int
CMlp_control::Choice_rand_sample(bool bfirst)
{
	if (bfirst)
		map_pos.clear ();
	int pos = 0;
	do{	pos = rand () % vec_samples.size (); }while ( map_pos.find (pos) != map_pos.end () );
	map_pos.insert ( std::make_pair (pos, pos) );
	scurr_sample  = vec_samples[pos];
	scurr_teacher = vec_teacher[pos];
	return pos;
}
inline void
CMlp_control::Choice_sample ( int pos )
{
	scurr_sample  = vec_samples[pos];
	scurr_teacher = vec_teacher[pos];
}
//计算
inline int
CMlp_control::Calculate ()
{
	int j,i,k,pos;
	//begin from hide layer
	for ( i = 1; i < vec_layers.size (); i++ ) {
		for (j = 0; j < vec_layers[i].size (); j++) {
			pos = (i != (vec_layers.size ()-1) ? j-1 : j );
			vec_layers[i][j]->Put_input_value (pos, vec_layers[i-1]);
			vec_layers[i][j]->Calculate ();
			Sleep(0);
		}
	}
	return 0;
}
//修改剃度
inline int
CMlp_control::Modify_grads ()
{
	int i, j;
	int cell_type;
	double dbreal, dbteach, dbgrads;
	
	for (i = vec_layers.size ()-1; i > 0; i-- )	{
		for (j = 0; j < vec_layers[i].size (); j++ ) {
			if ( vec_layers[i][j]->cell_type != CNerve_cell::VIRTUAL_TYPE )	{
				dbreal  = vec_layers[i][j]->Get_output_value ();
				if (vec_layers[i][j]->cell_type == CNerve_cell::OUTPUT_TYPE)
				{
					//修改输出层梯度
					dbteach = Get_teacher_val (j);
					dbgrads = CFormula::Instance()->Ask_output_grads (dbreal, dbteach);
				} else {
					//修改隐层梯度
					dbgrads = CFormula::Instance()->Ask_hide_grads (vec_layers[i][j], vec_layers[i+1]);
				}
				//保存修改好的梯度值
				vec_layers[i][j]->db_grads = dbgrads;
			}
			Sleep(0);
		}
	}
	return 0;
}
//△wt+1 = a△wt+ ...
inline void
CMlp_control::Modify_weight ()
{
	double dbgrads, dbreal, dbdiff_w;
	int i, j, k;
	
	for (i = vec_layers.size ()-2; i >= 0; i--)	{
		for (j = 0; j < vec_layers[i].size (); j++)	{
			dbreal   = vec_layers[i][j]->Get_output_value ();
			for (k = 0; k < vec_layers[i+1].size (); k++) {
				if (vec_layers[i+1][k]->cell_type != CNerve_cell::VIRTUAL_TYPE) {
					dbgrads  = vec_layers[i+1][k]->db_grads;
					dbdiff_w = vec_layers[i][j]->Get_diff_w (k);
					dbdiff_w = CFormula::Instance()->Ask_diff_weight (dbdiff_w, dbgrads, dbreal);
					vec_layers[i][j]->Update_weight (k, dbdiff_w);
				}
				Sleep(0);
			}
		}
	}
	
}
//统计输出层误差
inline void
CMlp_control::Calculate_diff ()
{
	int i;
	double dbreal, dbteach;
	CELL_VECTOR vcell = vec_layers[vec_layers.size ()-1];
	
	for ( i = 0; i < vcell.size (); i++ ) {
		dbreal  = vcell[i]->Get_output_value ();
		dbteach = Get_teacher_val(i);
		db_output_diff += CFormula::Ask_output_diff (dbreal, dbteach);
	}
}
//评估
inline int
CMlp_control::Evaluate (int t)
{
	double dbavg = (db_output_diff / Get_samples_size () );
	if (pfn_process)
		pfn_process (t, dbavg, p_void);
	return (dbavg > DIFF_RANGE ? -1 : 0 );
}

inline void
CMlp_control::Display (int index, CELL_VECTOR & vec, bool blast)
{
	int i;
	for ( i = 0; i < vec.size(); i++ ) {
		if ( pfn_display )
			if (blast)
				pfn_display (index, i, vec[i]->Get_old_input(), Adjust_value(vec[i]->Get_output_value ()), blast, p_void);
			else
				pfn_display (index, i, vec[i]->Get_old_input (), vec[i]->Get_output_value (), blast, p_void);
	}
}
#endif

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -