📄 bp_netview.cpp
字号:
// Bp_netView.cpp : implementation of the CBp_netView class
//
#include "stdafx.h"
#include "Bp_net.h"
#include "bp.h"
#include "Bp_netDoc.h"
#include "Bp_netView.h"
#include <string>
#ifdef _DEBUG
#define new DEBUG_NEW
#undef THIS_FILE
static char THIS_FILE[] = __FILE__;
#endif
/////////////////////////////////////////////////////////////////////////////
// CBp_netView
IMPLEMENT_DYNCREATE(CBp_netView, CView)
BEGIN_MESSAGE_MAP(CBp_netView, CView)
//{{AFX_MSG_MAP(CBp_netView)
ON_COMMAND(ID_XUNLIA, OnXunlia)
//}}AFX_MSG_MAP
// Standard printing commands
ON_COMMAND(ID_FILE_PRINT, CView::OnFilePrint)
ON_COMMAND(ID_FILE_PRINT_DIRECT, CView::OnFilePrint)
ON_COMMAND(ID_FILE_PRINT_PREVIEW, CView::OnFilePrintPreview)
END_MESSAGE_MAP()
/////////////////////////////////////////////////////////////////////////////
// CBp_netView construction/destruction
CBp_netView::CBp_netView()
{
// TODO: add construction code here
}
CBp_netView::~CBp_netView()
{
}
BOOL CBp_netView::PreCreateWindow(CREATESTRUCT& cs)
{
// TODO: Modify the Window class or styles here by modifying
// the CREATESTRUCT cs
return CView::PreCreateWindow(cs);
}
/////////////////////////////////////////////////////////////////////////////
// CBp_netView drawing
void CBp_netView::OnDraw(CDC* pDC)
{
CBp_netDoc* pDoc = GetDocument();
ASSERT_VALID(pDoc);
// TODO: add draw code for native data here
CClientDC dc(this);
COLORREF color=RGB(0,0,0);
CPen newPen(PS_SOLID,2,color);
CPen *oldPen;
oldPen=dc.SelectObject(&newPen);
CPoint starp,endp,starp1,endp1,endp2;
starp.x=200;
starp.y=80;
endp.x=300;
endp.y=80;
int b=20,c=40;
starp1.x=200;
starp1.y=160;
endp1.x=300;
endp1.y=160;
int i,j;
/////////////////画训练图////////
for (i=0;i<1;i++)
{
dc.MoveTo(starp);
dc.LineTo(endp);
dc.MoveTo(starp1);
dc.LineTo(endp1);
dc.MoveTo(starp);
dc.LineTo(endp1);
dc.MoveTo(starp1);
dc.LineTo(endp);
dc.Ellipse(endp1.x,endp1.y-b,endp1.x+c,endp1.y+b);
dc.Ellipse(endp.x,endp.y-b,endp.x+c,endp.y+b);
starp1.x=endp1.x+40;
endp1.x=starp1.x+100;
starp.x=endp.x+40;
endp.x=starp.x+100;
}
starp1.x=endp1.x;
endp1.x=starp1.x-100;
starp.x=endp.x;
endp.x=starp.x-100;
endp2.x=endp.x+100;
endp2.y=endp.y+40;
dc.MoveTo(endp);
dc.LineTo(endp2);
dc.MoveTo(endp1);
dc.LineTo(endp2);
dc.Ellipse(endp2.x,endp2.y-b,endp2.x+c,endp2.y+b);
dc.MoveTo(endp2.x+40,endp2.y);
dc.LineTo(endp2.x+100,endp2.y);
dc.TextOut(340,200,"训练网络");
dc.TextOut(190,80,"X1");
dc.TextOut(190,160,"X2");
dc.TextOut(310,70,"Y1");
dc.TextOut(310,150,"Y2");
dc.TextOut(450,110,"Y3");
dc.TextOut(530,120,"O");
dc.TextOut(230,70,"W1");
dc.TextOut(230,100,"W2");
dc.TextOut(230,130,"W3");
dc.TextOut(230,150,"W4");
dc.TextOut(370,90,"W5");
dc.TextOut(370,135,"W6");
dc.TextOut(160,120,"输入");
dc.TextOut(550,120,"输出");
///////////////画坐标///////////
dc.MoveTo(220,250);
dc.LineTo(220,750);
dc.MoveTo(210,260);
dc.LineTo(220,250);
dc.MoveTo(230,260);
dc.LineTo(220,250);
dc.TextOut(250,260,"E(网络能量)");
dc.TextOut(720,750,"T(迭代次数)");
dc.MoveTo(220,750);
dc.LineTo(720,750);
dc.MoveTo(710,740);
dc.LineTo(720,750);
dc.MoveTo(710,760);
dc.LineTo(720,750);
dc.SelectObject(oldPen);
CPoint WW;//权值输出的位置初始坐标
WW.x=750;
WW.y=50;
dc.TextOut(700,50,"W1=");
dc.TextOut(700,80,"W2=");
dc.TextOut(700,110,"W3=");
dc.TextOut(700,140,"W4=");
dc.TextOut(700,170,"W5=");
dc.TextOut(700,200,"W6=");
//
dc.TextOut(900,440,"异或问题");
dc.TextOut(1000,480,"输出结果");
dc.TextOut(820,480,"初始输入");
dc.TextOut(820,500,"0 0");
dc.TextOut(820,540,"0 1");
dc.TextOut(820,580,"1 0");
dc.TextOut(820,620,"1 1");
}
/////////////////////////////////////////////////////////////////////////////
// CBp_netView printing
BOOL CBp_netView::OnPreparePrinting(CPrintInfo* pInfo)
{
// default preparation
return DoPreparePrinting(pInfo);
}
void CBp_netView::OnBeginPrinting(CDC* /*pDC*/, CPrintInfo* /*pInfo*/)
{
// TODO: add extra initialization before printing
}
void CBp_netView::OnEndPrinting(CDC* /*pDC*/, CPrintInfo* /*pInfo*/)
{
// TODO: add cleanup after printing
}
/////////////////////////////////////////////////////////////////////////////
// CBp_netView diagnostics
#ifdef _DEBUG
void CBp_netView::AssertValid() const
{
CView::AssertValid();
}
void CBp_netView::Dump(CDumpContext& dc) const
{
CView::Dump(dc);
}
CBp_netDoc* CBp_netView::GetDocument() // non-debug version is inline
{
ASSERT(m_pDocument->IsKindOf(RUNTIME_CLASS(CBp_netDoc)));
return (CBp_netDoc*)m_pDocument;
}
#endif //_DEBUG
/////////////////////////////////////////////////////////////////////////////
// CBp_netView message handlers
void CBp_netView::OnXunlia()
{
// TODO: Add your command handler code here
CClientDC dc(this);
bp BP;
int n_in=2;
int n_out=1;
int n_hidden=2;
int i,j,k;
int num=4;
double min_ex=0.001;
double eta=0.08,momentum=0.65;
double **data_in,**data_out;
data_in=BP.alloc_2d_dbl(4,2);
data_out=BP.alloc_2d_dbl(4,1);
data_in[0][0]=0.0; data_in[0][1]=0.0;
data_in[1][0]=0.0; data_in[1][1]=1.0;
data_in[2][0]=1.0; data_in[2][1]=0.0;
data_in[3][0]=1.0; data_in[3][1]=1.0;
data_out[0][0]=0.0;
data_out[1][0]=1.0;
data_out[2][0]=1.0;
data_out[3][0]=0.0;
double *output;
//指向输入层数据的指针
double* input_unites;
//指向隐层数据的指针
double* hidden_unites;
//指向输出层数据的指针
double* output_unites;
//指向隐层误差数据的指针
double* hidden_deltas;
//指向输出层误差数剧的指针
double* output_deltas;
//指向理想目标输出的指针
double* target;
//指向输入层于隐层之间权值的指针
double** input_weights;
//指向隐层与输出层之间的权值的指针
double** hidden_weights;
//指向上一此输入层于隐层之间权值的指针
double** input_prev_weights ;
//指向上一此隐层与输出层之间的权值的指针
double** hidden_prev_weights;
//每次循环后的均方误差误差值
double ex;
//为各个数据结构申请内存空间
input_unites= BP.alloc_1d_dbl(n_in + 1);
hidden_unites=BP.alloc_1d_dbl(n_hidden + 1);
output_unites=BP.alloc_1d_dbl(n_out + 1);
output=BP.alloc_1d_dbl(n_out+1);
hidden_deltas = BP.alloc_1d_dbl(n_hidden + 1);
output_deltas = BP.alloc_1d_dbl(n_out + 1);
target = BP.alloc_1d_dbl(n_out + 1);
input_weights=BP.alloc_2d_dbl(n_in + 1, n_hidden + 1);
input_prev_weights =BP. alloc_2d_dbl(n_in + 1, n_hidden + 1);
hidden_prev_weights = BP.alloc_2d_dbl(n_hidden + 1, n_out + 1);
hidden_weights = BP.alloc_2d_dbl(n_hidden + 1, n_out + 1);
//为产生随机序列撒种
time_t t;
BP.bpnn_initialize((unsigned)time(&t));
//对各种权值进行初始化初始化
BP.bpnn_randomize_weights( input_weights,n_in,n_hidden);
BP.bpnn_randomize_weights( hidden_weights,n_hidden,n_out);
BP.bpnn_zero_weights(input_prev_weights, n_in,n_hidden );
BP.bpnn_zero_weights(hidden_prev_weights,n_hidden,n_out );
CString str_w;
str_w.Format("%f",input_weights[1][1]);
dc.TextOut(750,50,str_w);
str_w.Format("%f",input_weights[1][2]);
dc.TextOut(750,80,str_w);
str_w.Format("%f",input_weights[2][1]);
dc.TextOut(750,110,str_w);
str_w.Format("%f",input_weights[2][2]);
dc.TextOut(750,140,str_w);
str_w.Format("%f",hidden_weights[1][1]);
dc.TextOut(750,170,str_w);
str_w.Format("%f",hidden_weights[2][1]);
dc.TextOut(750,200,str_w);
//开始进行BP网络训练
//这里设定最大的迭代次数为15000次
double x[15000],y[15000];
for (i=0;i<15000;i++)
{
//对均方误差置零
ex=0;
//对样本进行逐个的扫描
for(j=0;j<num;j++)
{
//将提取的样本的特征向量输送到输入层上
for(k=1;k<=n_in;k++)
input_unites[k] = data_in[j][k-1];
//将预定的理想输出输送到BP网络的理想输出单元
for(k=1;k<=n_out;k++)
target[k]=data_out[j][k-1];
//dc.TextOut()
//前向传输激活
//将数据由输入层传到隐层
BP.bpnn_layerforward(input_unites,hidden_unites,
input_weights, n_in,n_hidden);
//将隐层的输出传到输出层
BP.bpnn_layerforward(hidden_unites, output_unites,
hidden_weights,n_hidden,n_out);
//误差计算
//将输出层的输出与理想输出比较计算输出层每个结点上的误差
BP.bpnn_output_error(output_deltas,target,output_unites,n_out);
//根据输出层结点上的误差计算隐层每个节点上的误差
BP.bpnn_hidden_error(hidden_deltas,n_hidden, output_deltas, n_out,hidden_weights, hidden_unites);
//权值调整
//根据输出层每个节点上的误差来调整隐层与输出层之间的权值
BP.bpnn_adjust_weights(output_deltas,n_out, hidden_unites,n_hidden,
hidden_weights, hidden_prev_weights, eta, momentum);
//根据隐层每个节点上的误差来调整隐层与输入层之间的权值
BP.bpnn_adjust_weights(hidden_deltas, n_hidden, input_unites, n_in,
input_weights, input_prev_weights, eta, momentum);
output[j]=output_unites[1];
//误差统计
for(k=1;k<=n_out;k++)
ex+=(output_unites[k]-data_out[j][k-1])*(output_unites[k]-data_out[j][k-1]);
}
str_w.Format("%f",input_weights[1][1]);
dc.TextOut(750,50,str_w);
str_w.Format("%f",input_weights[1][2]);
dc.TextOut(750,80,str_w);
str_w.Format("%f",input_weights[2][1]);
dc.TextOut(750,110,str_w);
str_w.Format("%f",input_weights[2][2]);
dc.TextOut(750,140,str_w);
str_w.Format("%f",hidden_weights[1][1]);
dc.TextOut(750,170,str_w);
str_w.Format("%f",hidden_weights[2][1]);
dc.TextOut(750,200,str_w);
//计算均方误差
ex=ex/double(num*n_out);
x[i]=i/30+220;
y[i]=750-ex*500;
//如果均方误差已经足够的小,跳出循环,训练完毕
if(ex<min_ex)break;
}
for (k=0;k<i-1;k++)
{
dc.MoveTo(x[k],y[k]);
dc.LineTo(x[k+1],y[k+1]);
}
for (k=0;k<num;k++)
{
str_w.Format("%f",output[k]);
dc.TextOut(1000,500+40*k,str_w);
}
CString str;
if(ex<=min_ex)
{
str.Format ("迭代%d次,\n平均误差%.4f",i,ex);
::MessageBox(NULL,str,"训练结果",NULL);
}
if(ex>min_ex)
{
str.Format("迭代%d次,平均误差%.4f\n我已经尽了最大努力了还是达不到您的要求\n请调整参数重新训练吧!",i,ex);
::MessageBox(NULL,str,"训练结果",NULL);
}
free(input_unites);
free(hidden_unites);
free(output_unites);
free(hidden_deltas);
free(output_deltas);
free(target);
free(input_weights);
free(hidden_weights);
free(input_prev_weights);
free(hidden_prev_weights);
}
void CBp_netView::OnInitialUpdate()
{
CView::OnInitialUpdate();
// TODO: Add your specialized code here and/or call the base class
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -