⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 quickprop.c

📁 神经网络系统 quickprop算法 大幅提高运算速度
💻 C
字号:
/* Quickprop algorithm (developed by Scott Fahlman)
	- a standard modified bp algorithm for comparison
*/

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <ctype.h>
#include <string.h>
#include <time.h>

#define GA          1
#define S           1 	        /* generalized BP with the S parameter */
#define mu          1.75        /* maximum grouth factor */
#define threshold   0.0 	/* mode switch threshold */
#define ERRORLEVEL  0.001	/* stopping criteria */
#define NITERATIONS 3000	/* no. of iterations to be run */
#define P           4     	/* no. of patterns to be trained */
#define I           2          /* no. of input nodes */
#define H           2         	/* no. of hidden nodes */
#define J           1           /* no. of output nodes */
#define N           9   	/* no. of weights = H*(I+1)+J*(H+1) */
#define weightfile  30    	/* no. of weight files used */


double target[P][J], out0[P][I], out1[P][H], out2[P][J];
double weights[N],delta1[P][H],delta2[P][J];
double dE1[H][I+1],dE2[J][H+1];
double pre_dE1[H][I+1],pre_dE2[J][H+1];
double delw1[H][I+1], delw2[J][H+1];

FILE   *fpRun, *fpPattern, *fpWts;
FILE   *fpWeightsOut, *fpResults, *fpError;

void itoa(n, s)     /* convert integer to character */
int n; char s[];
{
  int i=0;
  if (n/10 ==0)
    s[i++]= n +'0';
  else
  {  
    s[i++] = (n/10)+'0';
    s[i++] = (n%10)+'0';
  }
  s[i] = '\0';
} 

double minimum(a, b)
double a, b;
{
  if (a < b) 
    return a;
  else 
    return b; 
}

double maximum(a, b)
double a, b;
{
  if (a > b) 
    return a;
  else 
    return b; 
}

double sign(a)
double a;
{
  if (a < 0.0)  
    return -1.0; 
  else  
    return 1.0;
} 

main (argc, argv)
int argc;
char *argv[];
{

  double   eta,alpha;
  double   error[20],derror,temp,temp1,sum,dw;
  double   converge=0.0;  
  register int   h,i,j,p,q,r,l,x,min,tmp;
  int      nIterations=NITERATIONS;
  unsigned steady=0,non_conv=0;
  char     szResults[66],szError[66],szPattern[66],szWeightsOut[66];
  char     charstr[12],tmpstr[3];
  double   optwts[N], minerr;
  time_t   t,start,end;

  int offset; /* a value to set the weight files */

  t = time(NULL);      /* randomize the seed for each run */
/*  tmp = srand(t);
*/
  srand(t);

  if (argc < 2)
  {
	 fprintf(stderr, "Usage: %s runfilename\n", argv[0]);
	 exit(1);
  }

  if ((fpRun = fopen(*++argv,"r")) == NULL)
  {
	 fprintf(stderr, "can't open file %s\n", *argv);
	 exit(1);
  }

  offset = atoi(argv[1]);

  fscanf(fpRun, "%s %s %s %s %lf %lf",
			  szResults, szError, szPattern, szWeightsOut, &eta, &alpha);
  fclose(fpRun);

  if ((fpPattern = fopen(szPattern, "r")) == NULL)
  {
	 fprintf(stderr, "can't open file %s\n", szPattern);
	 exit(1);
  }

  for (p=0; p<P; p++)
  {
	for (i=0; i<I; i++)
	  fscanf(fpPattern, "%lf", &out0[p][i]);

	for (j=0; j<J; j++)
	  fscanf(fpPattern, "%lf", &target[p][j]);
  }
  fclose(fpPattern);

  if ((fpError = fopen(szError, "w")) == NULL)
  {
	 fprintf(stderr, "can't open file %s \n", szError);
	 exit(1);
  }

  start = time(NULL);

  for (x=offset; x<weightfile+offset; x++)
  {
	 minerr = 99999999.0; steady=0; 
	 strcpy(charstr,"w");
	 itoa(x,tmpstr);
	 strcat(charstr,tmpstr);
	 strcat(charstr,".wts");
	 if ((fpWts = fopen(charstr,"r")) == NULL)
	 {
		fprintf (stderr, "can't open wts file\n");
		exit(1);
	 }

	 for (h=0; h<H; h++)
		for (i=0; i<=I; i++)
		{
		  fscanf (fpWts, "%lf", &weights[h*(I+1)+i]);
		  delw1[h][i] = 0.0;
                  pre_dE1[h][i] = 0.0;
		}

	 for (j=0; j<J; j++)
		for (h=0; h<=H; h++)
		{
		  fscanf(fpWts, "%lf", &weights[H*(I+1)+j*(H+1)+h]);
		  delw2[j][h] = 0.0;
                  pre_dE2[j][h]= 0.0;
		}

	 fclose(fpWts);


	 /* begin processing */

	 for (q=0; q <= nIterations; q++)
	 {
		/* calculate feed-forward net = forward pass */
		for (p=0; p<P; p++)
		{
		  for (h=0; h<H; h++)
		  {
			 sum = weights[h*(I+1)+I];
			 for (i=0; i< I; i++)
				sum += weights[h*(I+1)+i] * out0[p][i];
			 out1[p][h] = 1.0 / (1.0 + exp(-sum));
		  }

		  for (j=0; j<J; j++)
		  {
			 sum = weights[H*(I+1)+j*(H+1)+H];
			 for (h=0; h< H; h++)
				sum += weights[H*(I+1)+j*(H+1)+h] * out1[p][h];
			 out2[p][j] = 1.0 / (1.0 + exp(-sum));
		  }

		  /* calculate error signals */
		  for (j=0; j<J; j++)
		  {
			 temp = target[p][j]-out2[p][j];
			 temp1= pow(out2[p][j]*(1.0-out2[p][j]),1.0/S);
			 delta2[p][j] = temp * temp1;

		  }
		  for (h=0; h<H; h++)
		  {
			 sum = 0.0; 
			 temp1= pow(out1[p][h]*(1.0-out1[p][h]),1.0/S);
			 for (j=0; j<J; j++)
				sum += delta2[p][j]*weights[H*(I+1)+j*(H+1)+h];
			 delta1[p][h] = sum * temp1;
	     	  }
		}
		/* calculate system error */
		if (q==0) r=0;

		for (p=0, error[r]=0.0; p<P; p++)
		{
		  for (j=0; j<J; j++)
		  {
			 temp = out2[p][j]-target[p][j];
			 error[r] += temp * temp;
		  }
		}
		error[r] /= (P * J);

		if (error[r] < ERRORLEVEL)
		  break;
		if (error[r] < minerr)
		{
		  minerr = error[r];
		  for (l=0; l<N; l++) optwts[l] = weights[l];
		}
		fprintf (stderr,"Iteration %5d/%-5d  Error %lf minerr %lf\r",
				q, nIterations, error[r], minerr);


		/* calculate rate of change of error with respect to weights */
		for (j=0; j<J; j++)
		{
		  dE2[j][H] = 0.0;     
		  for (p=0; p<P; p++)
			 dE2[j][H] += delta2[p][j];

		  for (h=0; h<H; h++)
		  {
			 dE2[j][h] = 0.0; 
			 for (p=0; p<P; p++)
				dE2[j][h] += delta2[p][j] * out1[p][h];
		  }
		}

		for (h=0; h<H; h++)
		{
		  dE1[h][I] = 0.0;  
		  for (p=0; p<P; p++)
			 dE1[h][I] += delta1[p][h];

		  for (i=0; i<I; i++)
		  {
			 dE1[h][i] = 0.0;   
			 for (p=0; p < P; p++)
				dE1[h][i] += delta1[p][h] * out0[p][i];
		  }
		}
		/* calculate weight update rule */
         	for (j=0; j<J; j++)
          	  for (h=0; h<=H; h++)
    		  { 
                    dw = 0.0;
               	    if (delw2[j][h] > threshold)
                    {
                       if (dE2[j][h] > 0.0)
                         dw += eta*dE2[j][h];
                       if (dE2[j][h] > (mu/(1.0+mu))*pre_dE2[j][h])
                         dw += mu*delw2[j][h];
                       else 
        	         dw +=dE2[j][h]/(pre_dE2[j][h]-dE2[j][h])*delw2[j][h];
	  	    } 
                    else if (delw2[j][h] < -threshold)
                    {
                       if (dE2[j][h] <0.0)
                         dw += eta*dE2[j][h];
                       if (dE2[j][h] < (mu/(1.0+mu))*pre_dE2[j][h])
                         dw += mu*delw2[j][h];
                       else 
        	         dw +=dE2[j][h]/(pre_dE2[j][h]-dE2[j][h])*delw2[j][h];
                    }
                    else 
                      dw += eta*dE2[j][h]+alpha*delw2[j][h];
		    weights[H*(I+1)+j*(H+1)+h] += dw;
                    pre_dE2[j][h] = dE2[j][h];
	            delw2[j][h] = dw;
		  }	    
	
         	for (h=0; h<H; h++)
          	  for (i=0; i<=I; i++)
    		  {
                    dw = 0.0;
               	    if (delw1[h][i] > threshold)
                    {
                       if (dE1[h][i] >0.0)
                         dw += eta*dE1[h][i];
                       if (dE1[h][i] > (mu/(1+mu))*pre_dE1[h][i])
                         dw += mu*delw1[h][i];
                       else 
        	         dw +=dE1[h][i]/(pre_dE1[h][i]-dE1[h][i])*delw1[h][i];
	  	    }
                    else if (delw1[h][i] < -threshold) 
                    {
                       if (dE1[h][i] >0.0)
                         dw += eta*dE1[h][i];
                       if (dE1[h][i] < (mu/(1+mu))*pre_dE1[h][i])
                         dw += mu*delw1[h][i];
                       else 
        	         dw +=dE1[h][i]/(pre_dE1[h][i]-dE1[h][i])*delw1[h][i];
                    }
                    else
                      dw += eta*dE1[h][i]+alpha*delw1[h][i];
		    weights[h*(I+1)+i] += dw;
		    pre_dE1[h][i] = dE1[h][i];
	            delw1[h][i] = dw;
		  }	    

		fprintf (fpError, "%lf\n", error[r]);

	        if (++r==10) r=0;
	 }

	 /* end processing */
	 printf ("Iteration %5d/%-5d  Error %lf minerr %lf\n",q-1,nIterations,error[r],minerr);
	 if (q-1 == NITERATIONS)
	   non_conv ++;
	 else
           converge += q-1;
         fprintf (stderr,"\n");
	 fprintf (fpError, "%lf\n", error[r]);
	 fclose(fpError);

  }
  end = time(NULL);
  printf ("\nElapsed time = %ld sec\n",(long)end - (long)start);

  printf ("The avg rate is %5.2lf, percentage of conv is %5.2lf\n\n",converge/(weightfile-non_conv),(double)(weightfile-non_conv)/weightfile*100);
  if ((fpWeightsOut = fopen(szWeightsOut, "w")) == NULL)
  {
	  fprintf(stderr, "can't write file %s\n", szWeightsOut);
	  exit(1);
  }

  for (h=0; h < H; h++)
	for (i=0; i <= I; i++)
	  fprintf(fpWeightsOut, "%9.6f%c", weights[h*(I+1)+i],
				(i == I) ? '\n':' ');
  for (j=0; j < J; j++)
	for (h=0; h <= H; h++)
	  fprintf(fpWeightsOut, "%9.6f%c", weights[H*(I+1)+j*(H+1)+h],
			  (h == H) ? '\n':' ');
  fclose(fpWeightsOut);

  if ((fpResults = fopen(szResults,"w")) == NULL)
  {
	 fprintf(stderr, "can't write file %s\n", szResults);
	 fpResults = stderr;
  }

  for (p=0; p<P; p++)
  {
	 fprintf(fpResults, "%d   ", p);

	 for (j=0; j < J; j++)
		fprintf(fpResults, " %lf", out2[p][j]);
	 fprintf (fpResults,"\n");
  }
  fclose(fpResults);
  return 0;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -