⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 neuralnetworkplatform.cpp

📁 使用神经网络开发包实现图形化的神经网络模拟
💻 CPP
📖 第 1 页 / 共 2 页
字号:
// NeuronNetwork.cpp: implementation of the NeuronNetwork class.
//
//////////////////////////////////////////////////////////////////////

#include "stdafx.h"
#include "NeuralNetwork.h"
#include "NeuralNetworkPlatform.h"
#include "NeuralNode.h"
#include "DlgNeuron.h"
#include "DlgLine.h"

#ifdef _DEBUG
#undef THIS_FILE
static char THIS_FILE[]=__FILE__;
#define new DEBUG_NEW
#endif

//////////////////////////////////////////////////////////////////////
// Construction/Destruction
//////////////////////////////////////////////////////////////////////
IMPLEMENT_SERIAL(NeuralNetworkPlatform, CObject, 1)

NeuralNetworkPlatform::NeuralNetworkPlatform()
{
	neuralNum=0;
	layerNum=0;
	trainSet=0;
	multiNet=0;
	btnPosType=0;
	mousePosTemp=CPoint(0,0);
	neuralPosTemp=CPoint(0,0);
	linePosTemp=CPoint(0,0);
}

NeuralNetworkPlatform::~NeuralNetworkPlatform()
{
	
}

bool NeuralNetworkPlatform::insertNode(CPoint pos)
{
	if(neuralContainer.empty())
		neuralContainer.push_back(NeuralNode(pos,++neuralNum));
	else
	{
		for(vector<NeuralNode>::iterator it=neuralContainer.end()-1;it!=neuralContainer.begin()-1;it--)
		{
			if(pos.x>(*it).nodePos.x||(pos.x==(*it).nodePos.x)&&(pos.y>=(*it).nodePos.y))
			{
				neuralContainer.insert(it+1,NeuralNode(pos,++neuralNum));
				return true;
			}
		}
		neuralContainer.insert(neuralContainer.begin(),NeuralNode(pos,++neuralNum));
		return true;
	}
	return false;
}

bool NeuralNetworkPlatform::insertIndex()
{
	int layerTemp=0;
	int neuronTemp=-1;
	CPoint posTemp;
	if(neuralContainer.empty())
	{
		AfxMessageBox("Please insert neuron!");
		return false;
	}
	posTemp=(*neuralContainer.begin()).nodePos;
	for(vector<NeuralNode>::iterator it=neuralContainer.begin();it!=neuralContainer.end();it++)
	{
		if((*it).nodePos.x==posTemp.x)			//=======规定同一层的神经元必须横坐标相同======
		{
			(*it).layerIndex=layerTemp;
			(*it).neuronIndex=++neuronTemp;
			posTemp=(*it).nodePos;
		}
		if((*it).nodePos.x>posTemp.x)			//=====横坐标不同,层数加一========
		{
			layerNeurons.push_back(neuronTemp+1);
			(*it).layerIndex=++layerTemp;
			(*it).neuronIndex=0;
			neuronTemp=0;
			posTemp=(*it).nodePos;
		}
	}
	//=============插入最后一层神经元===============
	layerNeurons.push_back(neuronTemp+1);
	layerNum=layerTemp;
	return true;
}

bool NeuralNetworkPlatform::networkTrain(CString fileName,int maxStep,float learnRate)
{
	insertIndex();
	if(layerNum<2)
	{
		AfxMessageBox("请确保神经网络至少含有三层!");
		return false;
	}
	if(!neuralCheckPosition())
	{
		AfxMessageBox("请确保同层神经元横坐标相同!");
		return false;
	}
	multiNet=new MultiLayerNetwork(layerNeurons[0]);
	
	for(vector<int>::iterator itInt=layerNeurons.begin()+1;itInt!=layerNeurons.end();itInt++)
	{
		multiNet->addLayer(*itInt);
	}	
	for(vector<NeuralNode>::iterator itNeuron=neuralContainer.begin();itNeuron!=neuralContainer.end();itNeuron++)
	{
		for(int i=0;i<EDGENUM;i++)
		{
			if(1==(*itNeuron).edgeType[i])	//===如果边缘节点为输出,则循环访问边缘节点,找到索引与边缘节点输出索引相同的节点====
			{		
				for(vector<NeuralNode>::iterator iterNeuron=neuralContainer.begin();iterNeuron!=neuralContainer.end();iterNeuron++)
				{		
					if((*iterNeuron).totalIndex==(*itNeuron).outputLine[i].outputNeuronIndex)
					{
						if(0.0f==(*itNeuron).outputLine[i].getLineWeight())
						{	//=======如果权值为0,则表示未设定,使用随机值=========
							multiNet->connect((*itNeuron).layerIndex,(*itNeuron).neuronIndex,(*iterNeuron).neuronIndex);
						}
						else
						{	//=======如果权值不为0,则表示已设定,使用指定值=======
							multiNet->connect((*itNeuron).layerIndex,(*itNeuron).neuronIndex,(*iterNeuron).neuronIndex,(*itNeuron).outputLine[i].getLineWeight());
						}
					}
				}
			}
		}
	}
	//==================设置每一层的传输函数(不含输入层)======================
	int neuralNumTemp=layerNeurons[0];
	for(int i=1;i<=layerNum;i++)
	{
		multiNet->setActivationFunction(i,neuralContainer[neuralNumTemp].getActivationFunction(),dgaussian);
		neuralNumTemp+=layerNeurons[i];
	}
	//==================创建训练集=================================
	int inputNum=*(layerNeurons.begin());
	int outputNum=*(layerNeurons.end()-1);
	trainSet=new TrainingSet(inputNum,outputNum);
	//=====================================================================
	vector<real> vectorInput,vectorOutput;
	CString strBuf;
	CString strTemp,astr;
	int posOld=0,posNew=0;
	
	CFile fileTrain;
	fileTrain.Open(fileName,CFile::modeRead);

	char* charBuf=new char[fileTrain.GetLength()];
	UINT byteNum=fileTrain.Read(charBuf,fileTrain.GetLength());
	charBuf[byteNum]=NULL;

	strBuf=charBuf;
	while(posNew!=strBuf.GetLength()-1)
	{	//===========input==================
		for(i=0;i<inputNum-1;i++)
		{
			posNew=strBuf.Find(" ",posOld);
			strTemp=strBuf.Mid(posOld,posNew-posOld);
			if(strTemp.Find(" ")!=-1)
			{

				AfxMessageBox("请确定训练数据与神经网络相符!");
				fileTrain.Close();
				return false;
			}
			posOld=posNew+1;
			vectorInput.push_back(atof(strTemp));
		}
		posNew=strBuf.Find("\n",posOld);
		strTemp=strBuf.Mid(posOld,posNew-posOld);
		if(strTemp.Find(" ")!=-1)
		{
			AfxMessageBox("请确定训练数据与神经网络相符!");
			fileTrain.Close();
			return false;
		}
		posOld=posNew+1;
		vectorInput.push_back(atof(strTemp));
		//===========output=================
		for(i=0;i<outputNum-1;i++)
		{
			posNew=strBuf.Find(" ",posOld);
			strTemp=strBuf.Mid(posOld,posNew-posOld);
			if(strTemp.Find(" ")!=-1)
			{
				AfxMessageBox("请确定训练数据与神经网络相符!");
				fileTrain.Close();
				return false;
			}
			posOld=posNew+1;
			vectorOutput.push_back(atof(strTemp));
		}
		posNew=strBuf.Find("\n",posOld);
		strTemp=strBuf.Mid(posOld,posNew-posOld);
		if(strTemp.Find(" ")!=-1)
		{
			AfxMessageBox("请确定训练数据与神经网络相符!");
			fileTrain.Close();
			return false;
		}
		posOld=posNew+1;
		vectorOutput.push_back(atof(strTemp));
		//===========addIOpair===============
		trainSet->addIOpair(vectorInput,vectorOutput);
		vectorInput.clear();
		vectorOutput.clear();
	}
	fileTrain.Close();

	multiNet->train(*trainSet,maxStep,learnRate);
	AfxMessageBox("训练完成!");
	return true;
}

void NeuralNetworkPlatform::drawNeuron(CDC *pDC)
{
	CBitmap bmpNeuron;
	bmpNeuron.LoadBitmap(IDB_BMP_NEURON);

	CDC dcCompatibleNeuron;
	dcCompatibleNeuron.CreateCompatibleDC(pDC);
	dcCompatibleNeuron.SelectObject(&bmpNeuron);
	
	CPoint pos;
	if(neuralContainer.empty())
		pos=CPoint(0,0);
	else 
		pos=(*(neuralContainer.end()-1)).nodePos;
	pDC->BitBlt(pos.x+MOVESIZE+NEURALSIZE,pos.y+MOVESIZE+NEURALSIZE,NEURALSIZE,NEURALSIZE,&dcCompatibleNeuron,0,0,SRCCOPY);
	insertNode(CPoint(pos.x+MOVESIZE+NEURALSIZE,pos.y+MOVESIZE+NEURALSIZE));
	dcCompatibleNeuron.DeleteDC();
	bmpNeuron.DeleteObject();
}

void NeuralNetworkPlatform::neuronSelected(CDC *pDC,CPoint point)
{
	mousePosTemp=point;
	for(vector<NeuralNode>::iterator it=neuralContainer.begin();it!=neuralContainer.end();it++)
	{
		if((*it).centerPos.PtInRect(point))
		{	//========神经元中心=========
			for(int i=0;i<EDGENUM;i++)
			{
				if((*it).edgeType[i]!=0)
					return ;
			}
			btnPosType=-1;
			iter=it;
			neuralPosTemp=(*it).nodePos;
			return;
		}	
		for(int i=0;i<EDGENUM;i++)
		{   //===============左键点中神经元边缘结点==============
			if((*it).edgeType[i]==0 && (*it).edgeNodePos[i].PtInRect(point))
			{
				//=========将边缘节点类型设置为输出===========
				(*it).edgeType[i]=1;
				btnPosType=i+1;
				iter=it;
				neuralPosTemp=(*it).nodePos;

				CBitmap bmp;
				CDC dcCompatible;
				bmp.LoadBitmap(IDB_BMP_LINEBEGIN);
				dcCompatible.CreateCompatibleDC(pDC);
				dcCompatible.SelectObject(&bmp);

				linePosTemp=(*it).edgeNodePos[i].TopLeft();
				pDC->BitBlt(linePosTemp.x,linePosTemp.y,LINESIZE,LINESIZE,&dcCompatible,0,0,SRCCOPY);
			}
		}
	}
}

void NeuralNetworkPlatform::neuronMove(CDC *pDC,CPoint point)
{
	CBitmap bmp,bmpBack;
	CDC dcCompatible,dcCompatibleBack;

	if(-1==btnPosType)
	{	//===========选中神经元中心位置===========
		bmp.LoadBitmap(IDB_BMP_NEURON);
		bmpBack.LoadBitmap(IDB_BMP_NEURALBACK);

		dcCompatibleBack.CreateCompatibleDC(pDC);
		dcCompatibleBack.SelectObject(&bmpBack);

		dcCompatible.CreateCompatibleDC(pDC);
		dcCompatible.SelectObject(&bmp);
		if(point.x>=mousePosTemp.x+MOVESIZE)
		{
			pDC->BitBlt(neuralPosTemp.x,neuralPosTemp.y,NEURALSIZE,NEURALSIZE,&dcCompatibleBack,0,0,SRCCOPY);
			pDC->BitBlt(neuralPosTemp.x+MOVESIZE,neuralPosTemp.y,NEURALSIZE,NEURALSIZE,&dcCompatible,0,0,SRCCOPY);

			mousePosTemp=CPoint(mousePosTemp.x+MOVESIZE,mousePosTemp.y);
			neuralPosTemp=CPoint(neuralPosTemp.x+MOVESIZE,neuralPosTemp.y);
		}
		if(point.x<=mousePosTemp.x-MOVESIZE)
		{
			pDC->BitBlt(neuralPosTemp.x,neuralPosTemp.y,NEURALSIZE,NEURALSIZE,&dcCompatibleBack,0,0,SRCCOPY);
			pDC->BitBlt(neuralPosTemp.x-MOVESIZE,neuralPosTemp.y,NEURALSIZE,NEURALSIZE,&dcCompatible,0,0,SRCCOPY);

			mousePosTemp=CPoint(mousePosTemp.x-MOVESIZE,mousePosTemp.y);
			neuralPosTemp=CPoint(neuralPosTemp.x-MOVESIZE,neuralPosTemp.y);
		}
		if(point.y>=mousePosTemp.y+MOVESIZE)
		{
			pDC->BitBlt(neuralPosTemp.x,neuralPosTemp.y,NEURALSIZE,NEURALSIZE,&dcCompatibleBack,0,0,SRCCOPY);
			pDC->BitBlt(neuralPosTemp.x,neuralPosTemp.y+MOVESIZE,NEURALSIZE,NEURALSIZE,&dcCompatible,0,0,SRCCOPY);

			mousePosTemp=CPoint(mousePosTemp.x,mousePosTemp.y+MOVESIZE);
			neuralPosTemp=CPoint(neuralPosTemp.x,neuralPosTemp.y+MOVESIZE);
		}
		if(point.y<=mousePosTemp.y-MOVESIZE)
		{
			pDC->BitBlt(neuralPosTemp.x,neuralPosTemp.y,NEURALSIZE,NEURALSIZE,&dcCompatibleBack,0,0,SRCCOPY);
			pDC->BitBlt(neuralPosTemp.x,neuralPosTemp.y-MOVESIZE,NEURALSIZE,NEURALSIZE,&dcCompatible,0,0,SRCCOPY);

			mousePosTemp=CPoint(mousePosTemp.x,mousePosTemp.y-MOVESIZE);
			neuralPosTemp=CPoint(neuralPosTemp.x,neuralPosTemp.y-MOVESIZE);
		}
		//==================删除已经创建的对象及DC=================
		bmp.DeleteObject();
		bmpBack.DeleteObject();
		dcCompatible.DeleteDC();
		dcCompatibleBack.DeleteDC();
	}
	if(btnPosType>=1 && btnPosType<=EDGENUM)
	{
		bmp.LoadBitmap(IDB_BMP_LINE);

		dcCompatible.CreateCompatibleDC(pDC);
		dcCompatible.SelectObject(&bmp);
		if(point.x>=mousePosTemp.x+LINESIZE)
		{
			linePosTemp=CPoint(linePosTemp.x+LINESIZE,linePosTemp.y);
			pDC->BitBlt(linePosTemp.x,linePosTemp.y,LINESIZE,LINESIZE,&dcCompatible,0,0,SRCCOPY);
			(*iter).outputLine[btnPosType-1].outputPoint.push_back(CPoint(linePosTemp.x,linePosTemp.y));	
			mousePosTemp=CPoint(mousePosTemp.x+LINESIZE,mousePosTemp.y);
			
		}
		if(point.x<=mousePosTemp.x-LINESIZE)
		{
			linePosTemp=CPoint(linePosTemp.x-LINESIZE,linePosTemp.y);
			pDC->BitBlt(linePosTemp.x,linePosTemp.y,LINESIZE,LINESIZE,&dcCompatible,0,0,SRCCOPY);
			(*iter).outputLine[btnPosType-1].outputPoint.push_back(CPoint(linePosTemp.x,linePosTemp.y));	
			mousePosTemp=CPoint(mousePosTemp.x-LINESIZE,mousePosTemp.y);
			
		}
		if(point.y>=mousePosTemp.y+LINESIZE)
		{
			linePosTemp=CPoint(linePosTemp.x,linePosTemp.y+LINESIZE);
			pDC->BitBlt(linePosTemp.x,linePosTemp.y,LINESIZE,LINESIZE,&dcCompatible,0,0,SRCCOPY);
			(*iter).outputLine[btnPosType-1].outputPoint.push_back(CPoint(linePosTemp.x,linePosTemp.y));	

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -