⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bp_with_ga.cc

📁 把genetic algorithm加到传统BP算法
💻 CC
字号:
/* back-propagation algorithm

	- a modified algorithm to embed GA into BP */



#include <stdio.h>

#include <stdlib.h>

#include <math.h>

#include <conio.h>

#include <ctype.h>

#include <string.h>

#include <time.h>



#define ERRORLEVEL  0.001		/* stopping criteria */

#define NITERATIONS 3000		/* no. of iterations to be run */

#define P           32			/* no. of patterns to be trained */

#define I           5			/* no. of input nodes */

#define H           12			/* no. of hidden nodes */

#define J           6			/* no. of output nodes */

#define POPSIZE     6			/* population size for GA */

#define N           150			/* no. of weights = H*(I+1)+J*(H+1) */





typedef struct _chrom

{

  float wts[N];

  float err;

} Ctype;





float target[P][J], out0[P][I], out1[P][H], out2[P][J];

float weights[N], delta1[P][H], delta2[P][J];

float delw1[H][I+1], delw2[J][H+1];





void generate_pop (float Pm, float *weights, Ctype *chrom)

{

  int l, n;



  for (l=0; l<N; l++)

	chrom[0].wts[l] = weights[l];

  for (n=1; n<POPSIZE; n++)

  {

	 for (l=0; l<H*(I+1); l++)

		chrom[n].wts[l] = weights[l];



	 /* perturb by [-0.5..+0.5] */

	 for (l=H*(I+1); l<N; l++)

		if (Pm > (rand()/(float)32767))

		  chrom[n].wts[l] = weights[l] + (rand()/(float)32767-0.5);

		else

		  chrom[n].wts[l] = weights[l];

  }

}





int find_min (Ctype *chrom)

{

  int   m,M1;

  float Min;



  Min = chrom[0].err;

  M1 = 0;



  for (m=1; m<POPSIZE; m++)

	if (chrom[m].err < Min)

	{

		Min = chrom[m].err;

		M1  = m;

	}

  return M1;



}





void fitness (Ctype *chrom, float current_err)

{

  register int i,h,j,p,n;

  float        sum, temp;



  chrom[0].err = current_err;



  for (n=1; n<POPSIZE; n++)

  {

	for (p=0, chrom[n].err=0.0; p<P; p++)

	{

	  for (h=0; h<H; h++)

	  {

		sum = chrom[n].wts[(h*(I+1))+I];



		for (i=0; i<I; i++)

		  sum += chrom[n].wts[(h*(I+1))+i] * out0[p][i];

		out1[p][h] = 1.0 / (1.0 + exp(-sum));

	  }

	  for (j=0; j<J; j++)

	  {

		sum = chrom[n].wts[(H*(I+1))+(j*(H+1))+H];



		for (h=0; h<H; h++)

		  sum += chrom[n].wts[(H*(I+1))+(j*(H+1))+h] * out1[p][h];

		out2[p][j] = 1.0 / (1.0 + exp(-sum));



		temp = target[p][j] - out2[p][j];

		chrom[n].err += temp * temp;

	  }

	}

	chrom[n].err /= (float)(P * J);

  }

}





void main (int argc, char *argv[])

{

  Ctype chrom[POPSIZE];

  float eta, alpha, Pm;

  float ErrorLevel = ERRORLEVEL;

  float error[10], derror, temp, sum, dw;

  register int   h, i, j, p, q, r, l, min;

  int   nIterations=NITERATIONS,count=0;

  FILE  *fpRun, *fpPattern, *fpWeights;

  FILE  *fpWeightsOut, *fpResults, *fpError;

  char  szResults[66], szError[66], szPattern[66];

  char  szWeights[66], szWeightsOut[66];

  clock_t start, end;



  randomize();





  if (argc < 2)

  {

	fprintf(stderr, "Usage: %s runfilename\n", argv[0]);

	exit(1);

  }



  if ((fpRun = fopen(*++argv,"r")) == NULL)

  {

	fprintf(stderr, "can't open file %s\n", *argv);

	exit(1);

  }



  fscanf(fpRun, "%s %s %s %s %s %f %f %f",

			szResults, szError, szPattern, szWeights, szWeightsOut,

			&eta, &alpha, &Pm);



  fclose(fpRun);



  if ((fpWeights = fopen(szWeights,"r")) == NULL)

  {

	fprintf(stderr, "can't open file %s\n", szWeights);

	exit(1);

  }



  for (h=0; h<H; h++)

	for (i=0; i<=I; i++)

	{

	  fscanf (fpWeights, "%f", &weights[h*(I+1)+i]);

	  delw1[h][i] = 0.0;

	}



  for (j=0; j<J; j++)

	for (h=0; h<=H; h++)

	{

	  fscanf(fpWeights, "%f", &weights[H*(I+1)+j*(H+1)+h]);

	  delw2[j][h] = 0.0;

	}



  fclose(fpWeights);





  if ((fpPattern = fopen(szPattern, "r")) == NULL)

  {

	fprintf(stderr, "can't open file %s\n", szPattern);

	exit(1);

  }



  for (p=0; p<P; p++)

  {

	for (i=0; i<I; i++)

	  fscanf(fpPattern, "%f", &out0[p][i]);



	for (j=0; j<J; j++)

	  fscanf(fpPattern, "%f", &target[p][j]);

  }

  fclose(fpPattern);





  if ((fpError = fopen(szError, "w")) == NULL)

  {

	fprintf(stderr, "can't open file %s \n", szError);

	exit(1);

  }





  /* begin processing */

  start = clock();



  for (q=0; q <= nIterations; q++)

  {



	/* calculate feed-forward net */

	for (p=0; p<P; p++)

	{

	  for (h=0; h<H; h++)

	  {

		sum = weights[h*(I+1)+I];



		for (i=0; i< I; i++)

		  sum += weights[h*(I+1)+i] * out0[p][i];

		out1[p][h] = 1.0 / (1.0 + exp(-sum));

	  }



	  for (j=0; j<J; j++)

	  {

		sum = weights[H*(I+1)+j*(H+1)+H];



		for (h=0; h< H; h++)

		  sum += weights[H*(I+1)+j*(H+1)+h] * out1[p][h];

		out2[p][j] = 1.0 / (1.0 + exp(-sum));

	  }



	  /* calculate error signals */

	  for (j=0; j<J; j++)

		delta2[p][j] = (target[p][j] - out2[p][j]) *

						out2[p][j] * (1.0 - out2[p][j]);



	  for (h=0; h<H; h++)

	  {

		sum = 0.0;



		for (j=0; j<J; j++)

		  sum += delta2[p][j] * weights[H*(I+1)+j*(H+1)+h];



		delta1[p][h] = sum * out1[p][h] * (1.0 - out1[p][h]);

	  }

	}



	/* calculate system error */

	if (q==0) r=0;



	for (p=0, error[r]=0.0; p<P; p++)

	{

	  for (j=0; j<J; j++)

	  {

		temp = target[p][j] - out2[p][j];

		error[r] += temp * temp;

	  }

	}

	error[r] /= (P * J);



	fprintf (stderr, "Iteration %5d/%-5d  Error %f\r",

			q, nIterations, error[r]);



	fprintf (fpError, "%f\n", error[r]);



	if (error[r] < ErrorLevel)

	  break;





	/* calculate backward net and update weights */

	for (j=0; j<J; j++)

	{

	  sum = 0.0;



	  for (p=0; p<P; p++)

		sum += delta2[p][j];

	  dw = eta * sum + alpha * delw2[j][H];

	  weights[H*(I+1)+j*(H+1)+H] += dw;

	  delw2[j][H] = dw;



	  for (h=0; h<H; h++)

	  {

		sum = 0.0;



		for (p=0; p<P; p++)

		  sum += delta2[p][j] * out1[p][h];



		dw = eta * sum + alpha * delw2[j][h];

		weights[H*(I+1)+j*(H+1)+h] += dw;

		delw2[j][h] = dw;

	  }

	}



	for (h=0; h<H; h++)

	{

	  sum = 0.0;



	  for (p=0; p<P; p++)

		sum += delta1[p][h];



	  dw = eta * sum + alpha * delw1[h][I];

	  weights[h*(I+1)+I] += dw;

	  delw1[h][I] = dw;



	  for (i=0; i<I; i++)

	  {

		sum = 0.0;



		for (p=0; p < P; p++)

		  sum += delta1[p][h] * out0[p][i];



		dw = eta * sum + alpha * delw1[h][i];

		weights[h*(I+1)+i] += dw;

		delw1[h][i] = dw;

	  }

	}





	/* start mutation if the rate of change of error is less than

		the rate_threshold */



	if (q>=10)

	{

	  derror = (error[r]-error[(r+1)%10])/9;



	  if ((derror<=0.0) && (derror>-0.0003) && (error[r]>ERRORLEVEL*5))

	  {

		 count++;

		 generate_pop (Pm, weights, chrom);

		 fitness (chrom,error[r]);

		 min = find_min (chrom);

		 for (l=0; l<N; l++)

			weights[l] = chrom[min].wts[l];

	  }

	}

	if (++r==10) r=0;



  }



  /* end processing */

  end = clock();

  printf ("Iteration %5d/%-5d  Error %lf\n",q,nIterations,error[r]);

//  printf ("\nElapsed time = %f, count=%d\n",(end-start)/CLK_TCK,count);



  fprintf(stderr, "\n");



  fclose(fpError);



  if ((fpWeightsOut = fopen(szWeightsOut, "w")) == NULL)

  {

	  fprintf(stderr, "can't write file %s\n", szWeightsOut);

	  exit(1);

  }



  for (h=0; h < H; h++)

	for (i=0; i <= I; i++)

	  fprintf(fpWeightsOut, "%9.6f%c", weights[h*(I+1)+i],

				(i == I) ? '\n':' ');

  for (j=0; j < J; j++)

	for (h=0; h <= H; h++)

	  fprintf(fpWeightsOut, "%9.6f%c", weights[H*(I+1)+j*(H+1)+j],

			  (h == H) ? '\n':' ');



  fclose(fpWeightsOut);



  if ((fpResults = fopen(szResults,"w")) == NULL)

  {

	  fprintf(stderr, "can't write file %s\n", szResults);

	  fpResults = stderr;

  }



  for (p=0; p<P; p++)

  {

	  fprintf(fpResults, "%d   ", p);



	  for (j=0; j < J; j++)

		fprintf(fpResults, " %f", out2[p][j]);

	  fprintf (fpResults,"\n");

  }



  fclose(fpResults);



}



⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -