📄 digitrecogview.cpp
字号:
// DigitRecogView.cpp : implementation of the CDigitRecogView class
//
#include "stdafx.h"
#include "DigitRecog.h"
#include <fstream.h>
#include <math.h>
#include "DigitRecogDoc.h"
#include "DigitRecogView.h"
#ifdef _DEBUG
#define new DEBUG_NEW
#undef THIS_FILE
static char THIS_FILE[] = __FILE__;
#endif
/////////////////////////////////////////////////////////////////////////////
// CDigitRecogView
IMPLEMENT_DYNCREATE(CDigitRecogView, CView)
BEGIN_MESSAGE_MAP(CDigitRecogView, CView)
//{{AFX_MSG_MAP(CDigitRecogView)
ON_COMMAND(IDC_PINPUT, OnPinput)
ON_COMMAND(IDC_PNOISY, OnPnoisy)
ON_COMMAND(IDC_PWEIGHTS, OnPweights)
ON_COMMAND(IDC_PTRAIN, OnPtrain)
ON_COMMAND(IDC_PERROR, OnPerror)
ON_COMMAND(IDC_PRECOG, OnPrecog)
ON_COMMAND(IDC_BINPUT, OnBinput)
ON_COMMAND(IDC_BNOISY, OnBnoisy)
ON_COMMAND(IDC_CREATE, OnCreate)
ON_COMMAND(IDC_BWEIGHTS, OnBweights)
ON_COMMAND(IDC_BTRAIN, OnBtrain)
ON_COMMAND(IDC_BRECOG, OnBrecog)
ON_COMMAND(IDC_BERROR, OnBerror)
//}}AFX_MSG_MAP
// Standard printing commands
ON_COMMAND(ID_FILE_PRINT, CView::OnFilePrint)
ON_COMMAND(ID_FILE_PRINT_DIRECT, CView::OnFilePrint)
ON_COMMAND(ID_FILE_PRINT_PREVIEW, CView::OnFilePrintPreview)
END_MESSAGE_MAP()
/////////////////////////////////////////////////////////////////////////////
// CDigitRecogView construction/destruction
CDigitRecogView::CDigitRecogView()
{
// TODO: add construction code here
m_icycles = 0;
for(int i=0; i<10000; i++)
m_out[i] = 0.0;
}
CDigitRecogView::~CDigitRecogView()
{
// delete m_out;
}
BOOL CDigitRecogView::PreCreateWindow(CREATESTRUCT& cs)
{
// TODO: Modify the Window class or styles here by modifying
// the CREATESTRUCT cs
return CView::PreCreateWindow(cs);
}
/////////////////////////////////////////////////////////////////////////////
// CDigitRecogView drawing
void CDigitRecogView::OnDraw(CDC* pDC)
{
CDigitRecogDoc* pDoc = GetDocument();
ASSERT_VALID(pDoc);
// TODO: add draw code for native data here
if(draw==TRUE)
DrawNumber(m_ipDrawNum);
}
/////////////////////////////////////////////////////////////////////////////
// CDigitRecogView printing
BOOL CDigitRecogView::OnPreparePrinting(CPrintInfo* pInfo)
{
// default preparation
return DoPreparePrinting(pInfo);
}
void CDigitRecogView::OnBeginPrinting(CDC* /*pDC*/, CPrintInfo* /*pInfo*/)
{
// TODO: add extra initialization before printing
}
void CDigitRecogView::OnEndPrinting(CDC* /*pDC*/, CPrintInfo* /*pInfo*/)
{
// TODO: add cleanup after printing
}
/////////////////////////////////////////////////////////////////////////////
// CDigitRecogView diagnostics
#ifdef _DEBUG
void CDigitRecogView::AssertValid() const
{
CView::AssertValid();
}
void CDigitRecogView::Dump(CDumpContext& dc) const
{
CView::Dump(dc);
}
CDigitRecogDoc* CDigitRecogView::GetDocument() // non-debug version is inline
{
ASSERT(m_pDocument->IsKindOf(RUNTIME_CLASS(CDigitRecogDoc)));
return (CDigitRecogDoc*)m_pDocument;
}
#endif //_DEBUG
/////////////////////////////////////////////////////////////////////////////
// CDigitRecogView message handlers
void CDigitRecogView::OnPinput()
{
// TODO: Add your command handler code here
// TODO: Add your command handler code here // 打开文件
ifstream data("learn.dat", ios::nocreate);
ifstream test("test.dat", ios::nocreate);
// 检查文件的有效性
if (!test || !data)
{
MessageBox("No learning or test data present.", "Cannot run...",
MB_OK | MB_ICONERROR);
}
// 读文件
for(int i=0;i<NN_NUMBERS;i++)
{
for(int j=0;j<NN_RESX * NN_RESY;j++)
{
int onoff;
test >> onoff;
m_bTestData[i][j] = onoff;
data >> onoff;
m_bNumbers[i][j] = onoff;
}
}
// 设置初始要画的数字指针
m_ipDrawNum = &m_bNumbers[0][0];
DrawNumber( m_ipDrawNum);
draw=true;
}
void CDigitRecogView::SetTarget(double *t,double num)
{
switch(int(num))
{
case 1:
t[1]=1;t[2]=0;t[3]=0;t[4]=0;
break;
case 2:
t[1]=0;t[2]=1;t[3]=0;t[4]=0;
break;
case 3:
t[1]=1;t[2]=1;t[3]=0;t[4]=0;
break;
case 4:
t[1]=0;t[2]=0;t[3]=1;t[4]=0;
break;
case 5:
t[1]=1;t[2]=0;t[3]=1;t[4]=0;
break;
case 6:
t[1]=0;t[2]=1;t[3]=1;t[4]=0;
break;
case 7:
t[1]=1;t[2]=1;t[3]=1;t[4]=0;
break;
case 8:
t[1]=0;t[2]=0;t[3]=0;t[4]=1;
break;
case 9:
t[1]=1;t[2]=0;t[3]=0;t[4]=1;
break;
}
}
void CDigitRecogView::RunNet(BOOL training)
{
float d[NN_NUMBERS]; // 检测带有噪声的数据
bool correct;
m_icycles = 0; //初始化循环次数为0
do
{
correct = true;
for(int i=0;i<NN_NUMBERS * NN_NOISY;i++)
{
for(int j=0;j<NN_NUMBERS;j++)
{
d[j] = 0;
for(int k=0;k<NN_RESX * NN_RESY;k++)
{
d[j] += m_fWeights[j][k]*m_iNoisy[i][k];
}
}
int bestind = 0;
for(j=1;j<NN_NUMBERS;j++)
if (d[j] > d[bestind])
bestind = j;
int realval = (int)(i/NN_NOISY);
if (bestind == realval) continue;
if (training)
{
correct = false;
for(j=0;j<NN_RESX * NN_RESY;j++)
{
m_fWeights[bestind][j] -= m_iNoisy[i][j];
m_fWeights[realval][j] += m_iNoisy[i][j];
}
}
}
for(int k=0;k<NN_RESX * NN_RESY;k++)
m_out[m_icycles] +=m_fWeights[0][k]*m_iNoisy[0][k];
m_icycles++;
} while (!correct && m_icycles <= NN_MAXITER);
//如果循环次数太大还每得到结果则退出
if (m_icycles >= NN_MAXITER)
{
MessageBox("Training has timed-out.",
"Error in Training", MB_OK | MB_ICONINFORMATION);
return;
}
}
void CDigitRecogView::DrawNumber(int *cell)
{
CPoint tl(100,100);
int ix=100;
CDC *dc;
dc=this->GetDC();
CPen pen;
pen.CreatePen(PS_DOT, 1, RGB(127,0,0));
CPen *pOldPen = dc->SelectObject(&pen);
//画边框
for(int i=0;i<9;i++)
{
pOldPen=dc->SelectObject(&pen);
dc->MoveTo(tl.x, tl.y);
dc->LineTo(tl.x+50, tl.y);
dc->LineTo(tl.x+50, tl.y+70);
dc->LineTo(tl.x, tl.y+70);
dc->LineTo(tl.x, tl.y);
dc->SelectObject(pOldPen);
// 画数字
CPoint p;
p.x=tl.x;
p.y=tl.y;
for(int j=0;j<NN_RESX*NN_RESY;j++)
{
if ((*cell) == 1)
{
dc->FillSolidRect(tl.x,tl.y,10,10,RGB(0,0,0));
}
tl.x += 10;
if (tl.x == ix + 10*NN_RESX)
{
tl.x = ix;
tl.y += 10;
}
cell++;
}
tl.x=p.x;
tl.y=p.y;
tl.x=tl.x+100;
ix=ix+100;
}
}
void CDigitRecogView::OnPnoisy()
{
// TODO: Add your command handler code here
int num = 0;
for (int i=0;i<NN_NUMBERS * NN_NOISY;i++) //产生加噪样本
{
for(int j=0;j<NN_RESX * NN_RESY;j++)
{
if (rand() % 100 < 7)
{
m_iNoisy[i][j] = !m_bNumbers[num][j];
}
else
m_iNoisy[i][j] = m_bNumbers[num][j];
}
if ((i+1)%NN_NOISY == 0)
num++;
}
MessageBox("成功产生噪声样本!", "噪声样本", MB_OK);
}
void CDigitRecogView::OnPweights()
{
// TODO: Add your command handler code here
srand((unsigned)time(NULL));
memset(&m_fWeights,0,sizeof(m_fWeights));
MessageBox("成功初始化网络权值!", "网络权值", MB_OK);
}
void CDigitRecogView::OnPtrain()
{
// TODO: Add your command handler code here
RunNet(TRUE);
MessageBox("训练完毕!", "训练完毕", MB_OK);
}
void CDigitRecogView::OnPerror()
{
// TODO: Add your command handler code here
float out;
int delta;
out = 0;
for(int k=0;k<NN_RESX * NN_RESY;k++)
out+=(int)m_fWeights[0][k]*m_bNumbers[0][k];
CClientDC dc(this);
dc.MoveTo(100,600);
dc.LineTo(300,600);
dc.MoveTo(100,600);
dc.LineTo(100,300);
dc.TextOut(300,575,"训练次数");
dc.TextOut(110,300," 误差曲线");
CPen newpen(PS_SOLID,1,RGB(0,0, 255));
CPen* old=dc.SelectObject(&newpen);
dc.MoveTo(100,600);
for(int j=0;j<m_icycles;j++)
{
delta=int(fabs(out-m_out[j]));
dc.LineTo(100+j*10,600-delta*10);
}
CString str;
str.Format("训练次数为%d",m_icycles);
MessageBox(str);
}
void CDigitRecogView::OnPrecog()
{
// TODO: Add your command handler code here
CClientDC dc(this);
CRect rect;
GetClientRect(&rect);
CBrush brush(RGB(255,255,255));
CBrush *pOldBrush=dc.SelectObject(&brush);
dc.PatBlt(rect.left,rect.top,rect.Width(),rect.Height(),PATCOPY);
m_ipDrawNum = &m_bTestData[0][0];
DrawNumber( m_ipDrawNum);
draw=true;
for(int i=0;i<NN_NUMBERS;i++)
{
float d[NN_NUMBERS];
for(int j=0;j<NN_NUMBERS;j++)
{
d[j] = 0;
for(int k=0;k<NN_RESX * NN_RESY;k++)
{
d[j] += m_fWeights[j][k]*m_bTestData[i][k];
}
}
int bestind = 0;
for(j=1;j<NN_NUMBERS;j++)
if (d[j] > d[bestind])
bestind = j;
CString str;
str.Format("识别结果为%d",++bestind);
dc.TextOut(100+i*100,200,str);
}
}
void CDigitRecogView::OnBinput()
{
// TODO: Add your command handler code here
ifstream data("learn.dat", ios::nocreate);
ifstream test("test.dat", ios::nocreate);
// 检查文件的有效性
if (!test || !data)
{
MessageBox("No learning or test data present.", "Cannot run...",
MB_OK | MB_ICONERROR);
}
// 读文件
for(int i=0;i<NN_NUMBERS;i++)
{
for(int j=0;j<NN_RESX * NN_RESY;j++)
{
int onoff;
test >> onoff;
m_bTestData[i][j] = onoff;
data >> onoff;
m_bNumbers[i][j] = onoff;
}
}
// 设置初始要画的数字指针
m_ipDrawNum = &m_bNumbers[0][0];
DrawNumber( m_ipDrawNum);
draw=true;
}
void CDigitRecogView::OnBnoisy()
{
// TODO: Add your command handler code here
int num = 0;
for (int i=0;i<NN_NUMBERS * NN_NOISY;i++) //产生加噪样本
{
for(int j=0;j<NN_RESX * NN_RESY;j++)
{
if (rand() % 100 < 7)
{
m_iNoisy[i][j] = !m_bNumbers[num][j];
}
else
m_iNoisy[i][j] = m_bNumbers[num][j];
}
if ((i+1)%NN_NOISY == 0)
num++;
}
MessageBox("成功产生噪声样本!", "噪声样本", MB_OK);
}
void CDigitRecogView::OnCreate()
{
// TODO: Add your command handler code here
net = bpnn_create(35, 4, 4);
MessageBox("网络创建完毕!", "网络创建完毕", MB_OK);
}
void CDigitRecogView::OnBweights()
{
// TODO: Add your command handler code here
int m,n;
for( m=0;m<=35;m++) //初始化权值
for( n=0;n<=4;n++)
{
net->input_weights[m][n]=(double)((float)(rand())/(32767/2) - 1);
}
for(m=0;m<=4;m++)
for(n=0;n<=4;n++)
{
net->hidden_weights[m][n]=(double)((float)(rand())/(32767/2) - 1);
}
MessageBox("成功初始化网络权值!", "网络权值", MB_OK);
}
void CDigitRecogView::OnBtrain()
{
// TODO: Add your command handler code here
double eta,momentum;
double eo,eh;
int num=1;
momentum = 0.1;
eta = 0.3;
MessageBox("确定开始训练,请耐心等待...", "训练", MB_OK);
do
{
//对标准的0到9数字进行网络训练
for (int i=0;i<NN_NUMBERS;i++)
{
for(int k=1;k<=NN_RESX * NN_RESY;k++)
{
net->input_units[k]=(double )m_bNumbers[i][k-1];
}
SetTarget(net->target,i+1); //用8421码表示1~9输出
bpnn_train(net,eta,momentum,&eo,&eh);
}
//对随机产生的25个躁声加入到网络中进行训练,用来提高对躁声的识别精度
for(int m1=0;m1<25;m1++)
for(int i1=0;i1<NN_NUMBERS;i1++)
{
for(int k1=1;k1<=NN_RESX * NN_RESY;k1++)
{
net->input_units[k1]=(double )m_iNoisy[i1*25+m1][k1-1];
}
SetTarget(net->target,i1+1);
bpnn_train(net,eta,momentum,&eo,&eh);
}
m_out[m_icycles] = eo;
m_icycles++;
}
while(eo > 0.00001 && m_icycles < 10000);
MessageBox("训练完毕!", "训练完毕", MB_OK);
}
void CDigitRecogView::OnBrecog()
{
// TODO: Add your command handler code here
CClientDC dc(this);
CRect rect;
GetClientRect(&rect);
CBrush brush(RGB(255,255,255));
CBrush *pOldBrush=dc.SelectObject(&brush);
dc.PatBlt(rect.left,rect.top,rect.Width(),rect.Height(),PATCOPY);
m_ipDrawNum = &m_bTestData[0][0];
DrawNumber(m_ipDrawNum);
draw=true;
int num=1;
double t;
double d[4]={0,0,0,0};
for(int i=0;i<NN_NUMBERS;i++)
{
for(int k=1;k<=NN_RESX * NN_RESY;k++)
{
net->input_units[k]=(double)m_bTestData[i][k-1];
}
bpnn_feedforward(net);
for(int m=0; m<4; m++)
d[m]=net->output_units[m+1];
t=d[0]+d[1]*2+d[2]*4+d[3]*8;
double min;
min=fabs(t-1);
for(int n=1;n<9;n++) //找出最接近n+1的数
{
if(min>fabs((t-(n+1))))
{
min=fabs((t-(n+1)));
num=n+1;
}
}
CString str;
str.Format("识别结果为%d",num);
dc.TextOut(100+i*100,200,str);
}
}
void CDigitRecogView::OnBerror()
{
// TODO: Add your command handler code here
CClientDC dc(this);
dc.MoveTo(100,600);
dc.LineTo(800,600);
dc.MoveTo(100,600);
dc.LineTo(100,300);
dc.TextOut(800,575,"训练次数");
dc.TextOut(110,300," 误差曲线");
CPen newpen(PS_SOLID,1,RGB(0,0, 255));
CPen* old=dc.SelectObject(&newpen);
dc.MoveTo(100,600);
for(int j=0;j<m_icycles;j++)
{
dc.LineTo((int)(100+0.05*j),(int)(600-m_out[j]*100000));
}
CString str;
str.Format("训练次数为%d",m_icycles);
MessageBox(str);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -