📄 bpnet.cpp
字号:
else
{
CArchive myar(&file,CArchive::store);
Serialize(myar);
myar.Close();
}
file.Close();
return(true);
}
//网络学习
double CBpNet::dblError = 999999;
void CBpNet::learn()
{
int iSample=1;
double dblTotal;
MSG msg;
m_IsStop=false;
//数据正规化处理
normalize();
while(dblError>dblMse&&!m_IsStop)
{
dblTotal=0.0;
for(iSample=0;iSample<mSampleInput.rows();iSample++)
{
forward(iSample);
backward(iSample);
dblTotal+=dblErr;//总误差
}
dblError = dblError>100000?100000:dblError;
dblError = dblError<0.0000001?0.0000001:dblError;
if (dblLearnRate1>=0.95)
{
dblLearnRate1 = 0.5;
dblLearnRate2 = 0.5;
}
if(dblTotal/dblError>1.04)
{
//动态改变学习速率
dblLearnRate1*=0.7;
dblLearnRate2*=0.7;
}
else
{
dblLearnRate1*=1.05;
dblLearnRate2*=1.05;
}
dblLearnRate1 = dblLearnRate1>0.95?0.95:dblLearnRate1;
dblLearnRate1 = dblLearnRate1<0.1?0.1:dblLearnRate1;
dblLearnRate2 = dblLearnRate2>0.95?0.95:dblLearnRate2;
dblLearnRate2 = dblLearnRate2<0.1?0.1:dblLearnRate2;
lEpochs++;
dblError=dblTotal;
::PeekMessage(&msg,NULL,0,0,PM_REMOVE);
::DispatchMessage(&msg);
msg.message=-1;
::DispatchMessage(&msg);//这样可以消除屏闪和假死机
}
if(dblError<=dblMse)
m_isOK=true;
else
m_isOK=false;
if (m_isOK)
{
AfxMessageBox("训练成功收敛");
}
}
void CBpNet::stop()
{
m_IsStop=true;
}
//注意,如果应用矩阵库,头文件matlib.h对rand()函数重新定义,只产生(0,1)
//之间的随机数
double CBpNet::randab(double a, double b)
{
return ((b-a)*rand()/RAND_MAX+a);
}
//将数据转化到(0,1)区间
void CBpNet::normalize()
{
int i,j;
//输入数据范围
mInputNormFactor=scope(mSampleInput);
//目标数据范围
mTargetNormFactor=scope(mSampleTarget);
for(i=0;i<mSampleInput.rows();i++)
{
for(j=0;j<mSampleInput.cols();j++)
{
if (fabs(mInputNormFactor(j,1)-mInputNormFactor(j,0))<0.000001)
{
mSampleInput.Set(i,j,mInputNormFactor(j,1));
}
else
{
mSampleInput.Set(i,j,(mSampleInput(i,j)-mInputNormFactor(j,0))/(mInputNormFactor(j,1)-mInputNormFactor(j,0)));
}
}
}
for(i=0;i<mSampleTarget.rows();i++)
{
for(j=0;j<mSampleTarget.cols();j++)
{
if (fabs(mTargetNormFactor(j,1)-mTargetNormFactor(j,0))<0.0000001)
{
mSampleTarget.Set(i,j,mTargetNormFactor(j,1));
}
else
{
mSampleTarget.Set(i,j,(mSampleTarget(i,j)-mTargetNormFactor(j,0))/(mTargetNormFactor(j,1)-mTargetNormFactor(j,0)));
}
}
}
}
//前向计算
//根据第iSample个样本,前向计算
void CBpNet::forward(int iSample)
{
if(iSample>=mSampleInput.rows())
{
MessageBox(NULL,"无此样本数据:索引出界!","无此样本数据:索引出界!",MB_OK);
return;
}
int i,j;
double sum=0.0;
double dTemp;
//输入层数据
for(i=0;i<iInput;i++)
mInput.m_pData[i]=mSampleInput(iSample,i);
//隐层数据
for(j=0;j<iHidden;j++)
{
sum=0.0;
for(i=0;i<iInput;i++)
sum+=mInput.m_pData[i]*mWeighti(i,j);
sum-=mThresholdi.m_pData[j];
dTemp = exp(-sum);
dTemp = dTemp>100000?100000:dTemp;
mHidden.m_pData[j]=1.0/(1.0+dTemp);
}
//输出数据
for(j=0;j<iOutput;j++)
{
sum=0.0;
for(i=0;i<iHidden;i++)
sum+=mHidden.m_pData[i]*mWeighto(i,j);
sum-=mThresholdo.m_pData[j];
dTemp = exp(-sum);
dTemp = dTemp>100000?100000:dTemp;
mOutput.m_pData[j]=1.0/(1.0+dTemp);
}
}
//后向反馈
void CBpNet::backward(int iSample)
{
if(iSample>=mSampleInput.rows()){
MessageBox(NULL,"无此样本数据:索引出界!","无此样本数据:索引出界!",MB_OK);
return;
}
int i,j;
//输出误差
for(i=0;i<iOutput;i++)
{
mOutputDeltas.m_pData[i] = mOutput.m_pData[i]*(1-mOutput.m_pData[i])*(mSampleTarget(iSample,i)-mOutput.m_pData[i]);
}
//隐层误差
double sum=0.0;
for(j=0;j<iHidden;j++)
{
sum=0.0;
for(i=0;i<iOutput;i++)
sum+=mOutputDeltas.m_pData[i]*mWeighto(j,i);
mHiddenDeltas.m_pData[j]=mHidden.m_pData[j]*(1-mHidden.m_pData[j])*sum;
}
//更新隐层-输出权重
double dblChange;
for(j=0;j<iHidden;j++)
{
for(i=0;i<iOutput;i++)
{
dblChange=mOutputDeltas.m_pData[i]*mHidden.m_pData[j];
mWeighto.Set(j,i,mWeighto(j,i)+dblLearnRate2*dblChange+dblMomentumFactor*mChangeo(j,i));
mChangeo.Set(j,i,dblChange);
}
}
//更新输入-隐层权重
for(i=0;i<iInput;i++)
{
for(j=0;j<iHidden;j++)
{
dblChange=mHiddenDeltas.m_pData[j]*mInput.m_pData[i];
mWeighti.Set(i,j,mWeighti(i,j)+dblLearnRate1*dblChange+dblMomentumFactor*mChangei(i,j));
mChangei.Set(i,j,dblChange);
}
}
//修改阙值
for(j=0;j<iOutput;j++)
mThresholdo.m_pData[j]-=dblLearnRate2*mOutputDeltas.m_pData[j];
for(i=0;i<iHidden;i++)
mThresholdi.m_pData[i]-=dblLearnRate1*mHiddenDeltas.m_pData[i];
//计算误差
dblErr=0.0;
for(i=0;i<iOutput;i++)
dblErr+=0.5*(mSampleTarget(iSample,i)-mOutput.m_pData[i])*(mSampleTarget(iSample,i)-mOutput.m_pData[i]);
}
//求数据列的范围
CMatrix CBpNet::scope(CMatrix &mData)
{
CMatrix mScope;
mScope.Zeros(mData.cols(),2);
double min,max;
for(int i=0;i<mData.cols();i++)
{
min=max=mData(1,i);
for(int j=0;j<mData.rows();j++)
{
if(mData(j,i)>=max)
max=mData(j,i);
if(mData(j,i)<=min)
min=mData(j,i);
}
if(min==max)
min=0.0;
mScope.Set(i,0,min);
mScope.Set(i,1,max);
}
return(mScope);
}
//显示矩阵数据,方便调试
void CBpNet::display(CMatrix &data)
{
CString strData,strTemp;
int i=1,j=1;
for(i=0;i<data.rows();i++)
{
for(j=0;j<data.cols();j++)
{
strTemp.Format("%.3f ",data(i,j));
strData+=strTemp;
}
strData=strData+"\r\n";
}
::MessageBox(NULL,strData,"",MB_OK);
}
void CBpNet::Serialize(CArchive &ar)
{
CObject::Serialize(ar);
/////////////////////////////////////
if(ar.IsStoring())
{
int i,j;
double dblData;
CString strTemp="Bp";
ar<<strTemp;//写入标志
//纪录神经元个数
ar<<iInput<<iHidden<<iOutput;
//纪录权值
for(i=0;i<iInput;i++)
{
for(j=0;j<iHidden;j++)
{
dblData=mWeighti(i,j);
ar<<dblData;
}
}
for(i=0;i<iHidden;i++)
{
for(j=0;j<iOutput;j++)
{
dblData=mWeighto(i,j);
ar<<dblData;
}
}
//记录权值变化
for(j=0;j<iHidden;j++)
{
for(i=0;i<iOutput;i++)
{
ar<<mChangeo(j,i);
}
}
//输入-隐层权重变化
for(i=0;i<iInput;i++)
{
for(j=0;j<iHidden;j++)
{
ar<<mChangei(i,j);
}
}
//纪录阙值
for(i=0;i<iHidden;i++)
{
dblData=mThresholdi(i);
ar<<dblData;
}
for(i=0;i<iOutput;i++)
{
dblData=mThresholdo(i);
ar<<dblData;
}
//纪录输入输出的极值
for(i=0;i<iInput;i++)
{
dblData=mInputNormFactor(i,0);
ar<<dblData; //极小值
dblData=mInputNormFactor(i,1);
ar<<dblData; //极大值
}
for(i=0;i<iOutput;i++)
{
dblData=mTargetNormFactor(i,0);
ar<<dblData; //输出数据极小值
dblData=mTargetNormFactor(i,1);
ar<<dblData;
}
//误差范围
ar<<dblMse;
//学习速率
ar<<dblLearnRate1<<dblLearnRate2;
}
else
{
int i,j;
CString strTemp="";
double dblTemp;
ar>>strTemp;//读入标志
//读入神经元个数
ar>>iInput>>iHidden>>iOutput;
mChangei.Zeros(iInput,iHidden);
mChangeo.Zeros(iHidden,iOutput);
mWeighti.Zeros(iInput,iHidden);
mWeighto.Zeros(iHidden,iOutput);
//读入权值
for(i=0;i<iInput;i++)
{
for(j=0;j<iHidden;j++)
{
ar>>dblTemp;
mWeighti.Set(i,j,dblTemp);
}
}
for(i=0;i<iHidden;i++)
{
for(j=0;j<iOutput;j++)
{
ar>>dblTemp;
mWeighto.Set(i,j,dblTemp);
}
}
//读入权值变化
for(j=0;j<iHidden;j++)
{
for(i=0;i<iOutput;i++)
{
ar>>dblTemp;
mChangeo.Set(j,i,dblTemp);
}
}
//输入-隐层权重
for(i=0;i<iInput;i++)
{
for(j=0;j<iHidden;j++)
{
ar>>dblTemp;
mChangei.Set(i,j,dblTemp);
}
}
//读入阙值
mThresholdi.Zeros(1,iHidden);
for(i=0;i<iHidden;i++)
{
ar>>dblTemp;
mThresholdi.m_pData[i]=dblTemp;
}
mThresholdo.Zeros(1,iOutput);
for(i=0;i<iOutput;i++)
{
ar>>dblTemp;
mThresholdo.m_pData[i]=dblTemp;
}
//读入输入输出的极值
mInputNormFactor.Zeros(iInput,2);
for(i=0;i<iInput;i++)
{
ar>>dblTemp;
mInputNormFactor.Set(i,0,dblTemp); //极小值
ar>>dblTemp;
mInputNormFactor.Set(i,1,dblTemp); //极大值
}
mTargetNormFactor.Zeros(iOutput,2);
for(i=0;i<iOutput;i++)
{
ar>>dblTemp;
mTargetNormFactor.Set(i,0,dblTemp); //输出数据极小值
ar>>dblTemp;
mTargetNormFactor.Set(i,1,dblTemp);
}
//读入误差范围
ar>>dblMse;
//读入学习速率
ar>>dblLearnRate1>>dblLearnRate2;
//创建计算用的单个样本矩阵
mInput.Zeros(1,iInput);
mHidden.Zeros(1,iHidden);
mOutput.Zeros(1,iOutput);
//误差矩阵
mOutputDeltas.Zeros(1,iOutput);
mHiddenDeltas.Zeros(1,iHidden);
}
}
//如果不是新网络,比如从文件恢复的网络,调用此函数构建学习样本
void CBpNet::LoadPattern(CMatrix &mIn, CMatrix &mOut)
{
if(mIn.cols()!=iInput||mOut.cols()!=iOutput){
::MessageBox( NULL,"学习样本格式错误!","错误",MB_OK);
return;
}
mSampleInput.Zeros(mIn.rows(),mIn.cols());
mSampleTarget.Zeros(mOut.rows(),mOut.cols());
mSampleInput=mIn;
mSampleTarget=mOut;
m_isOK=false;
m_IsStop=false;
lEpochs=0;
dblMomentumFactor=0.95;
dblError=1.0;
}
BOOL CMatrix::Zeros(int nRows, int nCols)
{
return Create(CSize(nCols,nRows));
}
void CMatrix::Set(int nRos,int nCol,double dVal)
{
ASSERT(nRos<m_nRows);
ASSERT(nCol<m_nCols);
m_pData[nRos*m_nCols+nCol] = dVal;
}
BOOL CMatrix::Create(CSize szM)
{
Destroy();
m_pData = new double[szM.cx*szM.cy];
if (!m_pData)
{
return FALSE;
}
memset(m_pData,0,sizeof(double)*szM.cx*szM.cy);
m_nCols = szM.cx;
m_nRows = szM.cy;
return TRUE;
}
void CMatrix::Destroy()
{
m_nRows = 0;
m_nCols = 0;
if (m_pData)
{
delete []m_pData;
m_pData = NULL;
}
}
BOOL CMatrix::SetData(int nRow, int nCol, double *pData)
{
if (!Create(CSize(nCol,nRow)))
{
return FALSE;
}
memcpy(m_pData,pData,sizeof(double)*m_nRows*m_nCols);
return TRUE;
}
BOOL CMatrix::SetData(int nRow,int nCol,BYTE *pData)
{
if (!Create(CSize(nCol,nRow)))
{
return FALSE;
}
int nNum = m_nRows*m_nCols;
for(int i=0;i<nNum;i++)
{
m_pData[i] = pData[i];
}
return TRUE;
}
double CMatrix::Max(int &nNo)
{
int nNum = m_nRows*m_nCols;
nNo = 0;
if (nNum<1) {
return 0;
}
double dMax = m_pData[0];
for(int i=0;i<nNum;i++)
{
if (dMax<m_pData[i])
{
dMax = m_pData[i];
nNo = i;
}
}
return dMax;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -