⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 trainpagedlg.cpp

📁 演示在角色扮演游戏中如何利用人工神经网络进行智能分类的训练
💻 CPP
字号:
// TrainPageDlg.cpp : implementation file
//

#include "stdafx.h"
#include <process.h>
#include <cmath>
#include <iostream>
#include <fstream>
#include "ANNApp.h"
#include "TrainPageDlg.h"
#include "ANNAppPropSheet.h"


#ifdef _DEBUG
#define new DEBUG_NEW
#undef THIS_FILE
static char THIS_FILE[] = __FILE__;
#endif

/////////////////////////////////////////////////////////////////////////////
// CTrainPageDlg property page

static const char sPatternsFileName[]	= "Patterns.txt";
static const char sNetSaveFileName[]	= "bpnet.nn";

IMPLEMENT_DYNCREATE(CTrainPageDlg, CPropertyPage)

CTrainPageDlg::CTrainPageDlg() : CPropertyPage(CTrainPageDlg::IDD)
{
	//{{AFX_DATA_INIT(CTrainPageDlg)
	m_learningRate = 0.4;
	m_momentum = 0.7;
	m_boolAutoLr = TRUE;
	//}}AFX_DATA_INIT

	_tolerance	= 0.1;
	_hThread	= NULL;
	_stopThread = false;
}

CTrainPageDlg::~CTrainPageDlg()
{
}

void CTrainPageDlg::DoDataExchange(CDataExchange* pDX)
{
	CPropertyPage::DoDataExchange(pDX);
	//{{AFX_DATA_MAP(CTrainPageDlg)
	DDX_Control(pDX, IDC_SLIDER_INT, m_sliderNPCExp);
	DDX_Control(pDX, IDC_SLIDER_EXP, m_sliderNPCInt);
	DDX_Control(pDX, IDC_LIST_TRAINSET, m_listTrainSet);
	DDX_Text(pDX, IDC_EDIT_LR, m_learningRate);
	DDV_MinMaxDouble(pDX, m_learningRate, 1.e-010, 1.);
	DDX_Text(pDX, IDC_EDIT_MT, m_momentum);
	DDV_MinMaxDouble(pDX, m_momentum, 0., 0.9);
	DDX_Check(pDX, IDC_CHECK_AUTOLR, m_boolAutoLr);
	//}}AFX_DATA_MAP
}


BEGIN_MESSAGE_MAP(CTrainPageDlg, CPropertyPage)
	//{{AFX_MSG_MAP(CTrainPageDlg)
	ON_WM_DESTROY()
	ON_BN_CLICKED(IDC_CHECK_AUTOLR, OnCheckAutoLR)
	ON_WM_HSCROLL()
	ON_BN_CLICKED(IDC_BUTTON_TRAIN, OnButtonTrain)
	//}}AFX_MSG_MAP
END_MESSAGE_MAP()

/////////////////////////////////////////////////////////////////////////////
// CTrainPageDlg message handlers

BOOL CTrainPageDlg::OnInitDialog() 
{
	CPropertyPage::OnInitDialog();

	// init training set list-box 
	_imageList.Create(16, 16, FALSE, 2, 0 );
	HICON Icon = ::LoadIcon(AfxGetResourceHandle(), MAKEINTRESOURCE(IDI_FAILED) );
	_imageList.Add(Icon);
	Icon = ::LoadIcon(AfxGetResourceHandle(), MAKEINTRESOURCE(IDI_OK) );
	_imageList.Add(Icon);
	

	m_listTrainSet.InsertColumn( 0, "Id", LVCFMT_LEFT, 40, 0 );
	m_listTrainSet.InsertColumn( 1, "Clothing", LVCFMT_CENTER, 55, 1 );
	m_listTrainSet.InsertColumn( 2, "Weapon", LVCFMT_CENTER, 55, 2 );
	m_listTrainSet.InsertColumn( 3, "", LVCFMT_CENTER, 20, 3 );
	m_listTrainSet.InsertColumn( 4, "Fighter", LVCFMT_CENTER, 52, 4 );
	m_listTrainSet.InsertColumn( 5, "Wizard", LVCFMT_CENTER, 52, 5 );
	m_listTrainSet.InsertColumn( 6, "Thief", LVCFMT_CENTER, 52, 6 );

	m_listTrainSet.SetImageList( &_imageList, LVSIL_SMALL );
	
	// init sliders
	m_sliderNPCInt.SetRange(1, 3);
	m_sliderNPCExp.SetRange(1, 3);

	m_sliderNPCInt.SetPos(3);
	m_sliderNPCExp.SetPos(3);

	// get .exe path
	char moduleFileName[_MAX_PATH+1];
	::GetModuleFileName(AfxGetInstanceHandle(), moduleFileName, _MAX_PATH);
	_appPath = moduleFileName;
	_appPath = _appPath.Left(_appPath.ReverseFind('\\')+1);

	// load training patterns
	CString patternsPath = _appPath + sPatternsFileName;
	if ( loadTrainingSet(patternsPath) == false) 
	{
		MessageBox( "Error Loading Training Set", "Error", MB_ICONERROR);
		PostMessage(WM_QUIT);
	}
	else
		fillTrainingSetList(_patterns.size());

	return TRUE;  // return TRUE unless you set the focus to a control
	              // EXCEPTION: OCX Property Pages should return FALSE
}



bool CTrainPageDlg::loadTrainingSet(const CString& fullPath)
{
	ifstream ist(fullPath);
	if (!ist || !ist.good())
		return false;

	while (!ist.eof() && !ist.fail())
	{
		// create pattern
		Pattern pattern(gNumInputNodes, gNumOutputNodes);

		// load pattern from file stream
		pattern.load(ist);

		// insert into _patterns vector
		_patterns.push_back(pattern);
	}	

	ist.close();

	return true;
}



void CTrainPageDlg::fillTrainingSetList(int numPatterns)
{
	m_listTrainSet.DeleteAllItems();

	CString str;
	for(int i = 0; i != numPatterns; ++i)
	{
		if (i >= _patterns.size()) // prevent overflow
			break;

		// Id
		str.Format("%d", _patterns[i].getId());
		m_listTrainSet.InsertItem(i, str, 0);
	
		// Clothing
		int idx = (int)(_patterns[i].getInput(0) * 10);
		m_listTrainSet.SetItemText(i, 1, gClothingStr[idx]);
		
		// Weapon
		idx = (int)(_patterns[i].getInput(1) * 10);
		m_listTrainSet.SetItemText(i, 2, gWeaponStr[idx]);


		// Fighter
		str.Format("%2.2f", _patterns[i].getOutput(0));
		m_listTrainSet.SetItemText(i, 4, str);

		// Wizard
		str.Format("%2.2f", _patterns[i].getOutput(1));
		m_listTrainSet.SetItemText(i, 5, str);

		// Thief
		str.Format("%2.2f", _patterns[i].getOutput(2));
		m_listTrainSet.SetItemText(i, 6, str);

	}


}


void CTrainPageDlg::OnButtonTrain() 
{
	if (_hThread == NULL)
	{
		if (UpdateData() == 0)
			return;

		_stopThread = false;

		// create separate thread for training
		unsigned int threadId = 0;
		_hThread = (HANDLE)_beginthreadex(0, 0, threadProc, this, 0, &threadId);
		if (_hThread == NULL)
		{
			MessageBox("Error starting training thread", "Error", MB_ICONERROR);
		}	
	}
	else
		_stopThread = true;
}


unsigned int __stdcall CTrainPageDlg::threadProc(void* context)
{
	if (context == NULL)
		return -1;

	CTrainPageDlg* dlgPtr = (CTrainPageDlg*)context;

	dlgPtr->trainNetwork();

	CloseHandle(dlgPtr->_hThread);
	dlgPtr->_hThread = NULL;

	return 0;
}


void CTrainPageDlg::trainNetwork()
{
	_outFullPath = "";
	enableControls(FALSE);
	GetDlgItem(IDC_BUTTON_TRAIN)->SetWindowText("Stop Training");

	double		totalError		= 0.0;
	int			iteration		= 0;
	int			good			= 0;
	int			best			= 0;
	int			numPatterns		= m_listTrainSet.GetItemCount();
	CString		outPath			= _appPath + sNetSaveFileName;
	CString		strAutoLR		= "";
	
	// create neural-network

	BPNet bpnet;
	vector<int> layers(3);
	layers[0] = gNumInputNodes;
	layers[1] = gNumHiddenNodes;
	layers[2] = gNumOutputNodes;

	bpnet.createNetwork(m_learningRate, m_momentum, layers);

	// train
	vector<int> listIconsVec(numPatterns);
	while (good < numPatterns && !_stopThread) 
	{
		good		= 0;
		totalError	= 0.0;

		// present training patterns to network
		// NOTE: 
		// It is recommended to randomly shuffle the patterns 
		// instead of presenting them in the same order each time
		// (to prevent cyclic effects)
		for (int i = 0; i < numPatterns; ++i)
		{
			bpnet.setInput(&_patterns[i]);	// set input values

			bpnet.run();				 // forward pass

			bpnet.setError(&_patterns[i]); // store output pattern value for error computation

			bpnet.learn();            // backward pass

			// test if network output is under tolerance  
			int goodOutputs = 0;
			for (int j = 0; j < gNumOutputNodes; ++j)
			{
				if (fabs( bpnet.getOutput(j) - _patterns[i].getOutput(j) ) < _tolerance)
					goodOutputs++;

				totalError += fabs(bpnet.getError(j));
			}


			// if all outputs are under error tolerance - this pattern is good
			if (goodOutputs == gNumOutputNodes)
			{
				good++;
				listIconsVec[i] = 1; // good icon (V)
			}
			else
				listIconsVec[i] = 0; // bad icon (X)

			
		}

		// save best network so far
		if (best < good)
		{
			best = good;

			// save network to file
			ofstream ost(outPath);
			bpnet.save(ost);
			ost.close();
		}

		// if user selected Auto learning-rate
		if (m_boolAutoLr)
		{
			// adjust learning-rate according to numbur of BAD patterns
			double newLR = (double)(numPatterns - good) / (double)numPatterns;
			bpnet.setLearningRate( newLR );
		}


		// display status
		if ( (iteration%500 == 0 || good == numPatterns) )
		{
			_strGood.Format("%d / %d", good, numPatterns);
			_strIter.Format("%d", iteration);
			_strTotError.Format("%4.8f", totalError);

			GetDlgItem(IDC_STATIC_GOOD)->SetWindowText(_strGood);
			GetDlgItem(IDC_STATIC_ITER)->SetWindowText(_strIter);
			GetDlgItem(IDC_STATIC_TOTERROR)->SetWindowText(_strTotError);

			if (m_boolAutoLr)
			{
				strAutoLR.Format("%f", bpnet.getLearningRate());
				GetDlgItem(IDC_EDIT_LR)->SetWindowText(strAutoLR);
			}

			LV_ITEM item;
			ZeroMemory(&item, sizeof(item));
			for(int i = 0; i != numPatterns; ++i)
			{
							
				item.iItem	= i;
				item.mask	= LVIF_IMAGE;
				item.iImage = listIconsVec[i];

				m_listTrainSet.SetItem(&item);
			}

		}

		iteration++;

	} // END WHILE loop

	GetDlgItem(IDC_BUTTON_TRAIN)->SetWindowText("Train Network");

	enableControls(TRUE);
	strAutoLR.Format("%f", m_learningRate);
	GetDlgItem(IDC_EDIT_LR)->SetWindowText(strAutoLR);


	_outFullPath = outPath;
}




void CTrainPageDlg::OnCheckAutoLR() 
{
	GetDlgItem(IDC_EDIT_LR)->EnableWindow(m_boolAutoLr);

	m_boolAutoLr = !m_boolAutoLr;
}


void CTrainPageDlg::OnHScroll(UINT nSBCode, UINT nPos, CScrollBar* pScrollBar) 
{
	CSliderCtrl* pSlider = (CSliderCtrl*)pScrollBar;
	double pct = (double)pSlider->GetPos() / pSlider->GetRangeMax();

	int sliderId = pScrollBar->GetDlgCtrlID();
	if ( sliderId == IDC_SLIDER_EXP )
	{
		// we use the value from the NPC Experience slider to control
		// the number of patterns presented to the network
		int numPatterns = ((double)_patterns.size() * pct);
		fillTrainingSetList(numPatterns);
	}
	else
	if ( sliderId == IDC_SLIDER_INT )
	{
		// we use the value from the NPC Intelligence slider to compute the tolerance
		// High Int.   gives us a tolerance of 0.1  
		// Medium Int. gives us a tolerance of 0.43
		// Low Int.	   gives us a tolerance of 0.77
		_tolerance	= (1.1 - pct); 
	}

	CPropertyPage::OnHScroll(nSBCode, nPos, pScrollBar);	
}


void CTrainPageDlg::enableControls(BOOL enable)
{
	GetDlgItem(IDC_SLIDER_INT)->EnableWindow(enable);
	GetDlgItem(IDC_SLIDER_EXP)->EnableWindow(enable);

	if (m_boolAutoLr == FALSE)
		GetDlgItem(IDC_EDIT_LR)->EnableWindow(enable);

	GetDlgItem(IDC_EDIT_MT)->EnableWindow(enable);

	GetDlgItem(IDC_CHECK_AUTOLR)->EnableWindow(enable);
}



void CTrainPageDlg::OnDestroy() 
{
	_imageList.DeleteImageList();	

	if (_hThread != NULL)
		CloseHandle(_hThread);

	CPropertyPage::OnDestroy();
}


BOOL CTrainPageDlg::OnKillActive() 
{
	CANNAppPropSheet* pSheet = (CANNAppPropSheet*)GetParent();
	pSheet->setNNFilename(_outFullPath);	
	
	return CPropertyPage::OnKillActive();
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -