⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 patternhw4view.cpp

📁 Neural Network program for pattern classification
💻 CPP
📖 第 1 页 / 共 2 页
字号:
// PatternHW4View.cpp : implementation of the CPatternHW4View class
//

#include "stdafx.h"
#include "PatternHW4.h"

#include "PatternHW4Doc.h"
#include "PatternHW4View.h"
#include "Chart.h"
#include "PrmDlg.h"
#include <math.h>

#ifdef _DEBUG
#define new DEBUG_NEW
#undef THIS_FILE
static char THIS_FILE[] = __FILE__;
#endif

/////////////////////////////////////////////////////////////////////////////
// CPatternHW4View

IMPLEMENT_DYNCREATE(CPatternHW4View, CScrollView)

BEGIN_MESSAGE_MAP(CPatternHW4View, CScrollView)
	//{{AFX_MSG_MAP(CPatternHW4View)
	ON_COMMAND(IDC_TRAINING1, OnTraining1)
	ON_UPDATE_COMMAND_UI(IDC_TRAINING1, OnUpdateTraining1)
	ON_COMMAND(IDC_TESTING1, OnTesting1)
	ON_UPDATE_COMMAND_UI(IDC_TESTING1, OnUpdateTesting1)
	ON_COMMAND(IDC_REGRESSTRN, OnRegresstrn)
	ON_COMMAND(IDC_REGRESSTST, OnRegresstst)
	ON_UPDATE_COMMAND_UI(IDC_REGRESSTRN, OnUpdateRegresstrn)
	ON_UPDATE_COMMAND_UI(IDC_REGRESSTST, OnUpdateRegresstst)
	ON_COMMAND(IDC_TESTTRAIN, OnTesttrain)
	//}}AFX_MSG_MAP
	// Standard printing commands
	ON_COMMAND(ID_FILE_PRINT, CScrollView::OnFilePrint)
	ON_COMMAND(ID_FILE_PRINT_DIRECT, CScrollView::OnFilePrint)
	ON_COMMAND(ID_FILE_PRINT_PREVIEW, CScrollView::OnFilePrintPreview)
END_MESSAGE_MAP()

/////////////////////////////////////////////////////////////////////////////
// CPatternHW4View construction/destruction

CPatternHW4View::CPatternHW4View()
{
	CreateMode = 0;
	m_Chart2d.SetChartTitle("Pattern Recognition Homework #4");
	m_Chart2d.SetChartLabel("Epoch","Learning Curve");
	m_Chart2d.SetRange(0,5000,-1,1);
	m_Chart2d.SetGridXYNumber(8,15);
	m_Chart2d.mpSerie[0].m_plotColor = RGB(30,255,30);
	m_Chart2d.mpSerie[0].IsLine=TRUE;
	m_Chart2d.m_BGColor = RGB(255,255,255);
	for(int i=0;i<m_Chart2d.nSerieCount;i++)
	{
		m_Chart2d.Graph_Name[i].Format("Graph");
	}
	TrainFinish = FALSE;
	TrRgFinish = FALSE;
	EnterFile = FALSE;
}

CPatternHW4View::~CPatternHW4View()
{
}

BOOL CPatternHW4View::PreCreateWindow(CREATESTRUCT& cs)
{
	// TODO: Modify the Window class or styles here by modifying
	//  the CREATESTRUCT cs

	return CScrollView::PreCreateWindow(cs);
}

/////////////////////////////////////////////////////////////////////////////
// CPatternHW4View drawing

void CPatternHW4View::OnDraw(CDC* pDC)
{
	CPatternHW4Doc* pDoc = GetDocument();
	ASSERT_VALID(pDoc);
	
	series = 0;
	
	CRect Rect;
	GetClientRect(Rect);
	Rect.bottom-=5;
	Rect.top+=5;
	Rect.left+=5;
	Rect.right-=5;
	
	//Allocate space for series .
	if ( !m_Chart2d.AllocSerie(MAX_INDEX) ) {
		AfxMessageBox("Error allocating chart serie") ;
        return;
	}
	
	if(CreateMode == 0){
		m_Chart2d.Create(WS_CHILD|WS_VISIBLE,Rect,this,0);
		CreateMode = 1;
	}
}

void CPatternHW4View::OnInitialUpdate()
{
	CScrollView::OnInitialUpdate();

	CSize sizeTotal;
	// TODO: calculate the total size of this view
	sizeTotal.cx = sizeTotal.cy = 100;
	SetScrollSizes(MM_TEXT, sizeTotal);
}

/////////////////////////////////////////////////////////////////////////////
// CPatternHW4View printing

BOOL CPatternHW4View::OnPreparePrinting(CPrintInfo* pInfo)
{
	// default preparation
	return DoPreparePrinting(pInfo);
}

void CPatternHW4View::OnBeginPrinting(CDC* /*pDC*/, CPrintInfo* /*pInfo*/)
{
	// TODO: add extra initialization before printing
}

void CPatternHW4View::OnEndPrinting(CDC* /*pDC*/, CPrintInfo* /*pInfo*/)
{
	// TODO: add cleanup after printing
}

/////////////////////////////////////////////////////////////////////////////
// CPatternHW4View diagnostics

#ifdef _DEBUG
void CPatternHW4View::AssertValid() const
{
	CScrollView::AssertValid();
}

void CPatternHW4View::Dump(CDumpContext& dc) const
{
	CScrollView::Dump(dc);
}

CPatternHW4Doc* CPatternHW4View::GetDocument() // non-debug version is inline
{
	ASSERT(m_pDocument->IsKindOf(RUNTIME_CLASS(CPatternHW4Doc)));
	return (CPatternHW4Doc*)m_pDocument;
}
#endif //_DEBUG

/////////////////////////////////////////////////////////////////////////////
// CPatternHW4View message handlers
/************************************************************************/
/* OnTraining1()														*/
/* Name: OnTraining1													*/
/* Parameter: No														*/
/* Return: No															*/
/* Explain: Training													*/
/************************************************************************/
void CPatternHW4View::OnTraining1() 
{
	CPatternHW4Doc* pDoc = GetDocument();
	CPrmDlg pDlg;
	ASSERT_VALID(pDoc);

	if(pDlg.DoModal() == IDOK){
		register int i;
		pDoc->EnterData(TRAINMODE);								// Enter data into the train data array

		m_Chart2d.ClearChart();
		m_Chart2d.SetChartTitle("PR HW #4 Training Result");
		m_Chart2d.SetChartLabel("Epoch","Sum of Squared Error");
		m_Chart2d.SetRange(0,MAXEPOCH,0,0.5);
		m_Chart2d.SetGridXYNumber(10,5);
		
		m_Chart2d.mpSerie[0].m_plotColor = RGB(255,0,0);
		m_Chart2d.mpSerie[1].m_plotColor = RGB(0,255,0);
		m_Chart2d.mpSerie[2].m_plotColor = RGB(0,0,255);
		m_Chart2d.mpSerie[3].m_plotColor = RGB(255,255,0);
		m_Chart2d.mpSerie[4].m_plotColor = RGB(255,0,255);
		m_Chart2d.mpSerie[5].m_plotColor = RGB(0,255,255);
		
		m_Chart2d.mpSerie[0].IsLine = TRUE;
		m_Chart2d.mpSerie[1].IsLine = TRUE;
		m_Chart2d.mpSerie[2].IsLine = TRUE;
		m_Chart2d.mpSerie[3].IsLine = TRUE;
		m_Chart2d.mpSerie[4].IsLine = TRUE;
		m_Chart2d.mpSerie[5].IsLine = TRUE;

		m_Chart2d.Graph_Name[0].Format("MMT 0.0");
		m_Chart2d.Graph_Name[1].Format("MMT 0.1");
		m_Chart2d.Graph_Name[2].Format("MMT 0.3");
		m_Chart2d.Graph_Name[3].Format("MMT 0.5");
		m_Chart2d.Graph_Name[4].Format("MMT 0.7");
		m_Chart2d.Graph_Name[5].Format("MMT 0.8");
/*		
		m_Chart2d.Graph_Name[0].Format("HN 3");
		m_Chart2d.Graph_Name[1].Format("HN 4");
		m_Chart2d.Graph_Name[2].Format("HN 5");
		m_Chart2d.Graph_Name[3].Format("HN 6");
		m_Chart2d.Graph_Name[4].Format("HN 7");
		m_Chart2d.Graph_Name[5].Format("HN 8");
		
		pDoc->MLP(pDlg.m_hid_lay,pDlg.m_learn_rate,3,pDlg.m_momentum,pDlg.m_threshold);
		series = 0;
		for(i=0;i<MAXEPOCH;i++){
			m_Chart2d.SetXYValue(i,pDoc->m_SSE[i],i,series);
		}
		pDoc->MLP(pDlg.m_hid_lay,pDlg.m_learn_rate,4,pDlg.m_momentum,pDlg.m_threshold);
		series = 1;
		for(i=0;i<MAXEPOCH;i++){
			m_Chart2d.SetXYValue(i,pDoc->m_SSE[i],i,series);
		}
		pDoc->MLP(pDlg.m_hid_lay,pDlg.m_learn_rate,5,pDlg.m_momentum,pDlg.m_threshold);
		series = 2;
		for(i=0;i<MAXEPOCH;i++){
			m_Chart2d.SetXYValue(i,pDoc->m_SSE[i],i,series);
		}
		pDoc->MLP(pDlg.m_hid_lay,pDlg.m_learn_rate,6,pDlg.m_momentum,pDlg.m_threshold);
		series = 3;
		for(i=0;i<MAXEPOCH;i++){
			m_Chart2d.SetXYValue(i,pDoc->m_SSE[i],i,series);
		}
		pDoc->MLP(pDlg.m_hid_lay,pDlg.m_learn_rate,7,pDlg.m_momentum,pDlg.m_threshold);
		series = 4;
		for(i=0;i<MAXEPOCH;i++){
			m_Chart2d.SetXYValue(i,pDoc->m_SSE[i],i,series);
		}
		pDoc->MLP(pDlg.m_hid_lay,pDlg.m_learn_rate,8,pDlg.m_momentum,pDlg.m_threshold);
		series = 5;
		for(i=0;i<MAXEPOCH;i++){
			m_Chart2d.SetXYValue(i,pDoc->m_SSE[i],i,series);
		}
*/
		pDoc->MLP(pDlg.m_hid_lay,pDlg.m_learn_rate,pDlg.m_hid_node,0.0,pDlg.m_threshold);
		series = 0;
		for(i=0;i<MAXEPOCH;i++){
			m_Chart2d.SetXYValue(i,pDoc->m_SSE[i],i,series);
		}
		pDoc->MLP(pDlg.m_hid_lay,pDlg.m_learn_rate,pDlg.m_hid_node,0.1,pDlg.m_threshold);
		series = 1;
		for(i=0;i<MAXEPOCH;i++){
			m_Chart2d.SetXYValue(i,pDoc->m_SSE[i],i,series);
		}
		pDoc->MLP(pDlg.m_hid_lay,pDlg.m_learn_rate,pDlg.m_hid_node,0.3,pDlg.m_threshold);
		series = 2;
		for(i=0;i<MAXEPOCH;i++){
			m_Chart2d.SetXYValue(i,pDoc->m_SSE[i],i,series);
		}
		pDoc->MLP(pDlg.m_hid_lay,pDlg.m_learn_rate,pDlg.m_hid_node,0.5,pDlg.m_threshold);
		series = 3;
		for(i=0;i<MAXEPOCH;i++){
			m_Chart2d.SetXYValue(i,pDoc->m_SSE[i],i,series);
		}
		pDoc->MLP(pDlg.m_hid_lay,pDlg.m_learn_rate,pDlg.m_hid_node,0.7,pDlg.m_threshold);
		series = 4;
		for(i=0;i<MAXEPOCH;i++){
			m_Chart2d.SetXYValue(i,pDoc->m_SSE[i],i,series);
		}
		pDoc->MLP(pDlg.m_hid_lay,pDlg.m_learn_rate,pDlg.m_hid_node,0.8,pDlg.m_threshold);
		series = 5;
		for(i=0;i<MAXEPOCH;i++){
			m_Chart2d.SetXYValue(i,pDoc->m_SSE[i],i,series);
		}
	}
	TrainFinish = TRUE;
}

/************************************************************************/
/* OnUpdateTraining1(CCmdUI* pCmdUI)									*/
/* Name: OnUpdateTraining1												*/
/* Parameter: CCmdUI* pCmdUI - For checking								*/
/* Return: No															*/
/* Explain: Constraint for training										*/
/************************************************************************/
void CPatternHW4View::OnUpdateTraining1(CCmdUI* pCmdUI) 
{
	CPatternHW4Doc* pDoc = GetDocument();
	ASSERT_VALID(pDoc);
	pCmdUI->Enable((pDoc->m_Data1 != NULL) || (pDoc->m_Data2 != NULL) || (pDoc->m_Data3 != NULL));
	// For training, the files should be inputted
}

/************************************************************************/
/* OnTesting1()															*/
/* Name: OnTesting1														*/
/* Parameter: No														*/
/* Return: No															*/
/* Explain: Testing														*/
/************************************************************************/
void CPatternHW4View::OnTesting1() 
{
	CPatternHW4Doc* pDoc = GetDocument();
	CPrmDlg pDlg;
	ASSERT_VALID(pDoc);
	
	if(pDlg.DoModal() == IDOK){
		pDoc->EnterData(TESTMODE);							// Enter the data into test data array
		pDoc->Test(pDlg.m_hid_lay,pDlg.m_learn_rate,pDlg.m_hid_node,pDlg.m_momentum,pDlg.m_threshold);
		
		m_Chart2d.SetChartTitle("PR HW #4 Test Result");
		m_Chart2d.SetChartLabel("Number of Data","Classification");
		m_Chart2d.SetRange(0,75,0,1);
		m_Chart2d.SetGridXYNumber(3,10);
		
		m_Chart2d.mpSerie[0].m_plotColor = RGB(255,0,0);
		m_Chart2d.mpSerie[1].m_plotColor = RGB(0,255,0);
		m_Chart2d.mpSerie[2].m_plotColor = RGB(0,0,255);
		
		m_Chart2d.mpSerie[0].IsLine = FALSE;
		m_Chart2d.mpSerie[1].IsLine = FALSE;
		m_Chart2d.mpSerie[2].IsLine = FALSE;
		
		m_Chart2d.Graph_Name[0].Format("Setosa");
		m_Chart2d.Graph_Name[1].Format("Versicolour");
		m_Chart2d.Graph_Name[2].Format("Virginica");
		
		m_Chart2d.ClearChart();
		series = 0;
		
		int maxIdx;
		double maxVal;
		numOfErr = 0;
		/************************************************************************/
		/* Result Showing														*/
		/************************************************************************/
		for(int i=0;i<75;i++){
			maxIdx = -100;
			maxVal = -100.0;
			for(int j=0;j<3;j++){
				if(pDoc->Test_Result[i][j] > maxVal){

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -