📄 bpnet.cpp
字号:
// BpNet.cpp : implementation file
////////////////////////////////////////////////////////////////////
/////////////////人工神经网络BP算法/////////////////////////////////
//1、动态改变学习速率
//2、加入动量项
//3、运用了Matcom4.5的矩阵运算库(可免费下载,头文件matlib.h),
// 方便矩阵运算,当然,也可自己写矩阵类
//4、可暂停运算
//5、可将网络以文件的形式保存、恢复
///////////////作者:同济大学材料学院 张纯禹//////////////////////
///////////////email:chunyu_79@hotmail.com//////////////////////////
///////////////QQ:53806186//////////////////////////////////////////
///////////////欢迎不断改进!欢迎讨论其他实用的算法!/////////////////
#include "stdafx.h"
#include "BpNet.h"
#ifdef _DEBUG
#define new DEBUG_NEW
#undef THIS_FILE
static char THIS_FILE[] = __FILE__;
#endif
int CMatrix::rows()
{
return m_nRows;
}
int CMatrix::cols()
{
return m_nCols;
}
double CMatrix::operator()(int nRow,int nCol)
{
ASSERT(nRow<m_nRows);
ASSERT(nCol<m_nCols);
return m_pData[nRow*m_nCols+nCol];
}
double CMatrix::operator()(int nNo)
{
ASSERT(nNo<m_nRows*m_nCols);
return m_pData[nNo];
}
CMatrix& CMatrix::operator=(CMatrix &mIn)
{
if (!Create(CSize(mIn.m_nCols,mIn.m_nRows)))
{
return *this;
}
memcpy(m_pData,mIn.m_pData,sizeof(double)*m_nCols*m_nRows);
return *this;
}
CMatrix::CMatrix()
{
m_nRows = 0;
m_nCols = 0;
m_pData = NULL;
}
void CMatrix::Display()
{
CString strData,strTemp;
int i,j;
for(i=0;i<m_nRows;i++)
{
for(j=0;j<m_nCols;j++)
{
strTemp.Format("%.3f ",m_pData[i*m_nCols+j]);
strData+=strTemp;
}
strData=strData+"\r\n";
}
::MessageBox(NULL,strData,"",MB_OK);
}
CMatrix::CMatrix(CMatrix &mIn)
{
m_nRows = 0;
m_nCols = 0;
m_pData = NULL;
if (!Create(CSize(mIn.m_nCols,mIn.m_nRows))) {
return ;
}
memcpy(m_pData,mIn.m_pData,sizeof(double)*m_nCols*m_nRows);
}
CMatrix::~CMatrix()
{
Destroy();
}
BOOL CMatrix::SaveAsText(CString strPath)
{
CStdioFile file;
if (!file.Open(strPath,CFile::modeCreate|CFile::modeWrite))
{
return FALSE;
}
CString strData,strTemp;
int i,j;
for(i=0;i<m_nRows;i++)
{
strData = "";
for(j=0;j<m_nCols;j++)
{
strTemp.Format("%.3f ",m_pData[i*m_nCols+j]);
strData+=strTemp;
}
strData=strData+"\r\n";
file.WriteString(strData);
}
file.Close();
return TRUE;
}
/////////////////////////////////////////////////////////////////////////////
// CBpNet
IMPLEMENT_SERIAL( CBpNet, CObject, 1 )
bool CBpNet::m_IsStop = TRUE;
CBpNet::CBpNet()
{
srand((unsigned)time(NULL));
}
CBpNet::~CBpNet()
{
}
double CBpNet::GetError()
{
return dblError;
}
void CBpNet::GetFeaNN(BYTE &nH, BYTE &nV, BYTE &nL, BYTE &nR, BYTE &nS, BYTE * pData,int nDim)
{
int i,j;//i表示高,j表示宽
nS = 0;
nH = 0;
nV = 0;
nL = 0;
nR = 0;
//总数
for(i=0;i<nDim;i++)
{
for(j=0;j<nDim;j++)
{
if (pData[i*nDim+j]==0)
{
nS++;
}
}
}
return;
//横
for(i=0;i<nDim;i++)
{
for(j=0;j<nDim-1;j++)
{
if (pData[i*nDim+j]==0&&pData[i*nDim+j+1]==0)
{
nH++;
}
}
}
//竖
for(i=0;i<nDim-1;i++)
{
for(j=0;j<nDim;j++)
{
if (pData[i*nDim+j]==0&&pData[(i+1)*nDim+j]==0)
{
nV++;
}
}
}
//撇
for(i=1;i<nDim;i++)
{
for(j=0;j<i-1;j++)
{
if (pData[i*nDim+j]==0&&pData[(i+1)*nDim+j-1]==0)
{
nL++;
}
}
}
for(i=nDim-2;i>0;i--)
{
for(j=nDim-1;j>i;j--)
{
if (pData[i*nDim+j]==0&&pData[(i-1)*nDim+j+1]==0)
{
nL++;
}
}
}
//捺
for(j=0;j<nDim;j++)
{
for(i=0;i<j-1;i++)
{
if (pData[i*nDim+j]==0&&pData[(i+1)*nDim+j+1]==0)
{
nR++;
}
}
}
for(j=0;j<nDim-1;j++)
{
for(i=nDim-1;i>j;i--)
{
if (pData[i*nDim+j]==0&&pData[(i-1)*nDim+j-1]==0)
{
nR++;
}
}
}
}
int CBpNet::Recogn(CImage &ImgIn,double &dRecVal)
{
int nLen;
BYTE Fea[512];
GetFeature(ImgIn,Fea,nLen);
CMatrix mIn;
CMatrix mOut;
mIn.SetData(1,nLen,Fea);
mOut = simulate(mIn);
int nResult;
dRecVal = mOut.Max(nResult);
return nResult;
}
int CBpNet::Recogn1(CImage &ImgIn,double &dRecVal)
{
int nLenFea;
BYTE byFea[512];
ImgIn.Stretch(CSize(16,32));
int nS = 2;
nLenFea = 0;
int i,j,k,l;
for(i=0;i<ImgIn.GetHeight()/nS;i++)
{
for(j=0;j<ImgIn.GetWidth()/nS;j++)
{
byFea[nLenFea] = 0;
for(k=0;k<nS;k++)
{
for(l=0;l<nS;l++)
{
if (ImgIn.m_pR[(i*nS+k)*16+j*nS+l]==0)
{
byFea[nLenFea]++;
}
}
}
nLenFea += 1;
}
}
CMatrix mIn;
CMatrix mOut;
mIn.SetData(1,nLenFea,byFea);
mOut = simulate(mIn);
int nResult;
dRecVal = mOut.Max(nResult);
return nResult;
}
/////////////////////////////////////////////////////////////////////////////
// CBpNet message handlers
//创建新网络
BOOL CBpNet::GetFeature(CImage &ImgIn,BYTE *pFea,int &nLen)
{
nLen = 0;
CImage ImgTemp = ImgIn;
ImgTemp.Stretch(CSize(16,32));
//提取特征
BYTE * pFeaT = pFea;
BYTE byData[64];
BYTE * pBuf = ImgTemp.m_pR;
int i,j,k,l;
int nWAndH = ImgTemp.GetHeight()+ImgTemp.GetWidth();
memset(pFeaT,0,3*nWAndH);
int nW = ImgTemp.GetWidth();
//先宽后高
//求垂直跳变次数
for(i=0;i<ImgTemp.GetWidth();i++)
{
for(j=0;j<ImgTemp.GetHeight()-1;j++)
{
if (pBuf[j*nW+i]!=pBuf[(j+1)*nW+i]&&pBuf[j*nW+i]==0)
{
pFeaT[i]++;
}
}
if (pBuf[j*nW+i]==0)
{
pFeaT[i]++;
}
}
pFeaT += ImgTemp.GetWidth();
nLen += ImgTemp.GetWidth();
//水平跳变次数
for(i=0;i<ImgTemp.GetHeight();i++)
{
for(j=0;j<ImgTemp.GetWidth()-1;j++)
{
if (pBuf[i*nW+j]!=pBuf[i*nW+j+1]&&pBuf[i*nW+j]==0)
{
pFeaT[i]++;
}
}
if (pBuf[i*nW+j]==0)
{
pFeaT[i]++;
}
}
pFeaT += ImgTemp.GetHeight();
nLen += ImgTemp.GetHeight();
//求垂直方向起始位置
for(i=0;i<ImgTemp.GetWidth();i++)
{
pFeaT[i] = 0;
pFeaT[i+nW] = ImgTemp.GetHeight()-1;
for(j=0;j<ImgTemp.GetHeight();j++)
{
if (pBuf[j*nW+i]==0)
{
pFeaT[i] = j;
break;
}
}
for(j=ImgTemp.GetHeight()-1;j>=0;j--)
{
if (pBuf[j*nW+i]==0)
{
pFeaT[i+nW] = j;
break;
}
}
}
pFeaT += 2*ImgTemp.GetWidth();
nLen += 2*ImgTemp.GetWidth();
//水平起始位置
for(i=0;i<ImgTemp.GetHeight();i++)
{
pFeaT[i] = 0;
pFeaT[i+nW] = ImgTemp.GetWidth()-1;
for(j=0;j<ImgTemp.GetWidth();j++)
{
if (pBuf[i*nW+j]==0)
{
pFeaT[i] = j;
break;
}
}
for(j=ImgTemp.GetWidth()-1;j>=0;j--)
{
if (pBuf[i*nW+j]==0)
{
pFeaT[i+nW] = j;
break;
}
}
}
pFeaT += 2*ImgTemp.GetHeight();
nLen += 2*ImgTemp.GetHeight();
int nS = 4;
for(i=0;i<ImgTemp.GetHeight()/nS;i++)
{
for(j=0;j<ImgTemp.GetWidth()/nS;j++)
{
for(k=0;k<nS;k++)
{
for(l=0;l<nS;l++)
{
byData[k*nS+l] = pBuf[(i*nS+k)*16+j*nS+l];
}
}
// GetFeaNN(pFeaT[0],pFeaT[1],pFeaT[2],pFeaT[3],pFeaT[4],byData,nS);
GetFeaNN(pFeaT[0],pFeaT[0],pFeaT[0],pFeaT[0],pFeaT[0],byData,nS);
pFeaT += 1;
nLen += 1;
}
}
return TRUE;
}
void CBpNet::Create(CMatrix &mInputData, CMatrix &mTarget, int iInput, int iHidden, int iOutput)
{
int i,j;
mSampleInput=mInputData;
mSampleTarget=mTarget;
this->iInput=iInput;
this->iHidden=iHidden;
this->iOutput=iOutput;
//创建计算用的单个样本矩阵
mInput.Zeros(1,this->iInput);
mHidden.Zeros(1,this->iHidden);
mOutput.Zeros(1,this->iOutput);
//创建权重矩阵,并赋初值
mWeighti.Zeros(this->iInput,this->iHidden);
mWeighto.Zeros(this->iHidden,this->iOutput);
//赋初值
for(i=0;i<this->iInput;i++)
{
for(j=0;j<this->iHidden;j++)
{
mWeighti.Set(i,j,randab(-1.0,1.0));
}
}
for(i=0;i<this->iHidden;i++)
{
for(j=0;j<this->iOutput;j++)
{
mWeighto.Set(i,j,randab(-1.0,1.0));
}
}
//创建阙值矩阵,并赋值
mThresholdi.Zeros(1,this->iHidden);
for(i=0;i<this->iHidden;i++)
mThresholdi.m_pData[i] = randab(-1.0,1.0);
mThresholdo.Zeros(1,this->iOutput);
for(i=0;i<this->iOutput;i++)
mThresholdo.m_pData[i] = randab(-1.0,1.0);
//创建权重变化矩阵
mChangei.Zeros(this->iInput,this->iHidden);
mChangeo.Zeros(this->iHidden,this->iOutput);
mInputNormFactor.Zeros(iInput,2);
mTargetNormFactor.Zeros(iOutput,2);
//误差矩阵
mOutputDeltas.Zeros(1,iOutput);
mHiddenDeltas.Zeros(1,iHidden);
//学习速率赋值
dblLearnRate1=0.5;
dblLearnRate2=0.5;
dblMomentumFactor=0.05;
m_isOK=false;
m_IsStop=false;
dblMse = 0.00005*mSampleTarget.rows()*mSampleTarget.cols();//误差限
dblError=1.0*mSampleTarget.rows()*mSampleTarget.cols();
lEpochs=0;
}
//根据已有的网络进行预测
CMatrix CBpNet::simulate(CMatrix& mData)
{
int i,j;
CMatrix mResult;
CMatrix data;
data=mData;
if(mData.cols()!=iInput)
{
::MessageBox(NULL,"输入数据变量个数错误!","输入数据变量个数错误!",MB_OK);
return mResult;
}
mResult.Zeros(data.rows(),iOutput);
//正规化数据
for(i=0;i<data.rows();i++)
{
for(j=0;j<data.cols();j++)
{
if (fabs(mInputNormFactor(j,1)-mInputNormFactor(j,0))<0.0000001)
{
data.Set(i,j,mInputNormFactor(j,1));
}
else
{
data.Set(i,j,(data(i,j)-mInputNormFactor(j,0))/(mInputNormFactor(j,1)-mInputNormFactor(j,0)));
}
}
}
//计算
int iSample;
CMatrix mInputdata,mHiddendata,mOutputdata;
mInputdata.Zeros(1,iInput);
mHiddendata.Zeros(1,iHidden);
mOutputdata.Zeros(1,iOutput);
double sum=0.0;
double dTemp;
for(iSample=0;iSample<data.rows();iSample++)
{
//输入层数据
for(i=0;i<iInput;i++)
mInputdata.m_pData[i]=data(iSample,i);
//隐层数据
for(j=0;j<iHidden;j++)
{
sum=0.0;
for(i=0;i<iInput;i++)
sum+=mInputdata.m_pData[i]*mWeighti(i,j);
sum-=mThresholdi.m_pData[j];
dTemp = exp(-sum);
dTemp = dTemp>100000?100000:dTemp;
mHiddendata.m_pData[j]=1.0/(1.0+dTemp);
}
//输出数据
for(j=0;j<iOutput;j++){
sum=0.0;
for(i=0;i<iHidden;i++)
sum+=mHiddendata.m_pData[i]*mWeighto(i,j);
sum-=mThresholdo.m_pData[j];
dTemp = exp(-sum);
dTemp = dTemp>100000?100000:dTemp;
mOutputdata.m_pData[j]=1.0/(1.0+dTemp);
}
//转换
for(j=0;j<iOutput;j++)
mResult.Set(iSample,j,mOutputdata.m_pData[j]*(mTargetNormFactor(j,1)-mTargetNormFactor(j,0))+mTargetNormFactor(j,0));
}
return (mResult);
}
void CBpNet::LoadBpNet(CString &strNetName)
{
CFile file;
if(file.Open(strNetName,CFile::modeRead)==0)
{
MessageBox(NULL,"无法打开文件!","错误",MB_OK);
return;
}
else
{
CArchive myar(&file,CArchive::load);
Serialize(myar);
myar.Close();
}
file.Close();
}
bool CBpNet::SaveBpNet(CString &strNetName)
{
CFile file;
if(strNetName.GetLength()==0)
return(false);
if(file.Open(strNetName,CFile::modeCreate|CFile::modeWrite)==0)
{
MessageBox(NULL,"无法创建文件!","错误",MB_OK);
return(false);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -