⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bpnn(momentum).cpp

📁 一个测试 后馈神经网络的程序, 解决XOR 问题
💻 CPP
字号:
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <float.h>
#include <time.h>
#define IN            5
#define HID           8
#define OUT           1
#define n_sample      16
#define eta           0.7
#define lambda        1.0
#define desired_error 0.001
#define sigmoid(x)    (1.0/(1.0+exp(-lambda*x)))
#define frand()       (rand()%10000/10001.0)
#define randomize()   srand((unsigned int)time(NULL))
#define N             30000
#define rate          0.9
double x[n_sample][IN]={
  {0,0,0,0,-1},
  {0,0,0,1,-1},
  {0,0,1,0,-1},
  {0,0,1,1,-1},
  {0,1,0,0,-1},
  {0,1,0,1,-1},
  {0,1,1,0,-1},
  {0,1,1,1,-1},
  {1,0,0,0,-1},
  {1,0,0,1,-1},
  {1,0,1,0,-1},
  {1,0,1,1,-1},
  {1,1,0,0,-1},
  {1,1,0,1,-1},
  {1,1,1,0,-1},
  {1,1,1,1,-1}
};
double d[n_sample][OUT]={1,0,0,1,0,1,1,0,0,1,1,0,1,0,0,1};
double v[HID][IN],w[OUT][HID];
double y[HID];
double o[OUT];
double delta_old_w[OUT][HID];
double delta_old_v[HID][IN];

void Initialization(void);
void FindHidden(int p);
void FindOutput(void);
void PrintResult(void);

void main()
{
  int    i,j,k,p,q=0;
  double Error=DBL_MAX;
  double delta_o[OUT];
  double delta_y[HID];

  Initialization();
  while(Error>desired_error)
  {
      q++;
      Error=0;
	  for(p=0; p<n_sample; p++)
	  {
		   FindHidden(p);
		   FindOutput();

		   for(k=0;k<OUT;k++)
		   {
			   Error += 0.5*pow(d[p][k]-o[k], 2.0);
			   delta_o[k]=(d[p][k]-o[k])*(1-o[k])*o[k];
		   }
      
		   for(j=0; j<HID; j++)
		   {
			   delta_y[j]=0;
			   for(k=0;k<OUT;k++)
			   delta_y[j]+=delta_o[k]*w[k][j];
			   delta_y[j]=(1-y[j])*y[j]*delta_y[j];
		   }

		   //double new_dw;
		   for(j=0; j<HID; j++)
		   for(i=0; i<IN; i++)
		   {
			   
		        v[j][i] += eta*delta_y[j]*x[p][i];
                 /* double new_dw;
                 new_dw = eta*delta_y[j]*x[p][i] + rate*delta_old_v[j][i];
                 v[k][j] += new_dw;
				 delta_old_v[k][j] = new_dw;*/
		   }
		   for(k=0; k<OUT; k++)
		   {
			   for(j=0; j<HID; j++)
			   {  
				   w[k][j]+=eta*delta_o[k]*y[j];
			       /*double new_dw;
				   new_dw=eta*delta_o[k]*y[j]+rate*delta_old_w[k][j];
				   w[k][j] += new_dw;
				   delta_old_w[k][j] = new_dw;*/
			   }
		   }
    }
    printf("Error in the %d-th learning cycle = %f\n",q,Error);
  } 
  PrintResult();
}
  
/*************************************************************/
/* Initialization of the connection weights                  */
/*************************************************************/
void Initialization(void){
  int i,j,k;

  randomize();
  for(j=0; j<HID; j++)
    for(i=0; i<IN; i++)
	{
      v[j][i] = frand()-0.5;
      delta_old_v[j][i]= frand()-0.5;
	}

  for(k=0; k<OUT; k++)
    for(j=0; j<HID; j++)
	{
      w[k][j] = frand()-0.5;
	  delta_old_w[k][j]= frand()-0.5;
	}
}

/*************************************************************/
/* Find the output of the hidden neurons                     */
/*************************************************************/
void FindHidden(int p){
  int    i,j;
  double temp;

  for(j=0;j<HID-1;j++){
    temp=0;
    for(i=0;i<IN;i++)
      temp+=v[j][i]*x[p][i];
    y[j]=sigmoid(temp);
  }
  y[HID-1]=-1;
}

/*************************************************************/
/* Find the actual outputs of the network                    */
/*************************************************************/
void FindOutput(void){
  int    j,k;
  double temp;

  for(k=0;k<OUT;k++){
    temp=0;
    for(j=0;j<HID;j++)
      temp += w[k][j]*y[j];
    o[k]=sigmoid(temp);
  }
}

/*************************************************************/
/* Print out the final result                                */
/*************************************************************/
void PrintResult(void){
  int i,j,k;

  printf("\n\n");
  printf("The connection weights in the output layer:\n");
  for(k=0; k<OUT; k++){
    for(j=0; j<HID; j++)
      printf("%5f ",w[k][j]);
    printf("\n");
  }

  printf("\n\n");
  printf("The connection weights in the hidden layer:\n");
  for(j=0; j<HID-1; j++){
    for(i=0; i<IN; i++)
      printf("%5f ",v[j][i]);
    printf("\n");
  }
  printf("\n\n");
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -