⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bpnet.cpp

📁 BP算法的运算矩阵
💻 CPP
📖 第 1 页 / 共 2 页
字号:
#if !defined(AFX_BPNET_H__7ACF7725_EE66_11D6_AAF0_00E04F29491B__INCLUDED_)
#define AFX_BPNET_H__7ACF7725_EE66_11D6_AAF0_00E04F29491B__INCLUDED_

#if _MSC_VER > 1000
#pragma once
#endif // _MSC_VER > 1000
// BpNet.h : header file
//

/////////////////////////////////////////////////////////////////////////////
// CBpNet window
#include<matlib.h>
class CBpNet : public CObject
{
// Construction
public:
	CBpNet();

// Attributes
public:

// Operations
public:

// Overrides
	// ClassWizard generated virtual function overrides
	//{{AFX_VIRTUAL(CBpNet)
	//}}AFX_VIRTUAL

// Implementation
public:
	void Serialize( CArchive& ar );
	void display(Mm data);
	Mm scope(Mm mData);
	long lEpochs;
	double dblMse;
	double dblError;
	
	double randab(double a,double b);
	void stop();
	void learn();
	bool SaveBpNet(CString &strNetName);
	void LoadBpNet(CString &strNetName);
	Mm simulate(Mm mData);
	void Create(Mm mInputData,Mm mTarget,int iInput,int iHidden,int iOutput);
	virtual ~CBpNet();

	// Generated message map functions
protected:
	//{{AFX_MSG(CBpNet)
		// NOTE - the ClassWizard will add and remove member functions here.
	//}}AFX_MSG

	DECLARE_SERIAL(CBpNet)

public:
	bool m_isOK;
	void LoadPattern(Mm mIn,Mm mOut);
	int iHidden;//隐层神经元个数
	int iInput;//输入个数
	int iOutput;//输出个数
protected:
	
	Mm mInput;//单个样本输入数据
	Mm mSampleInput;//全体样本输入数据
	Mm mSampleTarget;//全体目标数据
    Mm mHidden;//计算得到的隐层数据
	Mm mOutput;//计算输出
	Mm mWeighti;//输入-隐层权重
	Mm mWeighto;//隐层-输出权重
	Mm mChangei;//输入-隐层权重变化
	Mm mChangeo;//隐层-输出权重变化
public:
	Mm mInputNormFactor;//正规化因子,iInputx2
	Mm mTargetNormFactor;//输出正规化因子,iOutputx2
protected:
	Mm mThresholdi;//阙值
	Mm mThresholdo;
	Mm mOutputDeltas;//误差
	Mm mHiddenDeltas;
protected:
	bool m_IsStop;
	double dblMomentumFactor;
	double dblLearnRate1;
	double dblLearnRate2;
	void backward(int iSample);
	void forward(int iSample);
	void normalize();//将输入输出样本数据正规化处理

private:
	double dblErr;
};

/////////////////////////////////////////////////////////////////////////////

//{{AFX_INSERT_LOCATION}}
// Microsoft Visual C++ will insert additional declarations immediately before the previous line.

#endif // !defined(AFX_BPNET_H__7ACF7725_EE66_11D6_AAF0_00E04F29491B__INCLUDED_)

// BpNet.cpp : implementation file
////////////////////////////////////////////////////////////////////
/////////////////人工神经网络BP算法/////////////////////////////////
//1、动态改变学习速率
//2、加入动量项
//3、运用了Matcom4.5的矩阵运算库(可免费下载,头文件matlib.h),
//   方便矩阵运算,当然,也可自己写矩阵类
//4、可暂停运算
//5、可将网络以文件的形式保存、恢复
///////////////作者:江西理工大学   张克俊//////////////////////
///////////////email:chunyu_79@hotmail.com//////////////////////////
///////////////QQ:53806186//////////////////////////////////////////
///////////////欢迎不断改进!欢迎讨论其他实用的算法!/////////////////

#include "BpNet.h"

#ifdef _DEBUG
#define new DEBUG_NEW
#undef THIS_FILE
static char THIS_FILE[] = __FILE__;
#endif

/////////////////////////////////////////////////////////////////////////////
// CBpNet
IMPLEMENT_SERIAL( CBpNet, CObject, 1 )

CBpNet::CBpNet()
{initM(MATCOM_VERSION);//启用矩阵运算库
}

CBpNet::~CBpNet()
{exitM();
 delete this;
}




/////////////////////////////////////////////////////////////////////////////
// CBpNet message handlers
//创建新网络
void CBpNet::Create(Mm mInputData, Mm mTarget, int iInput, int iHidden, int iOutput)
{ int i,j;
  mSampleInput=zeros(mInput.rows(),mInput.cols());
  mSampleTarget=zeros(mTarget.rows(),mTarget.cols());  
  mSampleInput=mInputData;
  mSampleTarget=mTarget;
  this->iInput=iInput;
  this->iHidden=iHidden;
  this->iOutput=iOutput;
  //创建计算用的单个样本矩阵
  mInput=zeros(1,this->iInput);
  mHidden=zeros(1,this->iHidden);
  mOutput=zeros(1,this->iOutput);
  //创建权重矩阵,并赋初值
  mWeighti=zeros(this->iInput,this->iHidden);
  mWeighto=zeros(this->iHidden,this->iOutput); 
  //赋初值
  for(i=1;i<=this->iInput;i++)
	  for(j=1;j<=this->iHidden;j++)
		  mWeighti.r(i,j)=randab(-1.0,1.0);
  for(i=1;i<=this->iHidden;i++)
	  for(j=1;j<=this->iOutput;j++)
		  mWeighto.r(i,j)=randab(-1.0,1.0);
  
  //创建阙值矩阵,并赋值
  mThresholdi=zeros(1,this->iHidden);
  for(i=1;i<=this->iHidden;i++)
	  mThresholdi.r(i)=randab(-1.0,1.0);
  mThresholdo=zeros(1,this->iOutput);
  for(i=1;i<=this->iOutput;i++)
	  mThresholdo.r(i)=randab(-1.0,1.0);
  //创建权重变化矩阵
  mChangei=zeros(this->iInput,this->iHidden);
  mChangeo=zeros(this->iHidden,this->iOutput);
  
  mInputNormFactor=zeros(iInput,2);
  mTargetNormFactor=zeros(iOutput,2);
  //误差矩阵
   mOutputDeltas=zeros(iOutput);
   mHiddenDeltas=zeros(iHidden);
  //学习速率赋值
  dblLearnRate1=0.5;
  dblLearnRate2=0.5;
  dblMomentumFactor=0.95;
  
  m_isOK=false;
  m_IsStop=false;
  dblMse=1.0e-6;//误差限
  dblError=1.0;
  lEpochs=0;

}
//根据已有的网络进行预测
Mm CBpNet::simulate(Mm mData)
{int i,j;
 Mm mResult;
 Mm data=zeros(mData.rows(),mData.cols());
 data=mData;
 if(mData.cols()!=iInput)
 {::MessageBox(NULL,"输入数据变量个数错误!","输入数据变量个数错误!",MB_OK);
  return mResult;
 }
 mResult=zeros(data.rows(),iOutput); 
 //正规化数据
 for(i=1;i<=data.rows();i++)
	 for(j=1;j<=data.cols();j++) 
         data.r(i,j)=(data.r(i,j)-mInputNormFactor.r(j,1))/(mInputNormFactor.r(j,2)-mInputNormFactor.r(j,1));
 //计算
	 int iSample;
	 Mm mInputdata,mHiddendata,mOutputdata;
	 mInputdata=zeros(1,iInput);
	 mHiddendata=zeros(1,iHidden);
	 mOutputdata=zeros(1,iOutput);
	 double sum=0.0;
   for(iSample=1;iSample<=data.rows();iSample++){ 
	 //输入层数据
	    for(i=1;i<=iInput;i++)
		  mInputdata.r(i)=data.r(iSample,i);
	 //隐层数据	
		for(j=1;j<=iHidden;j++){
		  sum=0.0;
		 for(i=1;i<=iInput;i++)
			sum+=mInputdata.r(i)*mWeighti.r(i,j);
		 sum-=mThresholdi.r(j); 
		 mHiddendata.r(j)=1.0/(1.0+exp(-sum));
		}
    
  //输出数据
	for(j=1;j<=iOutput;j++){
		sum=0.0;
		for(i=1;i<=iHidden;i++)
			sum+=mHiddendata.r(i)*mWeighto.r(i,j);
		sum-=mThresholdo.r(j); 
		mOutputdata.r(j)=1.0/(1.0+exp(-sum));
	}
	
	//转换
	for(j=1;j<=iOutput;j++)
        mResult.r(iSample,j)=mOutputdata.r(j)*(mTargetNormFactor.r(j,2)-mTargetNormFactor.r(j,1))+mTargetNormFactor.r(j,1);
 }
  
 return (mResult);
}

void CBpNet::LoadBpNet(CString &strNetName)
{CFile file;
 
 if(file.Open(strNetName,CFile::modeRead)==0)   
 {MessageBox(NULL,"无法打开文件!","错误",MB_OK);
  return;
 }
 else{
	 CArchive myar(&file,CArchive::load);
	 Serialize(myar); 
	 myar.Close();
 }
 file.Close(); 
}

bool CBpNet::SaveBpNet(CString &strNetName)
{CFile file;
 if(strNetName.GetLength()==0)
	 return(false);
 if(file.Open(strNetName,CFile::modeCreate|CFile::modeWrite)==0)   
 {MessageBox(NULL,"无法创建文件!","错误",MB_OK);
  return(false);
 }
 else{
	 CArchive myar(&file,CArchive::store);
	 Serialize(myar); 
	 myar.Close();
 }
 file.Close(); 
 return(true);
}
//网络学习
void CBpNet::learn()
{ int iSample=1;
  double dblTotal;
  MSG msg;
  if(m_IsStop)
	  m_IsStop=false;
  //数据正规化处理
  normalize();

  while(dblError>dblMse&&!m_IsStop){
   dblTotal=0.0;
   for(iSample=1;iSample<=mSampleInput.rows();iSample++){
	forward(iSample);
    backward(iSample);
	dblTotal+=dblErr;//总误差
  }
   if(dblTotal/dblError>1.04){//动态改变学习速率
	   dblLearnRate1*=0.7;
       dblLearnRate2*=0.7;
   }
   else{
       dblLearnRate1*=1.05;
       dblLearnRate2*=1.05;
   }
   lEpochs++;
   dblError=dblTotal;

   ::PeekMessage(&msg,NULL,0,0,PM_REMOVE);
   ::DispatchMessage(&msg);
     msg.message=-1;
   ::DispatchMessage(&msg);//这样可以消除屏闪和假死机
  }
if(dblError<=dblMse)
 m_isOK=true;
else
 m_isOK=false;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -