📄 trainpagedlg.cpp
字号:
// TrainPageDlg.cpp : implementation file
//
#include "stdafx.h"
#include <process.h>
#include <cmath>
#include <iostream>
#include <fstream>
#include "ANNApp.h"
#include "TrainPageDlg.h"
#include "ANNAppPropSheet.h"
#ifdef _DEBUG
#define new DEBUG_NEW
#undef THIS_FILE
static char THIS_FILE[] = __FILE__;
#endif
/////////////////////////////////////////////////////////////////////////////
// CTrainPageDlg property page
static const char sPatternsFileName[] = "Patterns.txt";
static const char sNetSaveFileName[] = "bpnet.nn";
IMPLEMENT_DYNCREATE(CTrainPageDlg, CPropertyPage)
CTrainPageDlg::CTrainPageDlg() : CPropertyPage(CTrainPageDlg::IDD)
{
//{{AFX_DATA_INIT(CTrainPageDlg)
m_learningRate = 0.4;
m_momentum = 0.7;
m_boolAutoLr = TRUE;
//}}AFX_DATA_INIT
_tolerance = 0.1;
_hThread = NULL;
_stopThread = false;
}
CTrainPageDlg::~CTrainPageDlg()
{
}
void CTrainPageDlg::DoDataExchange(CDataExchange* pDX)
{
CPropertyPage::DoDataExchange(pDX);
//{{AFX_DATA_MAP(CTrainPageDlg)
DDX_Control(pDX, IDC_SLIDER_INT, m_sliderNPCExp);
DDX_Control(pDX, IDC_SLIDER_EXP, m_sliderNPCInt);
DDX_Control(pDX, IDC_LIST_TRAINSET, m_listTrainSet);
DDX_Text(pDX, IDC_EDIT_LR, m_learningRate);
DDV_MinMaxDouble(pDX, m_learningRate, 1.e-010, 1.);
DDX_Text(pDX, IDC_EDIT_MT, m_momentum);
DDV_MinMaxDouble(pDX, m_momentum, 0., 0.9);
DDX_Check(pDX, IDC_CHECK_AUTOLR, m_boolAutoLr);
//}}AFX_DATA_MAP
}
BEGIN_MESSAGE_MAP(CTrainPageDlg, CPropertyPage)
//{{AFX_MSG_MAP(CTrainPageDlg)
ON_WM_DESTROY()
ON_BN_CLICKED(IDC_CHECK_AUTOLR, OnCheckAutoLR)
ON_WM_HSCROLL()
ON_BN_CLICKED(IDC_BUTTON_TRAIN, OnButtonTrain)
//}}AFX_MSG_MAP
END_MESSAGE_MAP()
/////////////////////////////////////////////////////////////////////////////
// CTrainPageDlg message handlers
BOOL CTrainPageDlg::OnInitDialog()
{
CPropertyPage::OnInitDialog();
// init training set list-box
_imageList.Create(16, 16, FALSE, 2, 0 );
HICON Icon = ::LoadIcon(AfxGetResourceHandle(), MAKEINTRESOURCE(IDI_FAILED) );
_imageList.Add(Icon);
Icon = ::LoadIcon(AfxGetResourceHandle(), MAKEINTRESOURCE(IDI_OK) );
_imageList.Add(Icon);
m_listTrainSet.InsertColumn( 0, "Id", LVCFMT_LEFT, 40, 0 );
m_listTrainSet.InsertColumn( 1, "Clothing", LVCFMT_CENTER, 55, 1 );
m_listTrainSet.InsertColumn( 2, "Weapon", LVCFMT_CENTER, 55, 2 );
m_listTrainSet.InsertColumn( 3, "", LVCFMT_CENTER, 20, 3 );
m_listTrainSet.InsertColumn( 4, "Fighter", LVCFMT_CENTER, 52, 4 );
m_listTrainSet.InsertColumn( 5, "Wizard", LVCFMT_CENTER, 52, 5 );
m_listTrainSet.InsertColumn( 6, "Thief", LVCFMT_CENTER, 52, 6 );
m_listTrainSet.SetImageList( &_imageList, LVSIL_SMALL );
// init sliders
m_sliderNPCInt.SetRange(1, 3);
m_sliderNPCExp.SetRange(1, 3);
m_sliderNPCInt.SetPos(3);
m_sliderNPCExp.SetPos(3);
// get .exe path
char moduleFileName[_MAX_PATH+1];
::GetModuleFileName(AfxGetInstanceHandle(), moduleFileName, _MAX_PATH);
_appPath = moduleFileName;
_appPath = _appPath.Left(_appPath.ReverseFind('\\')+1);
// load training patterns
CString patternsPath = _appPath + sPatternsFileName;
if ( loadTrainingSet(patternsPath) == false)
{
MessageBox( "Error Loading Training Set", "Error", MB_ICONERROR);
PostMessage(WM_QUIT);
}
else
fillTrainingSetList(_patterns.size());
return TRUE; // return TRUE unless you set the focus to a control
// EXCEPTION: OCX Property Pages should return FALSE
}
bool CTrainPageDlg::loadTrainingSet(const CString& fullPath)
{
ifstream ist(fullPath);
if (!ist || !ist.good())
return false;
while (!ist.eof() && !ist.fail())
{
// create pattern
Pattern pattern(gNumInputNodes, gNumOutputNodes);
// load pattern from file stream
pattern.load(ist);
// insert into _patterns vector
_patterns.push_back(pattern);
}
ist.close();
return true;
}
void CTrainPageDlg::fillTrainingSetList(int numPatterns)
{
m_listTrainSet.DeleteAllItems();
CString str;
for(int i = 0; i != numPatterns; ++i)
{
if (i >= _patterns.size()) // prevent overflow
break;
// Id
str.Format("%d", _patterns[i].getId());
m_listTrainSet.InsertItem(i, str, 0);
// Clothing
int idx = (int)(_patterns[i].getInput(0) * 10);
m_listTrainSet.SetItemText(i, 1, gClothingStr[idx]);
// Weapon
idx = (int)(_patterns[i].getInput(1) * 10);
m_listTrainSet.SetItemText(i, 2, gWeaponStr[idx]);
// Fighter
str.Format("%2.2f", _patterns[i].getOutput(0));
m_listTrainSet.SetItemText(i, 4, str);
// Wizard
str.Format("%2.2f", _patterns[i].getOutput(1));
m_listTrainSet.SetItemText(i, 5, str);
// Thief
str.Format("%2.2f", _patterns[i].getOutput(2));
m_listTrainSet.SetItemText(i, 6, str);
}
}
void CTrainPageDlg::OnButtonTrain()
{
if (_hThread == NULL)
{
if (UpdateData() == 0)
return;
_stopThread = false;
// create separate thread for training
unsigned int threadId = 0;
_hThread = (HANDLE)_beginthreadex(0, 0, threadProc, this, 0, &threadId);
if (_hThread == NULL)
{
MessageBox("Error starting training thread", "Error", MB_ICONERROR);
}
}
else
_stopThread = true;
}
unsigned int __stdcall CTrainPageDlg::threadProc(void* context)
{
if (context == NULL)
return -1;
CTrainPageDlg* dlgPtr = (CTrainPageDlg*)context;
dlgPtr->trainNetwork();
CloseHandle(dlgPtr->_hThread);
dlgPtr->_hThread = NULL;
return 0;
}
void CTrainPageDlg::trainNetwork()
{
_outFullPath = "";
enableControls(FALSE);
GetDlgItem(IDC_BUTTON_TRAIN)->SetWindowText("Stop Training");
double totalError = 0.0;
int iteration = 0;
int good = 0;
int best = 0;
int numPatterns = m_listTrainSet.GetItemCount();
CString outPath = _appPath + sNetSaveFileName;
CString strAutoLR = "";
// create neural-network
BPNet bpnet;
vector<int> layers(3);
layers[0] = gNumInputNodes;
layers[1] = gNumHiddenNodes;
layers[2] = gNumOutputNodes;
bpnet.createNetwork(m_learningRate, m_momentum, layers);
// train
vector<int> listIconsVec(numPatterns);
while (good < numPatterns && !_stopThread)
{
good = 0;
totalError = 0.0;
// present training patterns to network
// NOTE:
// It is recommended to randomly shuffle the patterns
// instead of presenting them in the same order each time
// (to prevent cyclic effects)
for (int i = 0; i < numPatterns; ++i)
{
bpnet.setInput(&_patterns[i]); // set input values
bpnet.run(); // forward pass
bpnet.setError(&_patterns[i]); // store output pattern value for error computation
bpnet.learn(); // backward pass
// test if network output is under tolerance
int goodOutputs = 0;
for (int j = 0; j < gNumOutputNodes; ++j)
{
if (fabs( bpnet.getOutput(j) - _patterns[i].getOutput(j) ) < _tolerance)
goodOutputs++;
totalError += fabs(bpnet.getError(j));
}
// if all outputs are under error tolerance - this pattern is good
if (goodOutputs == gNumOutputNodes)
{
good++;
listIconsVec[i] = 1; // good icon (V)
}
else
listIconsVec[i] = 0; // bad icon (X)
}
// save best network so far
if (best < good)
{
best = good;
// save network to file
ofstream ost(outPath);
bpnet.save(ost);
ost.close();
}
// if user selected Auto learning-rate
if (m_boolAutoLr)
{
// adjust learning-rate according to numbur of BAD patterns
double newLR = (double)(numPatterns - good) / (double)numPatterns;
bpnet.setLearningRate( newLR );
}
// display status
if ( (iteration%500 == 0 || good == numPatterns) )
{
_strGood.Format("%d / %d", good, numPatterns);
_strIter.Format("%d", iteration);
_strTotError.Format("%4.8f", totalError);
GetDlgItem(IDC_STATIC_GOOD)->SetWindowText(_strGood);
GetDlgItem(IDC_STATIC_ITER)->SetWindowText(_strIter);
GetDlgItem(IDC_STATIC_TOTERROR)->SetWindowText(_strTotError);
if (m_boolAutoLr)
{
strAutoLR.Format("%f", bpnet.getLearningRate());
GetDlgItem(IDC_EDIT_LR)->SetWindowText(strAutoLR);
}
LV_ITEM item;
ZeroMemory(&item, sizeof(item));
for(int i = 0; i != numPatterns; ++i)
{
item.iItem = i;
item.mask = LVIF_IMAGE;
item.iImage = listIconsVec[i];
m_listTrainSet.SetItem(&item);
}
}
iteration++;
} // END WHILE loop
GetDlgItem(IDC_BUTTON_TRAIN)->SetWindowText("Train Network");
enableControls(TRUE);
strAutoLR.Format("%f", m_learningRate);
GetDlgItem(IDC_EDIT_LR)->SetWindowText(strAutoLR);
_outFullPath = outPath;
}
void CTrainPageDlg::OnCheckAutoLR()
{
GetDlgItem(IDC_EDIT_LR)->EnableWindow(m_boolAutoLr);
m_boolAutoLr = !m_boolAutoLr;
}
void CTrainPageDlg::OnHScroll(UINT nSBCode, UINT nPos, CScrollBar* pScrollBar)
{
CSliderCtrl* pSlider = (CSliderCtrl*)pScrollBar;
double pct = (double)pSlider->GetPos() / pSlider->GetRangeMax();
int sliderId = pScrollBar->GetDlgCtrlID();
if ( sliderId == IDC_SLIDER_EXP )
{
// we use the value from the NPC Experience slider to control
// the number of patterns presented to the network
int numPatterns = ((double)_patterns.size() * pct);
fillTrainingSetList(numPatterns);
}
else
if ( sliderId == IDC_SLIDER_INT )
{
// we use the value from the NPC Intelligence slider to compute the tolerance
// High Int. gives us a tolerance of 0.1
// Medium Int. gives us a tolerance of 0.43
// Low Int. gives us a tolerance of 0.77
_tolerance = (1.1 - pct);
}
CPropertyPage::OnHScroll(nSBCode, nPos, pScrollBar);
}
void CTrainPageDlg::enableControls(BOOL enable)
{
GetDlgItem(IDC_SLIDER_INT)->EnableWindow(enable);
GetDlgItem(IDC_SLIDER_EXP)->EnableWindow(enable);
if (m_boolAutoLr == FALSE)
GetDlgItem(IDC_EDIT_LR)->EnableWindow(enable);
GetDlgItem(IDC_EDIT_MT)->EnableWindow(enable);
GetDlgItem(IDC_CHECK_AUTOLR)->EnableWindow(enable);
}
void CTrainPageDlg::OnDestroy()
{
_imageList.DeleteImageList();
if (_hThread != NULL)
CloseHandle(_hThread);
CPropertyPage::OnDestroy();
}
BOOL CTrainPageDlg::OnKillActive()
{
CANNAppPropSheet* pSheet = (CANNAppPropSheet*)GetParent();
pSheet->setNNFilename(_outFullPath);
return CPropertyPage::OnKillActive();
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -