⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mlp_control.cpp

📁 用BP算法实现数字识别功能
💻 CPP
字号:
#include "Formula.h"
#include "mlp_control.h"
#include <time.h>


/*class CControl*/
CMlp_control::CMlp_control()
{
	run_status     = LEARN_STATUS;
	db_output_diff = 0.0;

	srand( (unsigned)time( NULL ) );
}

CMlp_control::~CMlp_control()
{
	Release_vector();
}
//ini the class
void CMlp_control::Initialize (int layers,
							   int input,
							   int hides,
							   int output,
							   const char * param_file,
							   PFN_DISPLAY_PROCESS pfn,
							   PFN_DISPLAY pfn_dsp,
							   int nstatus,
							   void * p)
{
	//
	int    i, j, k;
	double dbw  = 0.0;
	FILE * file = 0;
	
	run_status  = nstatus;
	sparam_file = param_file;
	pfn_process = pfn;
	pfn_display = pfn_dsp;
	p_void      = p;

	vec_samples.clear ();
	vec_teacher.clear ();
	CNerve_cell * pcell = NULL;

	if ( (file = fopen ( param_file, "r+b" )) != NULL )
		fseek ( file, 0L, SEEK_SET);
	for (i = 0; i < layers; i++)
	{
		CELL_VECTOR cell_vec;
		if (i == 0)   //输入层
		{
			for (j = 0; j < input+1; j++)
			{
				pcell = new CNerve_cell (j == 0 ? CNerve_cell::VIRTUAL_TYPE : CNerve_cell::INPUT_TYPE);
				for (k = 0; k < hides; k++) //输入层到隐层的权值
				{
					if (file)
						fread ( &dbw, 1, sizeof(dbw), file );
					else
						dbw = Get_rand (0.1, 0.8);
					pcell->Put_weight (k, dbw);
				}
				cell_vec.push_back (pcell);
			}
		}
		else if (i == (layers-1)) //输出层
		{
			for (j = 0; j < output; j++)
				cell_vec.push_back (new CNerve_cell (CNerve_cell::OUTPUT_TYPE) );
		}else
		{
			for (j = 0; j < hides+1; j++)
			{
				pcell = new CNerve_cell ( j == 0 ? CNerve_cell::VIRTUAL_TYPE : CNerve_cell::HIDE_TYPE);
				for (k = 0; k < output; k++)
				{
					if (file)
						fread ( &dbw, 1, sizeof(dbw), file );
					else
						dbw = Get_rand (0.1, 0.8);
					pcell->Put_weight (k, dbw);
				}
				cell_vec.push_back (pcell);
			}
		}
		vec_layers.push_back (cell_vec);
	}
	if ( file )
		fclose(file);
}
//释放
void CMlp_control::Release_vector()
{
	int i, j;

	for (i = 0; i < vec_layers.size (); i++ )
		for (j = 0; j < vec_layers[i].size (); j++)
			delete vec_layers[i][j];
	vec_layers.clear ();
}

//运行
void CMlp_control::Run()
{
	event_t  = CreateEvent (NULL, TRUE, FALSE, NULL );
	thread_t = CreateThread (NULL, 0, pro_contrl, this, 0, NULL);
}
//停止
void CMlp_control::Stop ()
{
	SetEvent (event_t);
	WaitForSingleObject (thread_t, INFINITE);
	CloseHandle (event_t);
	CloseHandle (thread_t);
}
void CMlp_control::Thread_control()
{
	switch (run_status)
	{
	case LEARN_STATUS:
		Run_teacher_mode ();
		break;
	case APP_STATUS:
		Run_app_mode ();
		break;
	default:
		break;
	} 
}
void CMlp_control::Run_teacher_mode()
{
	int i, j, num = 0, pos;
	do
	{
		if (WaitForSingleObject (event_t, 0) == WAIT_OBJECT_0 )
			break;

		db_output_diff = 0.0;  //初始化总误差数
		for ( i = 0; i < Get_samples_size(); i++)
		{
			pos = Choice_rand_sample ( i == 0 ? true : false);
			//初始化样本
			for ( j = 1; j < vec_layers[0].size (); j++ )
				vec_layers[0][j]->Put_input_value (Get_sample_val (j-1) );
			
			Calculate ();
			//显示中间过程
			Display (pos, vec_layers[vec_layers.size ()-1], false);
			Calculate_diff ();
			Modify_grads ();
			Modify_weight (); //修改各权重
			Sleep(1);
		}
		if (Evaluate(num) == 0)
			break;
		//评估
	} while ( ++num );
}

void CMlp_control::Run_app_mode()
{
	int i, j;
	
	for ( i = 0; i < Get_samples_size (); i++ )
	{
		Choice_sample (i);
		//初始化样本
		for ( j = 1; j < vec_layers[0].size (); j++ )
			vec_layers[0][j]->Put_input_value (Get_sample_val (j-1) );
		
		Calculate();
		Display (i, vec_layers[vec_layers.size ()-1], true);
	}
}
void CMlp_control::Save_param ()
{
	int i, j, k, pos;
	FILE * file;
	double dbw;

	if ( ( file = fopen ( sparam_file.c_str (), "w+b")) != NULL )
	{
		for (i = 0; i < vec_layers.size ()-1; i++ )
		{
			for (j = 0; j < vec_layers[i].size (); j++ )
			{
				for (k = 0, pos = 0; k < vec_layers[i+1].size (); k++, pos++)
				{
					if (vec_layers[i+1][k]->cell_type == CNerve_cell::VIRTUAL_TYPE)
					{
						pos--;
						continue;
					}
					dbw = vec_layers[i][j]->Get_weight (pos);
					fwrite (&dbw, 1, sizeof (dbw), file );
				}
			}
		}
		fclose (file);
	}
}

template<class T>
T CMlp_control::Get_rand (T minval, T maxval)
{
	T val;
	do { val = (T)rand ()*((double)maxval/(double)RAND_MAX); } while (val < minval);
	return val;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -