⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 sarprop.c

📁 神经网络系统 SARPROP 算法 加入sa系数
💻 C
字号:
/* SARPROP
	- a standard modified bp algorithm for comparison
*/

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <ctype.h>
#include <string.h>
#include <time.h>

#define S           1 	        /* generalized BP with the S parameter */
#define ETA         0.1     	/* initialize weight-update value */
#define ETA_BIG     1.2         /* parameter to increase step size */
#define ETA_SMALL   0.5         /* parameter to reduce step size */
#define MAX_D       50          /* max. value for dw */
#define MIN_D  	    0.000001    /* min. value for dw */
#define ERRORLEVEL  0.001	/* stopping criteria */
#define NITERATIONS 3000	 	/* no. of iterations to be run */
#define P           4     	/* no. of patterns to be trained */
#define I           2           /* no. of input nodes */
#define H           2          	/* no. of hidden nodes */
#define J           1           /* no. of output nodes */
#define N           9   	/* no. of weights = H*(I+1)+J*(H+1) */
#define weightfile  30   	/* no. of weight files used */

/* new part */
#define TEMP        0.08        /* SA Temperature parameter */
#define RAND_MAX    32767

typedef struct _chromq
{
  double wts[N];
  double err;
} Ctype;

double target[P][J], out0[P][I], out1[P][H], out2[P][J];
double weights[N],delta1[P][H],delta2[P][J];
double pre_dw1[H][I+1],pre_dw2[J][H+1];
double dE1[H][I+1],dE2[J][H+1];
double pre_dE1[H][I+1],pre_dE2[J][H+1];
double d1[H][I+1],d2[J][H+1];
double pre_d1[H][I+1],pre_d2[J][H+1];

FILE   *fpRun, *fpPattern, *fpWts;
FILE   *fpWeightsOut, *fpResults, *fpError;

void itoa(n, s)     /* convert integer to character */
int n; char s[];
{
  int i=0;
  if (n/10 ==0)
    s[i++]= n +'0';
  else
  {  
    s[i++] = (n/10)+'0';
    s[i++] = (n%10)+'0';
  }
  s[i] = '\0';
} 

double minimum(a, b)
double a, b;
{
  if (a < b) 
    return a;
  else 
    return b; 
}

double maximum(a, b)
double a, b;
{
  if (a > b) 
    return a;
  else 
    return b; 
}

double sign(a)
double a;
{
  if (a < 0.0)  
    return -1.0; 
  else  
    return 1.0;
} 

main (argc, argv)
int argc;
char *argv[];
{
  double   eta,alpha;
  double   eta_big=ETA_BIG,eta_small=ETA_SMALL,max_d=MAX_D,min_d=MIN_D;
  double   error[20],derror,temp,temp1,sum,dw;
  double   converge=0.0;  
  register int   h,i,j,p,q,r,l,x,min,tmp;
  int      nIterations=NITERATIONS;
  unsigned steady=0,non_conv=0;
  char     szResults[66],szError[66],szPattern[66],szWeightsOut[66];
  char     charstr[12],tmpstr[3];
  double   optwts[N], minerr;
  time_t   t,start,end;

  /* new part */
  double temperature; /* temp = temperature */
  double sa;   /* sa = SA factor */
  double random_no; /* random number */
  int index;  /* index */
  double sa_value; /* SA value */
  int seed; /* seed */

  int offset; /* a value to select a set of weight files */

  t = time(NULL);      /* randomize the seed for each run */
/*  tmp = srand(t);
*/
  seed = (int) atoi(argv[3]);
  srand(seed);

  if (argc < 2)
  {
	 fprintf(stderr, "Usage: %s runfilename\n", argv[0]);
	 exit(1);
  }

  
  if ((fpRun = fopen(*++argv,"r")) == NULL)
  {
	 fprintf(stderr, "can't open file %s\n", *argv);
	 exit(1);
  }

  offset = atoi(argv[1]);
  temperature = TEMP;

  fscanf(fpRun, "%s %s %s %s %lf %lf",

			  szResults, szError, szPattern, szWeightsOut, &eta, &alpha);
  fclose(fpRun);

  if ((fpPattern = fopen(szPattern, "r")) == NULL)
  {
	 fprintf(stderr, "can't open file %s\n", szPattern);
	 exit(1);
  }

  for (p=0; p<P; p++)
  {
	for (i=0; i<I; i++)
	  fscanf(fpPattern, "%lf", &out0[p][i]);

	for (j=0; j<J; j++)
	  fscanf(fpPattern, "%lf", &target[p][j]);
  }
  fclose(fpPattern);

  if ((fpError = fopen(szError, "w")) == NULL)
  {
	 fprintf(stderr, "can't open file %s \n", szError);
	 exit(1);
  }

  start = time(NULL);

  for (x=offset; x<weightfile+offset; x++)
  {
	 minerr = 99999999.0; steady=0; 
	 strcpy(charstr,"w");
	 itoa(x,tmpstr);
	 strcat(charstr,tmpstr);
	 strcat(charstr,".wts");
	 if ((fpWts = fopen(charstr,"r")) == NULL)
	 {
		fprintf (stderr, "can't open wts file\n");
                fprintf (stderr, "%s\n", charstr);
		exit(1);
	 }

	 for (h=0; h<H; h++)
		for (i=0; i<=I; i++)
		{
		  fscanf (fpWts, "%lf", &weights[h*(I+1)+i]);
		  pre_dw1[h][i] = 0.0;
 		  pre_dE1[h][i]=0.0;
 		  d1[h][i]=ETA;
		}

	 for (j=0; j<J; j++)
		for (h=0; h<=H; h++)
		{
		  fscanf(fpWts, "%lf", &weights[H*(I+1)+j*(H+1)+h]);
		  pre_dw2[j][h] = 0.0;
 		  pre_dE2[j][h]=0.0;
 		  d2[j][h]=ETA;
		}

	 fclose(fpWts);


	 /* begin processing */

	 for (q=0; q <= nIterations; q++)
	 {
		/* calculate feed-forward net */
		for (p=0; p<P; p++)
		{
		  for (h=0; h<H; h++)
		  {
			 sum = weights[h*(I+1)+I];
			 for (i=0; i< I; i++)
				sum += weights[h*(I+1)+i] * out0[p][i];
			 out1[p][h] = 1.0 / (1.0 + exp(-sum));
		  }

		  for (j=0; j<J; j++)
		  {
			 sum = weights[H*(I+1)+j*(H+1)+H];
			 for (h=0; h< H; h++)
				sum += weights[H*(I+1)+j*(H+1)+h] * out1[p][h];
			 out2[p][j] = 1.0 / (1.0 + exp(-sum));
		  }

		  /* calculate error signals */
		  for (j=0; j<J; j++)
		  {
			 temp = out2[p][j]-target[p][j];
			 temp1= pow(out2[p][j]*(1.0-out2[p][j]),1.0/S);
			 delta2[p][j] = temp * temp1;

		  }
		  for (h=0; h<H; h++)
		  {
			 sum = 0.0; 
			 temp1= pow(out1[p][h]*(1.0-out1[p][h]),1.0/S);
			 for (j=0; j<J; j++)
				sum += delta2[p][j]*weights[H*(I+1)+j*(H+1)+h];
			 delta1[p][h] = sum * temp1;
	     	  }
		}
		/* calculate system error */
		if (q==0) r=0;

		for (p=0, error[r]=0.0; p<P; p++)
		{
		  for (j=0; j<J; j++)
		  {
			 temp = out2[p][j]-target[p][j];
			 error[r] += temp * temp;
		  }
		}
		error[r] /= (P * J);

		if (error[r] < ERRORLEVEL)
		  break;
		if (error[r] < minerr)
		{
		  minerr = error[r];
		  for (l=0; l<N; l++) optwts[l] = weights[l];
		}
		fprintf (stderr,"Iteration %5d/%-5d  Error %lf minerr %lf\r",
				q, nIterations, error[r], minerr);


		/* calculate rate of change of error with respect to weights */
                sa = pow(2.0, -1.0 * q * temperature);

		for (j=0; j<J; j++)
		{
		  dE2[j][H] = 0.0;     
		  for (p=0; p<P; p++)
			 dE2[j][H] += delta2[p][j];
                  index = H * (I + 1) + j * (H + 1) + H;
                  sa_value = 0.01 * sa * weights[index];
                  sa_value /= 1 + weights[index] * weights[index];
                  dE2[j][H] -= sa_value;

		  for (h=0; h<H; h++)
		  {
			 dE2[j][h] = 0.0; 
			 for (p=0; p<P; p++)
				dE2[j][h] += delta2[p][j] * out1[p][h];
                         index = H * (I + 1) + j * (H + 1) + h;
                         sa_value = 0.01 * sa * weights[index];
                         sa_value /= 1 + weights[index] * weights[index];
                         dE2[j][h] -= sa_value;
		  }
		}

		for (h=0; h<H; h++)
		{
		  dE1[h][I] = 0.0;  
		  for (p=0; p<P; p++)
			 dE1[h][I] += delta1[p][h];
                  index = h * (I + 1) + I;
                  sa_value = 0.01 * sa * weights[index];
                  sa_value /= 1 + weights[index] * weights[index];
                  dE1[h][I] -= sa_value;

		  for (i=0; i<I; i++)
		  {
			 dE1[h][i] = 0.0;   
			 for (p=0; p < P; p++)
				dE1[h][i] += delta1[p][h] * out0[p][i];
                         index = h * (I + 1) + i;
                         sa_value = 0.01 * sa * weights[index];
                         sa_value /= 1 + weights[index] * weights[index];
                         dE1[h][i] -= sa_value;
		  }
		}
		/* calculate weight update rule */
         	for (j=0; j<J; j++)
          	  for (h=0; h<=H; h++)
    		  {
	  	    if (dE2[j][h]*pre_dE2[j][h] > 0.0)
		    {
			d2[j][h]=minimum(pre_d2[j][h]*eta_big,max_d);
			dw = -sign(dE2[j][h])*d2[j][h];
			weights[H*(I+1)+j*(H+1)+h] += dw;
			pre_dE2[j][h] = dE2[j][h];
		    }
		    else if (dE2[j][h]*pre_dE2[j][h] < 0.0)
			 {
                           if (pre_d2[j][h] < 0.4 * sa * sa)
                           {
                             random_no = (double) rand() / RAND_MAX;
                             d2[j][h] = pre_d2[j][h] * eta_small 
                                        + 0.8 *  random_no * sa * sa; 
                           }
                           else
                             d2[j][h] = pre_d2[j][h] * eta_small;
			   d2[j][h]=maximum(d2[j][h],min_d);
/*                           dw = -pre_dw2[j][h];
			   weights[H*(I+1)+j*(H+1)+h] += dw;
*/			   pre_dE2[j][h] = 0.0;
			 }
			 else  /* dE2 * pre_dE2 == 0.0 */
			 {
			   dw = -sign(dE2[j][h])*d2[j][h];
			   weights[H*(I+1)+j*(H+1)+h] += dw;
			   pre_dE2[j][h] = dE2[j][h];
			 }
		    pre_d2[j][h] = d2[j][h];	 
	            pre_dw2[j][h] = dw;
		  }	    
	
         	for (h=0; h<H; h++)
          	  for (i=0; i<=I; i++)
    		  {
	  	    if (dE1[h][i]*pre_dE1[h][i] > 0.0)
		    {
			d1[h][i]=minimum(pre_d1[h][i]*eta_big,max_d);
			dw = -sign(dE1[h][i])*d1[h][i];
			weights[h*(I+1)+i] += dw;
			pre_dE1[h][i] = dE1[h][i];
		    }
		    else if (dE1[h][i]*pre_dE1[h][i] < 0.0)
			 {
                           if (pre_d1[h][i] < 0.4 * sa * sa)
                           {
                             random_no = (double) rand() / RAND_MAX;
                             d1[h][i] = pre_d1[h][i] * eta_small
                                        + 0.8 * random_no * sa * sa;
                           }
                           else
                             d1[h][i] = pre_d1[h][i] * eta_small;
			   d1[h][i]=maximum(d1[h][i],min_d);
/*			   dw = -pre_dw1[h][i];
			   weights[h*(I+1)+i] += dw;
*/			   pre_dE1[h][i] = 0.0;
			 }
			 else  /* dE1 * pre_dE1 == 0.0 */
			 {
			   dw = -sign(dE1[h][i])*d1[h][i];
			   weights[h*(I+1)+i] += dw;
			   pre_dE1[h][i] = dE1[h][i];
			 }
		    pre_d1[h][i] = d1[h][i];	 
	            pre_dw1[h][i] = dw;
		  }	    
			   	

		fprintf (fpError, "%lf\n", error[r]);

	        if (++r==10) r=0;
	 }

	 /* end processing */
	 printf ("Iteration %5d/%-5d  Error %lf minerr %lf\n",q-1,nIterations,error[r],minerr);
	 if (q-1 == NITERATIONS)
	   non_conv ++;
	 else
           converge += q-1;
         fprintf (stderr,"\n");
	 fprintf (fpError, "%lf\n", error[r]);
	 fclose(fpError);

  }
  end = time(NULL);
  printf ("\nElapsed time = %ld sec\n",(long)end - (long)start);

  printf ("The avg rate is %5.2lf, percentage of conv is %5.2lf\n\n",converge/(weightfile-non_conv),(double)(weightfile-non_conv)/weightfile*100);
  if ((fpWeightsOut = fopen(szWeightsOut, "w")) == NULL)
  {
	  fprintf(stderr, "can't write file %s\n", szWeightsOut);
	  exit(1);
  }

  for (h=0; h < H; h++)
	for (i=0; i <= I; i++)
	  fprintf(fpWeightsOut, "%9.6f%c", weights[h*(I+1)+i],
				(i == I) ? '\n':' ');
  for (j=0; j < J; j++)
	for (h=0; h <= H; h++)
	  fprintf(fpWeightsOut, "%9.6f%c", weights[H*(I+1)+j*(H+1)+h],
			  (h == H) ? '\n':' ');
  fclose(fpWeightsOut);

  if ((fpResults = fopen(szResults,"w")) == NULL)
  {
	 fprintf(stderr, "can't write file %s\n", szResults);
	 fpResults = stderr;
  }

  for (p=0; p<P; p++)
  {
	 fprintf(fpResults, "%d   ", p);

	 for (j=0; j < J; j++)
		fprintf(fpResults, " %lf", out2[p][j]);
	 fprintf (fpResults,"\n");
  }
  fclose(fpResults);
  return 0;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -