📄 bpnetdlg.cpp
字号:
// BPNetDlg.cpp : implementation file
//
#include "stdafx.h"
#include "process.h"
#include "BPNetDlg.h"
#include <float.h>
#include <stdio.h>
#include <math.h>
#include <ctype.h>
#include <malloc.h>
#include <iostream.h>
#include <fstream.h>
#ifdef _DEBUG
#define new DEBUG_NEW
#undef THIS_FILE
static char THIS_FILE[] = __FILE__;
#endif
/////////////////////////////////////////////////////////////////////////////
// CBPNetDlg dialog
FILE *fp1,*fp2,*fp3,*fopen();
CBPNetDlg::CBPNetDlg(CWnd* pParent /*=NULL*/)
: CDialog(CBPNetDlg::IDD, pParent)
{
//{{AFX_DATA_INIT(CBPNetDlg)
m_error = 0.0;
m_alpha = 0.0;
m_maxnum = 0;
m_eta = 0.0;
m_ninattr = 0;
m_noutattr = 0;
m_nhlayer = 0;
m_nodehlayer = 0;
m_inputdata = _T("");
// m_errortmp = _T("");
// m_maxtemp = 0;
m_INPUTNET = _T("");
m_INPUTFANGZHEN = _T("");
//}}AFX_DATA_INIT
randseed=568731L;
eta=0.9;
alpha=0.7;
maxe=0.000001;
maxep=0.0000001;
cnt_num=10000;
ninattr=1;
noutattr=1;
ninput=1;
nhlayer=1;
strcpy(task_name,"output");
nsample=380;//
m_bSimulateDataFlag = true;
m_bEnableLearn=false;
INPUTNETWEIGHT=FALSE;
INPUTNETEMULATION=FALSE;
}
void CBPNetDlg::DoDataExchange(CDataExchange* pDX)
{
CDialog::DoDataExchange(pDX);
//{{AFX_DATA_MAP(CBPNetDlg)
DDX_Control(pDX, IDC_STATE, m_statebtn);
DDX_Text(pDX, IDC_EDIT31, m_error);
DDX_Text(pDX, IDC_EDIT33, m_alpha);
DDX_Text(pDX, IDC_EDIT32, m_maxnum);
DDX_Text(pDX, IDC_EDIT34, m_eta);
DDX_Text(pDX, IDC_EDIT35, m_ninattr);
DDX_Text(pDX, IDC_EDIT36, m_noutattr);
DDX_Text(pDX, IDC_EDIT37, m_nhlayer);
DDX_Text(pDX, IDC_EDIT38, m_nodehlayer);
DDX_Text(pDX, IDC_EDIT39, m_inputdata);
// DDX_Text(pDX, IDC_EDIT40, m_errortmp);
// DDX_Text(pDX, IDC_EDIT41, m_maxtemp);
DDX_Text(pDX, IDC_EDIT42, m_INPUTNET);
DDX_Text(pDX, IDC_EDIT43, m_INPUTFANGZHEN);
//}}AFX_DATA_MAP
}
BEGIN_MESSAGE_MAP(CBPNetDlg, CDialog)
//{{AFX_MSG_MAP(CBPNetDlg)
ON_BN_CLICKED(IDC_LEARN, OnBeginLearn)
ON_BN_CLICKED(IDC_OUTPUT, OnOutput)
ON_BN_CLICKED(IDC_DATA_INPUT, OnDataInputbrowse)
ON_BN_CLICKED(IDC_INPUTNET, Oninputnetdata)
ON_BN_CLICKED(IDC_BUTTON2, Oninputfangzhendata)
ON_BN_CLICKED(IDC_STATE, OnchangeState)
ON_BN_CLICKED(IDC_CANCEL, OnCancel)
ON_WM_DESTROY()
//}}AFX_MSG_MAP
END_MESSAGE_MAP()
/////////////////////////////////////////////////////////////////////////////
// CBPNetDlg message handlers
int CBPNetDlg::random(void)
{
randseed=15625L*randseed+22221L;
return((randseed>>16)&0x7FFF);
}
///////////////////////
void CBPNetDlg::init(void)
{
int len1,len2,i;
double *p1,*p2,*p3,*p4;
len1=len2=0;
nunit[nhlayer+2]=0;
for(i=0;i<(nhlayer+2);i++)
{
len1+=(nunit[i]+1)*nunit[i+1];
len2+=nunit[i]+1;
}
p1=(double *) calloc(len1+1,sizeof(double));
p2=(double *) calloc(len2+1,sizeof(double));
p3=(double *) calloc(len2+1,sizeof(double));
p4=(double *) calloc(len1+1,sizeof(double));
wtptr[0]=p1;
outptr[0]=p2;
errptr[0]=p3;
delw[0]=p4;
for(i=1;i<(nhlayer+1);i++)
{
wtptr[i]=wtptr[i-1]+nunit[i]*(nunit[i-1]+1);
delw[i]=delw[i-1]+nunit[i]*(nunit[i-1]+1);
}
for(i=1;i<(nhlayer+2);i++)
{
outptr[i]=outptr[i-1]+nunit[i-1]+1;
errptr[i]=errptr[i-1]+nunit[i-1]+1;
}
for(i=0;i<(nhlayer+1);i++) /*set up threshold outputs*/
{
*(outptr[i]+nunit[i])=1.0;
}
}
///////////////////////////
void CBPNetDlg::initwt(void)
{
int i,j;
for(j=0;j<(nhlayer+1);j++)
for(i=0;i<(nunit[j]+1)*nunit[j+1];i++)
{
*(wtptr[j]+i)=random()/pow(2.0,15.0)-0.5;
*(delw[j]+i)=0.0;
}
}
///////////////////
void CBPNetDlg::learning()
{
// BeginWaitCursor();
//开始学习!!!
int result;
user_session();//参数输入
set_up();//设置网络
init();
do
{
initwt();
result = rumelhart(0,ninput);
} while (result ==RESTRT);
if (result == FEXIT)
{
dwrite(task_name);
wtwrite(task_name);
AfxMessageBox("\nMax number of learning reached,you failed !");
return;
}
else
{
dwrite(task_name);
wtwrite(task_name);
AfxMessageBox("恭喜你! 训练成功!! 继续努力!!!");
}
// EndWaitCursor();
}
///////////////////
void CBPNetDlg::set_up()
{
for(int i=0;i<nhlayer;i++)
{
nunit[i+1]=3;//隐层的神经元数,暂时都设为8!!!
}
fplot10=1;
nunit[0]=ninattr;//nunit[0]为输入层的神经元数
nunit[nhlayer+1]=noutattr;//nunit[nhlayer+1]为输出层的神经元数
}
//////////
void CBPNetDlg::user_session()
{
int i,j;
ifstream istrm;
istrm.open(m_inputdata);//m_inputdata是用来训练网络的样本文件
for (i=0; i<ninput; i++)
{
for (j=0; j<ninattr;j++)
{
istrm>>input[i][j];
}
for (j=0; j<noutattr;j++)
{
istrm>>target[i][j];
}
}
istrm.close();
}
///////////////////////
int CBPNetDlg::rumelhart(int from_snum, int to_snum)
{
int i,j,m,n,p,offset,index;
float out;
char *err_file="criter.txt";
CString str;
HWND hWnd;
nsold=0;
cnt=0;
result=CONTNE; /////////
if(fplot10==1)
if((fp3=fopen(err_file,"w"))==NULL)
{
exit(0);
}
do{
err_curr=0.0;
for(i=from_snum;i<to_snum;i++)
{
forward(i);
for(m=0;m<nunit[nhlayer+1];m++)
{
out=*(outptr[nhlayer+1]+m);
*(errptr[nhlayer+1]+m)
=(target[i][m]-out)*(1-out)*out;
}
for(m=nhlayer+1;m>=1;m--)
{
for(n=0;n<(nunit[m-1]+1);n++)
{
*(errptr[m-1]+n)=0.0;
for(p=0;p<nunit[m];p++)
{
offset=(nunit[m-1]+1)*p+n;
*(delw[m-1]+offset)
=eta*(*(errptr[m]+p))
*(*(outptr[m-1]+n))
+alpha*(*(delw[m-1]+offset));
*(errptr[m-1]+n)+=*(errptr[m]+p)
*(*(wtptr[m-1]+offset));
}
*(errptr[m-1]+n)=*(errptr[m-1]+n)
*(1-*(outptr[m-1]+n))
*(*(outptr[m-1]+n));
}
}
for(m=1;m<nhlayer+2;m++)
{
for(n=0;n<nunit[m];n++)
{
for(p=0;p<nunit[m-1]+1;p++)
{
offset=(nunit[m-1]+1)*n+p;
*(wtptr[m-1]+offset)+=*(delw[m-1]+offset);
}
}
}
ep[i]=0.0;
for(m=0;m<nunit[nhlayer+1];m++)
{
ep[i]+=fabs((target[i][m]-
*(outptr[nhlayer+1]+m)));
}
err_curr+=ep[i]*ep[i];
}
err_curr=0.5*err_curr/ninput;
if(fplot10==1)
fprintf(fp3,"%1d,%f\n",cnt,err_curr);
cnt++;
result=introspective(from_snum,to_snum);
CStatic *static1,*static2;
static1=(CStatic*)GetDlgItem(IDC_STATIC3);
static2=(CStatic*)GetDlgItem(IDC_STATIC4);
str.Format("%1.9f",err_curr);
static1->SetWindowText(str);
str.Format("%u",cnt);
static2->SetWindowText(str);
}while(result==CONTNE);
for(i=from_snum;i<to_snum;i++) forward(i);
// str.Format("%.8f",err_curr);
// m_errortmp=str;
// m_maxtemp =cnt;
// UpdateData(false);
return(result);
}
/////////////////////
int CBPNetDlg::introspective(int nfrom,int nto)
{
int i,flag;
if(cnt>cnt_num) return(FEXIT); //#define FEXIT 1
nsnew=0;
flag=1;
for(i=nfrom;(i<nto)&&(flag==1);i++)
{
if(ep[i]<=maxep) nsnew++;
else flag=0;
}
if(flag==1) return(SEXIT); //#define SEXIT 3
if(err_curr<=maxe) return(SEXIT);
return(CONTNE);
}
//////////////////
void CBPNetDlg::forward(int i)
{
int m,n,p,offset;
float net;
for(m=0;m<ninattr;m++)
*(outptr[0]+m)=input[i][m];
for(m=1;m<nhlayer+2;m++)
{
for(n=0;n<nunit[m];n++)
{
net=0.0;
for(p=0;p<nunit[m-1]+1;p++)
{
offset=(nunit[m-1]+1)*n+p;
net+=*(wtptr[m-1]+offset)
*(*(outptr[m-1]+p));
}
*(outptr[m]+n)=1/(1+exp(-net));
}
}
for(n=0;n<nunit[nhlayer+1];n++)
outpt[i][n]=*(outptr[nhlayer+1]+n);//???????
}
////////////////////////////
void CBPNetDlg::dwrite(char *taskname)
{
int i,j,c;
char var_file_name[20];
strcpy(var_file_name,taskname);
strcat(var_file_name,"_v.txt");
if((fp1=fopen(var_file_name,"w+"))==NULL)
{
exit(0);
}
fprintf(fp1,"%u %u %u %f %f %u %u\n",ninput,noutattr,ninattr,eta,alpha,nhlayer,cnt_num);
for(i=0;i<nhlayer+2;i++)
{
fprintf(fp1,"%d ",nunit[i]);//输出所有层的节点数到文件
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -