⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 digitrecogview.cpp

📁 用bp实现的数字识别程序
💻 CPP
字号:
// DigitRecogView.cpp : implementation of the CDigitRecogView class
//


#include "stdafx.h"
#include "DigitRecog.h"
#include <fstream.h>
#include <math.h>
#include "DigitRecogDoc.h"
#include "DigitRecogView.h"

#ifdef _DEBUG
#define new DEBUG_NEW
#undef THIS_FILE
static char THIS_FILE[] = __FILE__;
#endif

/////////////////////////////////////////////////////////////////////////////
// CDigitRecogView

IMPLEMENT_DYNCREATE(CDigitRecogView, CView)

BEGIN_MESSAGE_MAP(CDigitRecogView, CView)
	//{{AFX_MSG_MAP(CDigitRecogView)
	ON_COMMAND(IDC_PINPUT, OnPinput)
	ON_COMMAND(IDC_PNOISY, OnPnoisy)
	ON_COMMAND(IDC_PWEIGHTS, OnPweights)
	ON_COMMAND(IDC_PTRAIN, OnPtrain)
	ON_COMMAND(IDC_PERROR, OnPerror)
	ON_COMMAND(IDC_PRECOG, OnPrecog)
	ON_COMMAND(IDC_BINPUT, OnBinput)
	ON_COMMAND(IDC_BNOISY, OnBnoisy)
	ON_COMMAND(IDC_CREATE, OnCreate)
	ON_COMMAND(IDC_BWEIGHTS, OnBweights)
	ON_COMMAND(IDC_BTRAIN, OnBtrain)
	ON_COMMAND(IDC_BRECOG, OnBrecog)
	ON_COMMAND(IDC_BERROR, OnBerror)
	//}}AFX_MSG_MAP
	// Standard printing commands
	ON_COMMAND(ID_FILE_PRINT, CView::OnFilePrint)
	ON_COMMAND(ID_FILE_PRINT_DIRECT, CView::OnFilePrint)
	ON_COMMAND(ID_FILE_PRINT_PREVIEW, CView::OnFilePrintPreview)
END_MESSAGE_MAP()

/////////////////////////////////////////////////////////////////////////////
// CDigitRecogView construction/destruction

CDigitRecogView::CDigitRecogView()
{
	// TODO: add construction code here
	m_icycles = 0;
	for(int i=0; i<10000; i++)
		m_out[i] = 0.0;
}

CDigitRecogView::~CDigitRecogView()
{
//	delete m_out;
}

BOOL CDigitRecogView::PreCreateWindow(CREATESTRUCT& cs)
{
	// TODO: Modify the Window class or styles here by modifying
	//  the CREATESTRUCT cs
	return CView::PreCreateWindow(cs);
}

/////////////////////////////////////////////////////////////////////////////
// CDigitRecogView drawing

void CDigitRecogView::OnDraw(CDC* pDC)
{
	CDigitRecogDoc* pDoc = GetDocument();
	ASSERT_VALID(pDoc);
	// TODO: add draw code for native data here
	if(draw==TRUE)
		DrawNumber(m_ipDrawNum);
}

/////////////////////////////////////////////////////////////////////////////
// CDigitRecogView printing

BOOL CDigitRecogView::OnPreparePrinting(CPrintInfo* pInfo)
{
	// default preparation
	return DoPreparePrinting(pInfo);
}

void CDigitRecogView::OnBeginPrinting(CDC* /*pDC*/, CPrintInfo* /*pInfo*/)
{
	// TODO: add extra initialization before printing
}

void CDigitRecogView::OnEndPrinting(CDC* /*pDC*/, CPrintInfo* /*pInfo*/)
{
	// TODO: add cleanup after printing
}

/////////////////////////////////////////////////////////////////////////////
// CDigitRecogView diagnostics

#ifdef _DEBUG
void CDigitRecogView::AssertValid() const
{
	CView::AssertValid();
}

void CDigitRecogView::Dump(CDumpContext& dc) const
{
	CView::Dump(dc);
}

CDigitRecogDoc* CDigitRecogView::GetDocument() // non-debug version is inline
{
	ASSERT(m_pDocument->IsKindOf(RUNTIME_CLASS(CDigitRecogDoc)));
	return (CDigitRecogDoc*)m_pDocument;
}
#endif //_DEBUG

/////////////////////////////////////////////////////////////////////////////
// CDigitRecogView message handlers

void CDigitRecogView::OnPinput() 
{
	// TODO: Add your command handler code here
	// TODO: Add your command handler code here	// 打开文件
	ifstream data("learn.dat", ios::nocreate);
	ifstream test("test.dat", ios::nocreate);

	// 检查文件的有效性
	if (!test || !data) 
	{
		MessageBox("No learning or test data present.", "Cannot run...",
			MB_OK | MB_ICONERROR);
	}

	// 读文件
	for(int i=0;i<NN_NUMBERS;i++) 
	{
		for(int j=0;j<NN_RESX * NN_RESY;j++) 
		{
			int onoff;
			test >> onoff;
			m_bTestData[i][j] = onoff;

			data >> onoff;
			m_bNumbers[i][j] = onoff;
		}
	}

	// 设置初始要画的数字指针
	m_ipDrawNum = &m_bNumbers[0][0];
    DrawNumber(	m_ipDrawNum);
	draw=true;	
}

void CDigitRecogView::SetTarget(double *t,double num)
{
	switch(int(num))
	{
	case 1:
		t[1]=1;t[2]=0;t[3]=0;t[4]=0;
		break;
	case 2:
		t[1]=0;t[2]=1;t[3]=0;t[4]=0;
		break;
	case 3:
		t[1]=1;t[2]=1;t[3]=0;t[4]=0;
		break;
	case 4:
		t[1]=0;t[2]=0;t[3]=1;t[4]=0;
		break;
	case 5:
		t[1]=1;t[2]=0;t[3]=1;t[4]=0;
		break;
	case 6:
		t[1]=0;t[2]=1;t[3]=1;t[4]=0;
		break;
	case 7:
		t[1]=1;t[2]=1;t[3]=1;t[4]=0;
		break;
	case 8:
		t[1]=0;t[2]=0;t[3]=0;t[4]=1;
		break;
	case 9:
		t[1]=1;t[2]=0;t[3]=0;t[4]=1;
		break;
	}
}
	
void CDigitRecogView::RunNet(BOOL training)
{
	float	d[NN_NUMBERS];			// 检测带有噪声的数据
	bool	correct;					

	m_icycles = 0;					//初始化循环次数为0
	do 
	{
		correct = true;
		for(int i=0;i<NN_NUMBERS * NN_NOISY;i++) 
		{
			for(int j=0;j<NN_NUMBERS;j++) 
			{
				d[j] = 0;
				for(int k=0;k<NN_RESX * NN_RESY;k++) 
				{
					d[j] += m_fWeights[j][k]*m_iNoisy[i][k];
				}
			}
			
			int bestind = 0;
			for(j=1;j<NN_NUMBERS;j++) 
				if (d[j] > d[bestind]) 
					bestind = j;
			
			int realval = (int)(i/NN_NOISY);
			if (bestind == realval) continue;
			
			if (training) 
			{				
				correct = false;				
				for(j=0;j<NN_RESX * NN_RESY;j++) 
				{
					m_fWeights[bestind][j] -= m_iNoisy[i][j];
					m_fWeights[realval][j] += m_iNoisy[i][j];
				}
			}
		}

			for(int k=0;k<NN_RESX * NN_RESY;k++) 		
				m_out[m_icycles] +=m_fWeights[0][k]*m_iNoisy[0][k];	

		m_icycles++;	   
		
	} while (!correct && m_icycles <= NN_MAXITER);
	
	//如果循环次数太大还每得到结果则退出
	if (m_icycles >= NN_MAXITER) 
	{
		MessageBox("Training has timed-out.",
			"Error in Training", MB_OK | MB_ICONINFORMATION);
		return;
	}
}	
	
void CDigitRecogView::DrawNumber(int *cell)
{
	CPoint tl(100,100);
	int ix=100;
	
	CDC *dc;
	dc=this->GetDC();
	CPen pen;
	pen.CreatePen(PS_DOT, 1, RGB(127,0,0));
	CPen *pOldPen = dc->SelectObject(&pen);
	//画边框
	for(int i=0;i<9;i++)
	{
		pOldPen=dc->SelectObject(&pen);
  
		dc->MoveTo(tl.x, tl.y);
		dc->LineTo(tl.x+50, tl.y);
		
		dc->LineTo(tl.x+50, tl.y+70);
		
		dc->LineTo(tl.x, tl.y+70);
		
		dc->LineTo(tl.x, tl.y);
        
          
		dc->SelectObject(pOldPen);
			// 画数字
		CPoint p;
		p.x=tl.x;
		p.y=tl.y;
		for(int j=0;j<NN_RESX*NN_RESY;j++) 
		{
		
			if ((*cell) == 1) 
			{
				dc->FillSolidRect(tl.x,tl.y,10,10,RGB(0,0,0));
			}
			tl.x += 10;
			if (tl.x == ix + 10*NN_RESX) 
			{
				tl.x = ix;
				tl.y += 10;
			}
			cell++;	
		}
		tl.x=p.x;
		tl.y=p.y;
		tl.x=tl.x+100;
		ix=ix+100;
	}  
}

void CDigitRecogView::OnPnoisy() 
{
	// TODO: Add your command handler code here
	int num = 0;
	for (int i=0;i<NN_NUMBERS * NN_NOISY;i++)			//产生加噪样本
	{
		for(int j=0;j<NN_RESX * NN_RESY;j++) 
		{
			if (rand() % 100 < 7) 
			{
				m_iNoisy[i][j] = !m_bNumbers[num][j];
			} 
			else 
				m_iNoisy[i][j] = m_bNumbers[num][j];
		}
		
		if ((i+1)%NN_NOISY == 0) 
			num++;
	}

	MessageBox("成功产生噪声样本!", "噪声样本", MB_OK);
}

void CDigitRecogView::OnPweights() 
{
	// TODO: Add your command handler code here
	srand((unsigned)time(NULL));
	memset(&m_fWeights,0,sizeof(m_fWeights));
	MessageBox("成功初始化网络权值!", "网络权值", MB_OK);
}

void CDigitRecogView::OnPtrain() 
{
	// TODO: Add your command handler code here
	RunNet(TRUE);
	MessageBox("训练完毕!", "训练完毕", MB_OK);
}

void CDigitRecogView::OnPerror() 
{
	
	// TODO: Add your command handler code here
	float	out;
	int		delta;

	out = 0;
	for(int k=0;k<NN_RESX * NN_RESY;k++) 		
		out+=(int)m_fWeights[0][k]*m_bNumbers[0][k];

	CClientDC dc(this);
	dc.MoveTo(100,600);
	dc.LineTo(300,600);
	dc.MoveTo(100,600);
	dc.LineTo(100,300);
	dc.TextOut(300,575,"训练次数");
	dc.TextOut(110,300," 误差曲线");
	CPen newpen(PS_SOLID,1,RGB(0,0, 255));
    CPen* old=dc.SelectObject(&newpen);
	dc.MoveTo(100,600);
	for(int j=0;j<m_icycles;j++)
    {		
		delta=int(fabs(out-m_out[j]));
		dc.LineTo(100+j*10,600-delta*10);
	}
	
	CString str;
	str.Format("训练次数为%d",m_icycles);
	MessageBox(str);
}

void CDigitRecogView::OnPrecog() 
{
	// TODO: Add your command handler code here
	CClientDC dc(this);
	CRect rect;
	GetClientRect(&rect);
	CBrush brush(RGB(255,255,255));
	CBrush *pOldBrush=dc.SelectObject(&brush);
	
	dc.PatBlt(rect.left,rect.top,rect.Width(),rect.Height(),PATCOPY);

	m_ipDrawNum = &m_bTestData[0][0];
    DrawNumber(	m_ipDrawNum);
	draw=true;
	for(int i=0;i<NN_NUMBERS;i++)
	{
		float d[NN_NUMBERS];
		for(int j=0;j<NN_NUMBERS;j++)
		{
			d[j] = 0;
			for(int k=0;k<NN_RESX * NN_RESY;k++)
			{
				d[j] += m_fWeights[j][k]*m_bTestData[i][k];
			}
		}
	
		int bestind = 0;
		for(j=1;j<NN_NUMBERS;j++) 
			if (d[j] > d[bestind]) 
				bestind = j;
		CString str;
		str.Format("识别结果为%d",++bestind);
		dc.TextOut(100+i*100,200,str);
	}
}

void CDigitRecogView::OnBinput() 
{
	// TODO: Add your command handler code here
	ifstream data("learn.dat", ios::nocreate);
	ifstream test("test.dat", ios::nocreate);

	// 检查文件的有效性
	if (!test || !data) 
	{
		MessageBox("No learning or test data present.", "Cannot run...",
			MB_OK | MB_ICONERROR);
	}

	// 读文件
	for(int i=0;i<NN_NUMBERS;i++) 
	{
		for(int j=0;j<NN_RESX * NN_RESY;j++) 
		{
			int onoff;
			test >> onoff;
			m_bTestData[i][j] = onoff;

			data >> onoff;
			m_bNumbers[i][j] = onoff;
		}
	}

	// 设置初始要画的数字指针
	m_ipDrawNum = &m_bNumbers[0][0];
    DrawNumber(	m_ipDrawNum);
	draw=true;	
}

void CDigitRecogView::OnBnoisy() 
{
	// TODO: Add your command handler code here
	int num = 0;
	for (int i=0;i<NN_NUMBERS * NN_NOISY;i++)			//产生加噪样本
	{
		for(int j=0;j<NN_RESX * NN_RESY;j++) 
		{
			if (rand() % 100 < 7) 
			{
				m_iNoisy[i][j] = !m_bNumbers[num][j];
			} 
			else 
				m_iNoisy[i][j] = m_bNumbers[num][j];
		}
		
		if ((i+1)%NN_NOISY == 0) 
			num++;
	}

	MessageBox("成功产生噪声样本!", "噪声样本", MB_OK);
}

void CDigitRecogView::OnCreate() 
{
	// TODO: Add your command handler code here
	net = bpnn_create(35, 4, 4);
	MessageBox("网络创建完毕!", "网络创建完毕", MB_OK);
}

void CDigitRecogView::OnBweights() 
{
	// TODO: Add your command handler code here
	int m,n;
	for( m=0;m<=35;m++)							//初始化权值		
		for( n=0;n<=4;n++)
		{
			net->input_weights[m][n]=(double)((float)(rand())/(32767/2) - 1);
		}

	for(m=0;m<=4;m++)
		for(n=0;n<=4;n++)
		{
			net->hidden_weights[m][n]=(double)((float)(rand())/(32767/2) - 1);
		}

	MessageBox("成功初始化网络权值!", "网络权值", MB_OK);
}

void CDigitRecogView::OnBtrain() 
{
	// TODO: Add your command handler code here

	double eta,momentum;
	double	eo,eh;
	int		num=1;
	momentum = 0.1;
	eta =	0.3;

	MessageBox("确定开始训练,请耐心等待...", "训练", MB_OK);

	do
	{
		//对标准的0到9数字进行网络训练
		for (int i=0;i<NN_NUMBERS;i++) 
		{
            for(int k=1;k<=NN_RESX * NN_RESY;k++)
			{
				net->input_units[k]=(double )m_bNumbers[i][k-1];

			}
			SetTarget(net->target,i+1);		//用8421码表示1~9输出
		    bpnn_train(net,eta,momentum,&eo,&eh);   
		}

		//对随机产生的25个躁声加入到网络中进行训练,用来提高对躁声的识别精度		
		for(int m1=0;m1<25;m1++)
			for(int i1=0;i1<NN_NUMBERS;i1++) 
			{
				for(int k1=1;k1<=NN_RESX * NN_RESY;k1++)
				{
					net->input_units[k1]=(double )m_iNoisy[i1*25+m1][k1-1];
				}
				SetTarget(net->target,i1+1);
				bpnn_train(net,eta,momentum,&eo,&eh);
			}
		
		m_out[m_icycles] = eo;
	
		m_icycles++;
	}
	while(eo > 0.00001 && m_icycles < 10000);

	MessageBox("训练完毕!", "训练完毕", MB_OK);
}

void CDigitRecogView::OnBrecog() 
{
	// TODO: Add your command handler code here
	CClientDC dc(this);
	CRect rect;
	GetClientRect(&rect);
	CBrush brush(RGB(255,255,255));
	CBrush *pOldBrush=dc.SelectObject(&brush);
	
    dc.PatBlt(rect.left,rect.top,rect.Width(),rect.Height(),PATCOPY);

	m_ipDrawNum = &m_bTestData[0][0];
    DrawNumber(m_ipDrawNum);
	draw=true;
	int num=1;
	double t;
    double d[4]={0,0,0,0};
    for(int i=0;i<NN_NUMBERS;i++)
	{
		for(int k=1;k<=NN_RESX * NN_RESY;k++)
		{
			net->input_units[k]=(double)m_bTestData[i][k-1];
		}
                     
		bpnn_feedforward(net);
					 				 
		
		for(int m=0; m<4; m++)
			d[m]=net->output_units[m+1];
        t=d[0]+d[1]*2+d[2]*4+d[3]*8;

		double min;
		min=fabs(t-1);
		for(int n=1;n<9;n++)				//找出最接近n+1的数
		{    					 
			if(min>fabs((t-(n+1))))
			{
				min=fabs((t-(n+1)));
				num=n+1;
			}						
		}
		
		CString str;
		str.Format("识别结果为%d",num);
		dc.TextOut(100+i*100,200,str);
	}					 
}

void CDigitRecogView::OnBerror() 
{
	// TODO: Add your command handler code here
	
	CClientDC dc(this);
	dc.MoveTo(100,600);
	dc.LineTo(800,600);
	dc.MoveTo(100,600);
	dc.LineTo(100,300);
	dc.TextOut(800,575,"训练次数");
	dc.TextOut(110,300," 误差曲线");
	CPen newpen(PS_SOLID,1,RGB(0,0, 255));
    CPen* old=dc.SelectObject(&newpen);
	dc.MoveTo(100,600);
	for(int j=0;j<m_icycles;j++)
    {		
		dc.LineTo((int)(100+0.05*j),(int)(600-m_out[j]*100000));
	}
	
	CString str;
	str.Format("训练次数为%d",m_icycles);
	MessageBox(str);
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -