⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bp_netview.cpp

📁 测试神经网络异或问题
💻 CPP
字号:
// Bp_netView.cpp : implementation of the CBp_netView class
//

#include "stdafx.h"
#include "Bp_net.h"
#include "bp.h"
#include "Bp_netDoc.h"
#include "Bp_netView.h"
#include <string>
#ifdef _DEBUG
#define new DEBUG_NEW
#undef THIS_FILE
static char THIS_FILE[] = __FILE__;
#endif

/////////////////////////////////////////////////////////////////////////////
// CBp_netView

IMPLEMENT_DYNCREATE(CBp_netView, CView)

BEGIN_MESSAGE_MAP(CBp_netView, CView)
	//{{AFX_MSG_MAP(CBp_netView)
	ON_COMMAND(ID_XUNLIA, OnXunlia)
	//}}AFX_MSG_MAP
	// Standard printing commands
	ON_COMMAND(ID_FILE_PRINT, CView::OnFilePrint)
	ON_COMMAND(ID_FILE_PRINT_DIRECT, CView::OnFilePrint)
	ON_COMMAND(ID_FILE_PRINT_PREVIEW, CView::OnFilePrintPreview)
END_MESSAGE_MAP()

/////////////////////////////////////////////////////////////////////////////
// CBp_netView construction/destruction

CBp_netView::CBp_netView()
{
	// TODO: add construction code here

}

CBp_netView::~CBp_netView()
{
}

BOOL CBp_netView::PreCreateWindow(CREATESTRUCT& cs)
{
	// TODO: Modify the Window class or styles here by modifying
	//  the CREATESTRUCT cs

	return CView::PreCreateWindow(cs);
}

/////////////////////////////////////////////////////////////////////////////
// CBp_netView drawing

void CBp_netView::OnDraw(CDC* pDC)
{
	CBp_netDoc* pDoc = GetDocument();
	ASSERT_VALID(pDoc);
	// TODO: add draw code for native data here
		CClientDC  dc(this);
	COLORREF  color=RGB(0,0,0);
	CPen newPen(PS_SOLID,2,color);
	CPen *oldPen;
	oldPen=dc.SelectObject(&newPen);
	CPoint starp,endp,starp1,endp1,endp2;
	starp.x=200;
	starp.y=80;
	endp.x=300;
	endp.y=80;
	int b=20,c=40; 
	starp1.x=200;
	starp1.y=160;
	endp1.x=300;
	endp1.y=160;
	int i,j;
	/////////////////画训练图////////
	for (i=0;i<1;i++) 
	{
		dc.MoveTo(starp);
		dc.LineTo(endp);
		dc.MoveTo(starp1);
		dc.LineTo(endp1);
		dc.MoveTo(starp);
		dc.LineTo(endp1);
		dc.MoveTo(starp1);
		dc.LineTo(endp);
		dc.Ellipse(endp1.x,endp1.y-b,endp1.x+c,endp1.y+b);
		dc.Ellipse(endp.x,endp.y-b,endp.x+c,endp.y+b);
		starp1.x=endp1.x+40;
		endp1.x=starp1.x+100;
		starp.x=endp.x+40;
		endp.x=starp.x+100;
	}
	starp1.x=endp1.x;
	endp1.x=starp1.x-100;
	starp.x=endp.x;
	endp.x=starp.x-100;
	endp2.x=endp.x+100;
	endp2.y=endp.y+40;
	dc.MoveTo(endp);
	dc.LineTo(endp2);
	dc.MoveTo(endp1);
	dc.LineTo(endp2);
	dc.Ellipse(endp2.x,endp2.y-b,endp2.x+c,endp2.y+b);
	dc.MoveTo(endp2.x+40,endp2.y);
	dc.LineTo(endp2.x+100,endp2.y);
		
	dc.TextOut(340,200,"训练网络");	
	dc.TextOut(190,80,"X1");
	dc.TextOut(190,160,"X2");
	dc.TextOut(310,70,"Y1");
	dc.TextOut(310,150,"Y2");
	dc.TextOut(450,110,"Y3");
	
	dc.TextOut(530,120,"O");
	dc.TextOut(230,70,"W1");
	dc.TextOut(230,100,"W2");
	dc.TextOut(230,130,"W3");
	dc.TextOut(230,150,"W4");
	dc.TextOut(370,90,"W5");
	dc.TextOut(370,135,"W6");
	dc.TextOut(160,120,"输入");
	dc.TextOut(550,120,"输出");
	///////////////画坐标///////////
	dc.MoveTo(220,250);
	dc.LineTo(220,750);
	dc.MoveTo(210,260);
	dc.LineTo(220,250);
	dc.MoveTo(230,260);
	dc.LineTo(220,250);
dc.TextOut(250,260,"E(网络能量)");
dc.TextOut(720,750,"T(迭代次数)");

	dc.MoveTo(220,750);
	dc.LineTo(720,750);
	dc.MoveTo(710,740);
	dc.LineTo(720,750);
	dc.MoveTo(710,760);
	dc.LineTo(720,750);
	dc.SelectObject(oldPen);
	CPoint WW;//权值输出的位置初始坐标
	WW.x=750;
	WW.y=50;
	dc.TextOut(700,50,"W1=");
	dc.TextOut(700,80,"W2=");
	dc.TextOut(700,110,"W3=");
	dc.TextOut(700,140,"W4=");
	dc.TextOut(700,170,"W5=");
	dc.TextOut(700,200,"W6=");
	//
	dc.TextOut(900,440,"异或问题");
	dc.TextOut(1000,480,"输出结果");
	dc.TextOut(820,480,"初始输入");
	dc.TextOut(820,500,"0       0");
	dc.TextOut(820,540,"0       1");
	dc.TextOut(820,580,"1       0");
	dc.TextOut(820,620,"1       1");
}

/////////////////////////////////////////////////////////////////////////////
// CBp_netView printing

BOOL CBp_netView::OnPreparePrinting(CPrintInfo* pInfo)
{
	// default preparation
	return DoPreparePrinting(pInfo);
}

void CBp_netView::OnBeginPrinting(CDC* /*pDC*/, CPrintInfo* /*pInfo*/)
{
	// TODO: add extra initialization before printing
}

void CBp_netView::OnEndPrinting(CDC* /*pDC*/, CPrintInfo* /*pInfo*/)
{
	// TODO: add cleanup after printing
}

/////////////////////////////////////////////////////////////////////////////
// CBp_netView diagnostics

#ifdef _DEBUG
void CBp_netView::AssertValid() const
{
	CView::AssertValid();
}

void CBp_netView::Dump(CDumpContext& dc) const
{
	CView::Dump(dc);
}

CBp_netDoc* CBp_netView::GetDocument() // non-debug version is inline
{
	ASSERT(m_pDocument->IsKindOf(RUNTIME_CLASS(CBp_netDoc)));
	return (CBp_netDoc*)m_pDocument;
}
#endif //_DEBUG

/////////////////////////////////////////////////////////////////////////////
// CBp_netView message handlers

void CBp_netView::OnXunlia() 
{
	// TODO: Add your command handler code here
	CClientDC dc(this);
	bp BP;
	int n_in=2;
	int n_out=1;
	int n_hidden=2;
	int i,j,k;
	int num=4;
	double min_ex=0.001;
	double eta=0.08,momentum=0.65;
	double **data_in,**data_out;
	data_in=BP.alloc_2d_dbl(4,2);
	data_out=BP.alloc_2d_dbl(4,1);
	data_in[0][0]=0.0;	data_in[0][1]=0.0;
	data_in[1][0]=0.0;	data_in[1][1]=1.0;
	data_in[2][0]=1.0;	data_in[2][1]=0.0;
	data_in[3][0]=1.0;	data_in[3][1]=1.0;
	data_out[0][0]=0.0;
	data_out[1][0]=1.0;
	data_out[2][0]=1.0;
	data_out[3][0]=0.0;
	double *output;
		//指向输入层数据的指针
	double* input_unites; 
	//指向隐层数据的指针
	double* hidden_unites;
	//指向输出层数据的指针
	double* output_unites; 
	//指向隐层误差数据的指针
	double* hidden_deltas;
	//指向输出层误差数剧的指针
	double* output_deltas;  
	//指向理想目标输出的指针
	double* target;    
	//指向输入层于隐层之间权值的指针
	double** input_weights;
	//指向隐层与输出层之间的权值的指针
	double** hidden_weights;
	//指向上一此输入层于隐层之间权值的指针
	double** input_prev_weights ;
	//指向上一此隐层与输出层之间的权值的指针
	double** hidden_prev_weights;
	//每次循环后的均方误差误差值 
	double ex;
	
	//为各个数据结构申请内存空间
	input_unites= BP.alloc_1d_dbl(n_in + 1);
	hidden_unites=BP.alloc_1d_dbl(n_hidden + 1);
	output_unites=BP.alloc_1d_dbl(n_out + 1);
	output=BP.alloc_1d_dbl(n_out+1);
	hidden_deltas = BP.alloc_1d_dbl(n_hidden + 1);
	output_deltas = BP.alloc_1d_dbl(n_out + 1);
	target = BP.alloc_1d_dbl(n_out + 1);
	input_weights=BP.alloc_2d_dbl(n_in + 1, n_hidden + 1);
	input_prev_weights =BP. alloc_2d_dbl(n_in + 1, n_hidden + 1);
	hidden_prev_weights = BP.alloc_2d_dbl(n_hidden + 1, n_out + 1);
	hidden_weights = BP.alloc_2d_dbl(n_hidden + 1, n_out + 1);
	
	//为产生随机序列撒种
	time_t t; 
	BP.bpnn_initialize((unsigned)time(&t));
	
	//对各种权值进行初始化初始化
	BP.bpnn_randomize_weights( input_weights,n_in,n_hidden);
	BP.bpnn_randomize_weights( hidden_weights,n_hidden,n_out);
	BP.bpnn_zero_weights(input_prev_weights, n_in,n_hidden );
	BP.bpnn_zero_weights(hidden_prev_weights,n_hidden,n_out );
	CString str_w;
			str_w.Format("%f",input_weights[1][1]);
			dc.TextOut(750,50,str_w);
			str_w.Format("%f",input_weights[1][2]);
			dc.TextOut(750,80,str_w);
			str_w.Format("%f",input_weights[2][1]);
			dc.TextOut(750,110,str_w);
			str_w.Format("%f",input_weights[2][2]);
			dc.TextOut(750,140,str_w);
			str_w.Format("%f",hidden_weights[1][1]);
			dc.TextOut(750,170,str_w);
			str_w.Format("%f",hidden_weights[2][1]);
			dc.TextOut(750,200,str_w);
	//开始进行BP网络训练
	//这里设定最大的迭代次数为15000次
				double x[15000],y[15000];
	for (i=0;i<15000;i++)
	{
			//对均方误差置零
		ex=0;
		//对样本进行逐个的扫描
		
		for(j=0;j<num;j++)  
		{ 
			//将提取的样本的特征向量输送到输入层上
			for(k=1;k<=n_in;k++)
				input_unites[k] = data_in[j][k-1];
			
			//将预定的理想输出输送到BP网络的理想输出单元
			for(k=1;k<=n_out;k++)
				target[k]=data_out[j][k-1];
			//dc.TextOut()
		
		
			//前向传输激活
			
			//将数据由输入层传到隐层 
			BP.bpnn_layerforward(input_unites,hidden_unites,
				input_weights, n_in,n_hidden);
			//将隐层的输出传到输出层
			BP.bpnn_layerforward(hidden_unites, output_unites,
				hidden_weights,n_hidden,n_out);
			
			//误差计算
			
			//将输出层的输出与理想输出比较计算输出层每个结点上的误差
			BP.bpnn_output_error(output_deltas,target,output_unites,n_out);
			//根据输出层结点上的误差计算隐层每个节点上的误差
			BP.bpnn_hidden_error(hidden_deltas,n_hidden, output_deltas, n_out,hidden_weights, hidden_unites);
			
			//权值调整
			//根据输出层每个节点上的误差来调整隐层与输出层之间的权值    
			BP.bpnn_adjust_weights(output_deltas,n_out, hidden_unites,n_hidden,
				hidden_weights, hidden_prev_weights, eta, momentum); 
			//根据隐层每个节点上的误差来调整隐层与输入层之间的权值    	
			BP.bpnn_adjust_weights(hidden_deltas, n_hidden, input_unites, n_in,
				input_weights, input_prev_weights, eta, momentum); 
			output[j]=output_unites[1];
			//误差统计		
			for(k=1;k<=n_out;k++)
				ex+=(output_unites[k]-data_out[j][k-1])*(output_unites[k]-data_out[j][k-1]);
		}
		str_w.Format("%f",input_weights[1][1]);
			dc.TextOut(750,50,str_w);
			str_w.Format("%f",input_weights[1][2]);
			dc.TextOut(750,80,str_w);
			str_w.Format("%f",input_weights[2][1]);
			dc.TextOut(750,110,str_w);
			str_w.Format("%f",input_weights[2][2]);
			dc.TextOut(750,140,str_w);
			str_w.Format("%f",hidden_weights[1][1]);
			dc.TextOut(750,170,str_w);
			str_w.Format("%f",hidden_weights[2][1]);
			dc.TextOut(750,200,str_w);
			//计算均方误差
		ex=ex/double(num*n_out);
			x[i]=i/30+220;
			y[i]=750-ex*500;	
			
		//如果均方误差已经足够的小,跳出循环,训练完毕  
		if(ex<min_ex)break;
		
	}
	for (k=0;k<i-1;k++) 
	{
		dc.MoveTo(x[k],y[k]);
		dc.LineTo(x[k+1],y[k+1]);
	}
	for (k=0;k<num;k++)
	{
			str_w.Format("%f",output[k]);
		dc.TextOut(1000,500+40*k,str_w);
	}
	CString str;
	if(ex<=min_ex)
	{
		str.Format ("迭代%d次,\n平均误差%.4f",i,ex);
		
		::MessageBox(NULL,str,"训练结果",NULL);
	}
	
	if(ex>min_ex)
	{
		
		str.Format("迭代%d次,平均误差%.4f\n我已经尽了最大努力了还是达不到您的要求\n请调整参数重新训练吧!",i,ex);
		::MessageBox(NULL,str,"训练结果",NULL);
	}
		free(input_unites);
	free(hidden_unites);
	free(output_unites);
	free(hidden_deltas);
	free(output_deltas);
	free(target);
	free(input_weights);
	free(hidden_weights);
	free(input_prev_weights);
	free(hidden_prev_weights);
}

void CBp_netView::OnInitialUpdate() 
{
	CView::OnInitialUpdate();
	
	// TODO: Add your specialized code here and/or call the base class
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -