⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 annbp.cpp

📁 annbp_BP神经网络原代码,经过一些算法改动,希望大家多多支持.
💻 CPP
📖 第 1 页 / 共 2 页
字号:
// Annbp.cpp: implementation of the Annbp class.
//
//////////////////////////////////////////////////////////////////////

#include "stdafx.h"
#include "Annbp.h"
#include "math.h"
#include <process.h>
#include <errno.h> 

#ifdef _DEBUG
#undef THIS_FILE
static char THIS_FILE[]=__FILE__;
#define new DEBUG_NEW
#endif


IMPLEMENT_SERIAL(Annbp,CObject,1);
//////////////////////////////////////////////////////////////////////
// Construction/Destruction
//////////////////////////////////////////////////////////////////////

Annbp::Annbp()
{
	InputNum = 0;
	IsTrained = FALSE;

	for (int i=0;i<5;i++)
		HideNum[i] = 0;
	OutputFunctionType = 1;
	HideStep = 0.15;
	OutputStep = 0.1;
	a = 0.6;
	Tol = 0.01;
	//OutputMax = 1;
	//OutputMin = 0;
	FuncSlope = 1;

	maxStudyTimes = 200;
	rateOfTolChg = 0.0001;
}

Annbp::~Annbp()
{
	for (int i=0;i<x.GetSize()-1;i++)
		x[i].RemoveAll();
	x.RemoveAll();

	for (i=0;i<o.GetSize();i++)
		o[i].RemoveAll();
	o.RemoveAll();

	for (i=0;i<w.GetSize();i++)
	{
		for (int j=0;j<w[i].GetSize();j++)
			w[i][j].RemoveAll();
		w[i].RemoveAll();
	}
	w.RemoveAll();

	SamNo.RemoveAll();
	CoverNodeNum.RemoveAll();
}

void Annbp::Serialize(CArchive& ar)
{
	CObject::Serialize(ar);
	if (ar.IsStoring())
	{
		ar<<Name;
		ar<<TrainingSet;
		ar<<TrainingSetNum;
		ar<<TestSet;
		ar<<TestSetNum;
		ar<<InputNum;
		ar<<OutputNum;
		ar<<HideCoverNum;
		for(int i=0;i<HideCoverNum;i++)
			ar<<HideNum[i];
		ar<<OutputFunctionType;
		ar<<HideStep;
		ar<<OutputStep;
		ar<<a;
		ar<<Tol;		
	}
	else
	{
		ar>>Name;
		ar>>TrainingSet;
		ar>>TrainingSetNum;
		ar>>TestSet;
		ar>>TestSetNum;
		ar>>InputNum;
		ar>>OutputNum;
		ar>>HideCoverNum;
		for(int i=0;i<HideCoverNum;i++)
			ar>>HideNum[i];
		ar>>OutputFunctionType;
		ar>>HideStep;
		ar>>OutputStep;
		ar>>a;
		ar>>Tol;
	}
}

BOOL Annbp::Study()
{
	if (InputNum == 0 )
	{
		MessageBox(NULL,"请先进行神经网络设置!","提示",MB_ICONINFORMATION|MB_OK);
		return FALSE;
	}
	//动态数组设定
	for (int i=0;i<x.GetSize();i++)
		x[i].RemoveAll();
	x.RemoveAll();
	for (i=0;i<o.GetSize();i++)
		o[i].RemoveAll();
	o.RemoveAll();
	SamNo.RemoveAll();

	x.SetSize(TrainingSetNum);
	for (i=0;i<TrainingSetNum;i++)
		x[i].SetSize(InputNum+1);
	o.SetSize(TrainingSetNum);
	for (i=0;i<TrainingSetNum;i++)
		o[i].SetSize(OutputNum);

	SamNo.SetSize(TrainingSetNum);
	
	SetCoverNodeNum();
	WeightSize();
	InitWeight();
	ReadSample();

	CArray<CArray<double,double>,CArray<double,double>&> y;
	y.SetSize(HideCoverNum+2);
	for (i=0;i<HideCoverNum+1;i++)
		y[i].SetSize(CoverNodeNum[i]+1);
	y[HideCoverNum+1].SetSize(OutputNum);
	double se,se0=0;
//
	FILE *fp;
	if((fp = fopen("D:\\jlm\\ANN\\BPNN\\学习过程.txt","w"))==NULL)
	{
		MessageBox(NULL,"打开 学习过程 文件错误!","错误",MB_OK);
		return 9999;
	}
//

	//val for display
	int Top_y = 0;
	double se_ori;
	///
	StudyTimes = 0;
	do
	{
		se = 0;
		RandSamNo();
		for (int n=0;n<TrainingSetNum;n++)
		{
			//forward
			for (int in=0;in<InputNum+1;in++)
				y[0][in] = x[SamNo[n]][in];
			for(int c=1;c<HideCoverNum+1;c++)//隐层
			{
				for(int j=1;j<CoverNodeNum[c];j++)
				{
					double v=0;
					for(int i=0;i<CoverNodeNum[c-1];i++)
						v += w[c][j][i]*y[c-1][i];
					//隐层单元激发
					y[c][j] = 1.0/(1+exp(-v*FuncSlope));
				}
			}
			for(int k=0;k<OutputNum;k++)//输出层
			{
				double v=0;
				for(int i=0;i<CoverNodeNum[HideCoverNum];i++)
					v += w[HideCoverNum+1][k][i]*y[HideCoverNum][i];

				/*if(OutputFunctionType == 1)//线性
					y[HideCoverNum+1][k] = v;
				else//sigmoid
					y[HideCoverNum+1][k] = 1.0/(1+exp(-v*FuncSlope));*/

				switch (OutputFunctionType)//0---sigmoid函数;1---线性阈值函数;2---线性多分段函数;3---线性01分段函数;
				{
					case 0:
						y[HideCoverNum+1][k] = 1.0/(1+exp(-v*FuncSlope));
						break;
					case 1:
						y[HideCoverNum+1][k] = v;
						break;
					case 2:
						break;
					case 3:
						if (v <= 0)
							y[HideCoverNum+1][k] = 0.0;
						else
							y[HideCoverNum+1][k] = 1.0;
						break;
				}
			}
			//backward
			CArray<double,double> e;
			e.SetSize(OutputNum);
			CArray<CArray<double,double>,CArray<double,double>&> s;
			s.SetSize(HideCoverNum+2);
			s[0].SetSize(0);
			for(int i=1;i<HideCoverNum+2;i++)
				s[i].SetSize(CoverNodeNum[i]);

			for(k=0;k<OutputNum;k++)
			{
				e[k] = o[SamNo[n]][k]-y[HideCoverNum+1][k];
				if(OutputFunctionType != 0)//线性
					s[HideCoverNum+1][k] = e[k];
				else
					s[HideCoverNum+1][k] = e[k]*y[HideCoverNum+1][k]*(1-y[HideCoverNum+1][k])*FuncSlope;
				se = se + 0.5*pow(e[k],2);
			}

			c = HideCoverNum;//last hide cover
			for(int j=0;j<CoverNodeNum[HideCoverNum];j++)
			{
				double sw=0;
				for(int i=0;i<OutputNum;i++)
					sw += s[c+1][i]*w[c+1][i][j];//原为 = ,补加 + 号,为 +=。2004.2.26
				s[c][j] = y[c][j]*(1-y[c][j])*sw*FuncSlope;
			}
			for(c=HideCoverNum-1;c>0;c--)//other hide covers
				for(int j=0;j<CoverNodeNum[c];j++)
				{
					double sw=0;
					for(int i=1;i<CoverNodeNum[c+1];i++)//0--1
						sw += s[c+1][i]*w[c+1][i][j];//原为 = ,补加 + 号,为 +=。2004.2.26
					s[c][j] = y[c][j]*(1-y[c][j])*sw*FuncSlope;
				}
			//修正权值
			double dw=0;
			for(k=0;k<OutputNum;k++)
				for(i=0;i<CoverNodeNum[HideCoverNum];i++)
				{
					dw = OutputStep*s[HideCoverNum+1][k]*y[HideCoverNum][i];
					w[HideCoverNum+1][k][i] += a*dw0[HideCoverNum+1][k][i]+dw;
					dw0[HideCoverNum+1][k][i] = dw;
				}
			for(c=HideCoverNum;c>0;c--)
				for(int j=1;j<CoverNodeNum[c];j++)
					for(i=0;i<CoverNodeNum[c-1];i++)
					{
						dw = HideStep*s[c][j]*y[c-1][i];
						w[c][j][i] += a*dw0[c][j][i]+dw;
						dw0[c][j][i] = dw;
					}

			e.RemoveAll();
			for(i=0;i<HideCoverNum+2;i++)
				s[i].RemoveAll();
			s.RemoveAll();
		}
		se = se/TrainingSetNum;

		StudyTimes++;
		StudyTol = se;
		//////////
		//Display
		CString str1,str2;
		str1.Format("学习次数:%16d",StudyTimes);
		str2.Format("均方误差:%16.6f",StudyTol);
		CWnd* pWnd = AfxGetMainWnd();
        CDC* pDC = pWnd->GetDC();
		CRect* pRect = new CRect;
		pWnd->GetClientRect(pRect);
		ASSERT(pRect);

		pDC->SetBkColor(RGB(0,255,255));
		LOGFONT logFont;
		logFont.lfHeight=13;
		logFont.lfWeight=0;
		logFont.lfEscapement=0;
		logFont.lfOrientation=0;
		logFont.lfWeight=FW_MEDIUM;
		logFont.lfItalic=0;
		logFont.lfUnderline=0;
		logFont.lfStrikeOut=0;
		logFont.lfCharSet=GB2312_CHARSET;
		logFont.lfQuality=PROOF_QUALITY;
		logFont.lfPitchAndFamily=VARIABLE_PITCH|FF_DONTCARE;
		strcpy(logFont.lfFaceName,"宋体");
		CFont font;
		font.CreateFontIndirect(&logFont);
		CFont* poldFont=pDC->SelectObject(&font);

		int m_Left = int(pRect->Width()*2/3.0); 
		int m_Right = int(pRect->Width()*0.99); 
		int m_Top = int(pRect->Height()*2/15.0);
		int m_Bot = int(pRect->Height()*13/15.0);
		int m_Width = m_Right - m_Left;
		int m_Height = m_Bot - m_Top;
		pDC->TextOut(int(pRect->Width()/1.4),int(pRect->Height()*13.4/15.0),str1);
		pDC->TextOut(int(pRect->Width()/1.4),int(pRect->Height()*13.8/15.0),str2);
		//pWnd->InvalidateRect(pRect,true);

		pDC->SelectObject(poldFont);

		CPen m_pen,m_pen1;
		CPen* poldpen;

		RECT r;
		int cell = 10;
		r.left = m_Left;
		int col = int(m_Width/cell);
		r.right = r.left+cell*col;
		r.top = m_Top;
		int row = int(m_Height/cell);
		r.bottom = r.top+cell*row;

//		InvalidateRect(pWnd->m_hWnd,&r,TRUE);
		for (int cn=0;cn<=col;cn++)
		{
			pDC->MoveTo(r.left+cell*cn,r.top);
			pDC->LineTo(r.left+cell*cn,r.bottom);
		}
		for (int rn=0;rn<=row;rn++)
		{
			pDC->MoveTo(r.left,r.top+cell*rn);
			pDC->LineTo(r.right,r.top+cell*rn);
		}

		m_pen.CreatePen(PS_SOLID,3,GetSysColor(COLOR_WINDOW));
		poldpen = pDC->SelectObject(&m_pen);
		pDC->MoveTo((m_Left+m_Right)/2,Top_y);
		if(StudyTimes>1)
			pDC->LineTo((m_Left+m_Right)/2,m_Bot);
		pDC->SelectObject(poldpen);

		m_pen1.CreatePen(PS_SOLID,3,RGB(255,0,0));
		poldpen = pDC->SelectObject(&m_pen1);

		if (StudyTimes==1)
			se_ori = se;

		if (StudyTimes==1)
		{
			Top_y = m_Top;
		}
		else
		{
			Top_y = int(m_Bot-se/se_ori*m_Top);
		}
		pDC->MoveTo((m_Left+m_Right)/2,Top_y);
		pDC->LineTo((m_Left+m_Right)/2,m_Bot);
   	    
		pDC->SelectObject(poldpen);
		delete pRect;
		ReleaseDC(pWnd->m_hWnd,pDC->m_hDC);//Release!!
		//End of Display
		/////////
		fprintf(fp,"%d %f\n",StudyTimes,StudyTol);
		if(StudyTimes==maxStudyTimes) break;
		if(fabs((se-se0)/se)<rateOfTolChg) break;
		//////////
		se0=se;
	}while(se>Tol);
	fclose(fp);

	for (i=0;i<y.GetSize();i++)
		y[i].RemoveAll();
	y.RemoveAll();

	//beep
	//Beep(20000,500);
	Beep(10300,200);
	//Beep(800,300);
	//Beep(10300,500);
	//Beep(800,600);
	IsTrained = TRUE;

	::ShellExecute(NULL,NULL,"学习过程.txt",NULL,"D:\\jlm\\ANN\\BPNN",SW_SHOW);//%SystemRoot%\system32\notepad.exe
	return TRUE;
}

BOOL Annbp::ReadSample()
{
/*	CFile f;
	CFileException e;
	if( !f.Open( TrainingSet,CFile::modeRead,&e ) )
		{
		#ifdef _DEBUG
		afxDump << "File could not be opened " << e.m_cause << "\n";
		#endif
		return FALSE;
		}*/
	FILE *fp;
	if((fp = fopen("D:\\jlm\\ANN\\BPNN\\"+TrainingSet,"r"))==NULL)
	{
		MessageBox(NULL,"打开训练集样本文件错误","错误",MB_OK);
		return FALSE;
		exit(0);
	}
	for (int n=0;n<TrainingSetNum;n++)
	{
		for (int i=1;i<InputNum+1;i++)
			fscanf(fp,"%lf",&x[n][i]);
		for (int j=0;j<OutputNum;j++)
			fscanf(fp,"%lf",&o[n][j]);
	}
	for (n=0;n<TrainingSetNum;n++)
		x[n][0] = -1;
	fclose(fp);

	return TRUE;
}

void Annbp::InitWeight()
{
	//////
	FILE *fp;
	fp = fopen("D:\\jlm\\ANN\\BPNN\\权重.txt","w+");
	//////
	//初始化权重
	srand(GetTickCount());
	//int n=0;
	//for(int i=0;i<HideCoverNum+1;i++)
	//	n+=CoverNodeNum[i];
	//double ave_n = n/(HideCoverNum+1.0);
	//double bound = 3/sqrt(ave_n);
	//fprintf(fp,"平均节点数,初始权重界限 %f %f\n",ave_n,bound);
	double bound;
	for (int i=1;i<HideNum[0]+1;i++)
		for(int j=0;j<InputNum+1;j++)
		{
			int r = rand();
			if((r%2) == 0) r=-r;
			bound = 2.4/(InputNum+1);
			w[1][i][j] = r*bound/RAND_MAX;
			dw0[1][i][j] = 0.0;
			/////
			fprintf(fp,"%f ",w[1][i][j]);
		}
	fprintf(fp,"\n");
	for (i=2;i<HideCoverNum+1;i++)
		for (int j=1;j<HideNum[i-1]+1;j++)
			for (int k=0;k<HideNum[i-2]+1;k++)
			{
				int r = rand();
				if(r<RAND_MAX/2) r=-r;
				bound = 2.4/(HideNum[i-2]+1);
				w[i][j][k] = r*bound/RAND_MAX;
				dw0[i][j][k] = 0.0;
				///////
				fprintf(fp,"%f ",w[i][j][k]);
			}
	fprintf(fp,"\n");
	for (i=0;i<OutputNum;i++)
		for (int j=0;j<HideNum[HideCoverNum-1]+1;j++)
		{
			int r = rand();
			if(r<RAND_MAX/2) r=-r;
			bound = 2.4/(HideNum[HideCoverNum-1]+1);
			w[HideCoverNum+1][i][j] = r*bound/RAND_MAX;
			dw0[HideCoverNum+1][i][j] = 0.0;
			/////////
			fprintf(fp,"%f ",w[HideCoverNum+1][i][j]);
		}
	/////
	fclose(fp);
	/////
}

void Annbp::WeightSize()
{
	for (int i=0;i<w.GetSize();i++)
	{
		for (int j=0;j<w[i].GetSize();j++)
			w[i][j].RemoveAll();
		w[i].RemoveAll();
	}
	w.RemoveAll();

	//w 从1---HideCoverNum+1为有效权重
	w.SetSize(HideCoverNum+2);
	w[0].SetSize(0);
	for (i=1;i<HideCoverNum+1;i++)
		w[i].SetSize(HideNum[i-1]+1);
	w[HideCoverNum+1].SetSize(OutputNum);

	for (int j=1;j<HideNum[0]+1;j++)//j=0,w---阈值
		w[1][j].SetSize(InputNum+1);
	for (i=1;i<HideCoverNum+1;i++)
		w[i][0].SetSize(0);
	for (i=2;i<HideCoverNum+1;i++)
		for (j=1;j<HideNum[i-1]+1;j++)
			w[i][j].SetSize(HideNum[i-2]+1);
	for (i=0;i<OutputNum;i++)
		w[HideCoverNum+1][i].SetSize(HideNum[HideCoverNum-1]+1);	
	//dw0
	for (i=0;i<dw0.GetSize();i++)
	{
		for (int j=0;j<dw0[i].GetSize();j++)
			dw0[i][j].RemoveAll();
		dw0[i].RemoveAll();
	}
	dw0.RemoveAll();

	//dw0 从1---HideCoverNum+1为有效权重
	dw0.SetSize(HideCoverNum+2);
	dw0[0].SetSize(0);
	for (i=1;i<HideCoverNum+1;i++)
		dw0[i].SetSize(HideNum[i-1]+1);
	dw0[HideCoverNum+1].SetSize(OutputNum);

	for (j=1;j<HideNum[0]+1;j++)//j=0,dw0---阈值
		dw0[1][j].SetSize(InputNum+1);
	for (i=1;i<HideCoverNum+1;i++)
		dw0[i][0].SetSize(0);
	for (i=2;i<HideCoverNum+1;i++)
		for (j=1;j<HideNum[i-1]+1;j++)
			dw0[i][j].SetSize(HideNum[i-2]+1);
	for (i=0;i<OutputNum;i++)
		dw0[HideCoverNum+1][i].SetSize(HideNum[HideCoverNum-1]+1);	
}

void Annbp::RandSamNo()
{
	for (int i=0;i<TrainingSetNum;i++)
		SamNo[i] = i;
	srand(GetTickCount());
	int temp;
	int r;
	for (i=0;i<TrainingSetNum;i++)
	{
		r = rand()*(TrainingSetNum-i-1)/RAND_MAX;
		temp = SamNo[r];
		SamNo[r] = SamNo[TrainingSetNum-i-1];
		SamNo[TrainingSetNum-i-1] = temp;
	}
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -