⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bpnet.cpp

📁 车牌识别(改定位)武汉理工大学
💻 CPP
📖 第 1 页 / 共 2 页
字号:
// BpNet.cpp : implementation file
////////////////////////////////////////////////////////////////////
/////////////////人工神经网络BP算法/////////////////////////////////
//1、动态改变学习速率
//2、加入动量项
//3、运用了Matcom4.5的矩阵运算库(可免费下载,头文件matlib.h),
//   方便矩阵运算,当然,也可自己写矩阵类
//4、可暂停运算
//5、可将网络以文件的形式保存、恢复
///////////////作者:同济大学材料学院   张纯禹//////////////////////
///////////////email:chunyu_79@hotmail.com//////////////////////////
///////////////QQ:53806186//////////////////////////////////////////
///////////////欢迎不断改进!欢迎讨论其他实用的算法!/////////////////
#include "stdafx.h"
#include "BpNet.h"

#ifdef _DEBUG
#define new DEBUG_NEW
#undef THIS_FILE
static char THIS_FILE[] = __FILE__;
#endif

int CMatrix::rows()
{
	return m_nRows;
}
int CMatrix::cols()
{
	return m_nCols;
}
double CMatrix::operator()(int nRow,int nCol)
{
	ASSERT(nRow<m_nRows);
	ASSERT(nCol<m_nCols);
	return m_pData[nRow*m_nCols+nCol];
}
double CMatrix::operator()(int nNo)
{
	ASSERT(nNo<m_nRows*m_nCols);
	return m_pData[nNo];
}
CMatrix& CMatrix::operator=(CMatrix &mIn)
{
	if (!Create(CSize(mIn.m_nCols,mIn.m_nRows))) 
	{
		return *this;
	}
	memcpy(m_pData,mIn.m_pData,sizeof(double)*m_nCols*m_nRows);
	return *this;
}
CMatrix::CMatrix()
{
	m_nRows = 0;
	m_nCols = 0;
	m_pData = NULL;
}
void CMatrix::Display()
{
	CString strData,strTemp;
	int i,j;
	for(i=0;i<m_nRows;i++)
	{
		for(j=0;j<m_nCols;j++)
		{
			strTemp.Format("%.3f ",m_pData[i*m_nCols+j]);
			strData+=strTemp;
		}
		strData=strData+"\r\n";
	}
	::MessageBox(NULL,strData,"",MB_OK);
}
CMatrix::CMatrix(CMatrix &mIn)
{
	m_nRows = 0;
	m_nCols = 0;
	m_pData = NULL;
	if (!Create(CSize(mIn.m_nCols,mIn.m_nRows))) {
		return ;
	}
	memcpy(m_pData,mIn.m_pData,sizeof(double)*m_nCols*m_nRows);
}
CMatrix::~CMatrix()
{
	Destroy();
}
BOOL CMatrix::SaveAsText(CString strPath)
{
	CStdioFile file;
	if (!file.Open(strPath,CFile::modeCreate|CFile::modeWrite))
	{
		return FALSE;
	}
	CString strData,strTemp;
	int i,j;
	for(i=0;i<m_nRows;i++)
	{
		strData = "";
		for(j=0;j<m_nCols;j++)
		{
			strTemp.Format("%.3f ",m_pData[i*m_nCols+j]);
			strData+=strTemp;
		}
		strData=strData+"\r\n";
		file.WriteString(strData);
	}
	file.Close();
	return TRUE;
}
/////////////////////////////////////////////////////////////////////////////
// CBpNet
IMPLEMENT_SERIAL( CBpNet, CObject, 1 )
bool CBpNet::m_IsStop = TRUE;
CBpNet::CBpNet()
{
	srand((unsigned)time(NULL));	
}

CBpNet::~CBpNet()
{
}
double CBpNet::GetError()
{
	return dblError;
}
void CBpNet::GetFeaNN(BYTE &nH, BYTE &nV, BYTE &nL, BYTE &nR, BYTE &nS, BYTE * pData,int nDim)
{
	int i,j;//i表示高,j表示宽
	nS = 0;
	nH = 0;
	nV = 0;
	nL = 0;
	nR = 0;
	//总数
	for(i=0;i<nDim;i++)
	{
		for(j=0;j<nDim;j++)
		{
			if (pData[i*nDim+j]==0)
			{
				nS++;
			}
		}
	}
	return;
	//横
	for(i=0;i<nDim;i++)
	{
		for(j=0;j<nDim-1;j++)
		{
			if (pData[i*nDim+j]==0&&pData[i*nDim+j+1]==0) 
			{
				nH++;
			}
		}
	}
	//竖
	for(i=0;i<nDim-1;i++)
	{
		for(j=0;j<nDim;j++)
		{
			if (pData[i*nDim+j]==0&&pData[(i+1)*nDim+j]==0) 
			{
				nV++;
			}
		}
	}
	//撇
	for(i=1;i<nDim;i++)
	{
		for(j=0;j<i-1;j++)
		{
			if (pData[i*nDim+j]==0&&pData[(i+1)*nDim+j-1]==0) 
			{
				nL++;
			}
		}
	}
	for(i=nDim-2;i>0;i--)
	{
		for(j=nDim-1;j>i;j--)
		{
			if (pData[i*nDim+j]==0&&pData[(i-1)*nDim+j+1]==0) 
			{
				nL++;
			}
		}
	}
	//捺
	for(j=0;j<nDim;j++)
	{
		for(i=0;i<j-1;i++)
		{
			if (pData[i*nDim+j]==0&&pData[(i+1)*nDim+j+1]==0)
			{
				nR++;
			}
		}
	}
	for(j=0;j<nDim-1;j++)
	{
		for(i=nDim-1;i>j;i--)
		{
			if (pData[i*nDim+j]==0&&pData[(i-1)*nDim+j-1]==0) 
			{
				nR++;
			}
		}
	}
}
int CBpNet::Recogn(CImage &ImgIn,double &dRecVal)
{
	int nLen;
	BYTE Fea[512];	
    GetFeature(ImgIn,Fea,nLen);
	CMatrix mIn;
	CMatrix mOut;
	mIn.SetData(1,nLen,Fea);
	mOut = simulate(mIn);
	int nResult;
	dRecVal = mOut.Max(nResult);
	return nResult;
}
int CBpNet::Recogn1(CImage &ImgIn,double &dRecVal)
{
	int nLenFea;
	BYTE byFea[512];
	
	ImgIn.Stretch(CSize(16,32));
	int nS = 2;
	nLenFea = 0;
	int i,j,k,l;
	for(i=0;i<ImgIn.GetHeight()/nS;i++)
	{
		for(j=0;j<ImgIn.GetWidth()/nS;j++)
		{
			byFea[nLenFea] = 0;
			for(k=0;k<nS;k++)
			{
				for(l=0;l<nS;l++)
				{
					if (ImgIn.m_pR[(i*nS+k)*16+j*nS+l]==0) 
					{
						byFea[nLenFea]++;
					}
				}
			}
			nLenFea += 1;
		}
	}
	
	CMatrix mIn;
	CMatrix mOut;
	mIn.SetData(1,nLenFea,byFea);
	mOut = simulate(mIn);
	
	int nResult;
	dRecVal = mOut.Max(nResult);
	return nResult;
}
/////////////////////////////////////////////////////////////////////////////
// CBpNet message handlers
//创建新网络
BOOL CBpNet::GetFeature(CImage &ImgIn,BYTE *pFea,int &nLen)
{
	nLen = 0;
	CImage ImgTemp = ImgIn;
	ImgTemp.Stretch(CSize(16,32));
	//提取特征
	BYTE * pFeaT = pFea;
	BYTE byData[64];
	BYTE * pBuf = ImgTemp.m_pR;
	int i,j,k,l;

	int nWAndH = ImgTemp.GetHeight()+ImgTemp.GetWidth();
	memset(pFeaT,0,3*nWAndH);
	int nW = ImgTemp.GetWidth();
	//先宽后高
	//求垂直跳变次数
	for(i=0;i<ImgTemp.GetWidth();i++)
	{
		for(j=0;j<ImgTemp.GetHeight()-1;j++)
		{
			if (pBuf[j*nW+i]!=pBuf[(j+1)*nW+i]&&pBuf[j*nW+i]==0) 
			{
				pFeaT[i]++;
			}
		}
		if (pBuf[j*nW+i]==0) 
		{
			pFeaT[i]++;
		}
	}
	pFeaT += ImgTemp.GetWidth();
	nLen += ImgTemp.GetWidth();
	//水平跳变次数
	for(i=0;i<ImgTemp.GetHeight();i++)
	{
		for(j=0;j<ImgTemp.GetWidth()-1;j++)
		{
			if (pBuf[i*nW+j]!=pBuf[i*nW+j+1]&&pBuf[i*nW+j]==0) 
			{
				pFeaT[i]++;
			}
		}
		if (pBuf[i*nW+j]==0) 
		{
			pFeaT[i]++;
		}
	}
	pFeaT += ImgTemp.GetHeight();
	nLen += ImgTemp.GetHeight();

	//求垂直方向起始位置
	for(i=0;i<ImgTemp.GetWidth();i++)
	{
		pFeaT[i] = 0;
		pFeaT[i+nW] = ImgTemp.GetHeight()-1;
		for(j=0;j<ImgTemp.GetHeight();j++)
		{	
			if (pBuf[j*nW+i]==0) 
			{
				pFeaT[i] = j;
				break;
			}
		}		
		for(j=ImgTemp.GetHeight()-1;j>=0;j--)
		{
			if (pBuf[j*nW+i]==0) 
			{
				pFeaT[i+nW] = j;
				break;
			}
		}
	}
	pFeaT += 2*ImgTemp.GetWidth();
	nLen += 2*ImgTemp.GetWidth();
	//水平起始位置
	for(i=0;i<ImgTemp.GetHeight();i++)
	{
		pFeaT[i] = 0;
		pFeaT[i+nW] = ImgTemp.GetWidth()-1;
		for(j=0;j<ImgTemp.GetWidth();j++)
		{
			if (pBuf[i*nW+j]==0) 
			{
				pFeaT[i] = j;
				break;
			}
		}
		for(j=ImgTemp.GetWidth()-1;j>=0;j--)
		{
			if (pBuf[i*nW+j]==0) 
			{
				pFeaT[i+nW] = j;
				break;
			}
		}
	}
	pFeaT += 2*ImgTemp.GetHeight();
	nLen += 2*ImgTemp.GetHeight();

	int nS = 4;
	for(i=0;i<ImgTemp.GetHeight()/nS;i++)
	{
		for(j=0;j<ImgTemp.GetWidth()/nS;j++)
		{
			for(k=0;k<nS;k++)
			{
				for(l=0;l<nS;l++)
				{
					byData[k*nS+l] = pBuf[(i*nS+k)*16+j*nS+l];
				}
			}
		//	GetFeaNN(pFeaT[0],pFeaT[1],pFeaT[2],pFeaT[3],pFeaT[4],byData,nS);
			GetFeaNN(pFeaT[0],pFeaT[0],pFeaT[0],pFeaT[0],pFeaT[0],byData,nS);
			pFeaT += 1;
			nLen += 1;
		}
	}
	return TRUE;
}
void CBpNet::Create(CMatrix &mInputData, CMatrix &mTarget, int iInput, int iHidden, int iOutput)
{ 
	int i,j;
	mSampleInput=mInputData;
	mSampleTarget=mTarget;
	this->iInput=iInput;
	this->iHidden=iHidden;
	this->iOutput=iOutput;
	//创建计算用的单个样本矩阵
	mInput.Zeros(1,this->iInput);
	mHidden.Zeros(1,this->iHidden);
	mOutput.Zeros(1,this->iOutput);
	//创建权重矩阵,并赋初值
	mWeighti.Zeros(this->iInput,this->iHidden);
	mWeighto.Zeros(this->iHidden,this->iOutput); 
	//赋初值
	for(i=0;i<this->iInput;i++)
	{
		for(j=0;j<this->iHidden;j++)
		{
			mWeighti.Set(i,j,randab(-1.0,1.0));
		}
	}
	for(i=0;i<this->iHidden;i++)
	{
		for(j=0;j<this->iOutput;j++)
		{
			mWeighto.Set(i,j,randab(-1.0,1.0));
		}
	}
	
	//创建阙值矩阵,并赋值
	mThresholdi.Zeros(1,this->iHidden);
	for(i=0;i<this->iHidden;i++)
		mThresholdi.m_pData[i] = randab(-1.0,1.0);
	mThresholdo.Zeros(1,this->iOutput);
	for(i=0;i<this->iOutput;i++)
		mThresholdo.m_pData[i] = randab(-1.0,1.0);
	//创建权重变化矩阵
	mChangei.Zeros(this->iInput,this->iHidden);
	mChangeo.Zeros(this->iHidden,this->iOutput);
	
	mInputNormFactor.Zeros(iInput,2);
	mTargetNormFactor.Zeros(iOutput,2);
	//误差矩阵
	mOutputDeltas.Zeros(1,iOutput);
	mHiddenDeltas.Zeros(1,iHidden);
	//学习速率赋值
	dblLearnRate1=0.5;
	dblLearnRate2=0.5;
	dblMomentumFactor=0.05;
	
	m_isOK=false;
	m_IsStop=false;
	dblMse = 0.00005*mSampleTarget.rows()*mSampleTarget.cols();//误差限
	dblError=1.0*mSampleTarget.rows()*mSampleTarget.cols();
	lEpochs=0;	
}
//根据已有的网络进行预测
CMatrix CBpNet::simulate(CMatrix& mData)
{
	int i,j;
	CMatrix mResult;
	CMatrix data;
	data=mData;
	if(mData.cols()!=iInput)
	{
		::MessageBox(NULL,"输入数据变量个数错误!","输入数据变量个数错误!",MB_OK);
		return mResult;
	}
	mResult.Zeros(data.rows(),iOutput); 
	//正规化数据
	for(i=0;i<data.rows();i++)
	{
		for(j=0;j<data.cols();j++) 
		{
			if (fabs(mInputNormFactor(j,1)-mInputNormFactor(j,0))<0.0000001)
			{
				data.Set(i,j,mInputNormFactor(j,1));
			}
			else
			{
				data.Set(i,j,(data(i,j)-mInputNormFactor(j,0))/(mInputNormFactor(j,1)-mInputNormFactor(j,0)));
			}			
		}
	}
	//计算
	int iSample;
	CMatrix mInputdata,mHiddendata,mOutputdata;
	mInputdata.Zeros(1,iInput);
	mHiddendata.Zeros(1,iHidden);
	mOutputdata.Zeros(1,iOutput);
	double sum=0.0;
	double dTemp;
	for(iSample=0;iSample<data.rows();iSample++)
	{ 
		//输入层数据
		for(i=0;i<iInput;i++)
			mInputdata.m_pData[i]=data(iSample,i);
		//隐层数据	
		
		for(j=0;j<iHidden;j++)
		{
			sum=0.0;
			for(i=0;i<iInput;i++)
				sum+=mInputdata.m_pData[i]*mWeighti(i,j);
			sum-=mThresholdi.m_pData[j]; 
			dTemp = exp(-sum);
			dTemp = dTemp>100000?100000:dTemp;
			mHiddendata.m_pData[j]=1.0/(1.0+dTemp);
		}
		
		//输出数据
		for(j=0;j<iOutput;j++){
			sum=0.0;
			for(i=0;i<iHidden;i++)
				sum+=mHiddendata.m_pData[i]*mWeighto(i,j);
			sum-=mThresholdo.m_pData[j]; 
			dTemp = exp(-sum);
			dTemp = dTemp>100000?100000:dTemp;
			mOutputdata.m_pData[j]=1.0/(1.0+dTemp);
		}
		
		//转换
		for(j=0;j<iOutput;j++)
			mResult.Set(iSample,j,mOutputdata.m_pData[j]*(mTargetNormFactor(j,1)-mTargetNormFactor(j,0))+mTargetNormFactor(j,0));
	}	
	return (mResult);
}

void CBpNet::LoadBpNet(CString &strNetName)
{
	CFile file;	
	if(file.Open(strNetName,CFile::modeRead)==0)   
	{
		MessageBox(NULL,"无法打开文件!","错误",MB_OK);
		return;
	}
	else
	{
		CArchive myar(&file,CArchive::load);
		Serialize(myar); 
		myar.Close();
	}
	file.Close(); 
}

bool CBpNet::SaveBpNet(CString &strNetName)
{
	CFile file;
	if(strNetName.GetLength()==0)
		return(false);
	if(file.Open(strNetName,CFile::modeCreate|CFile::modeWrite)==0)   
	{
		MessageBox(NULL,"无法创建文件!","错误",MB_OK);
		return(false);
	}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -