📄 sarprop.c
字号:
/* SARPROP
- a standard modified bp algorithm for comparison
*/
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <ctype.h>
#include <string.h>
#include <time.h>
#define S 1 /* generalized BP with the S parameter */
#define ETA 0.1 /* initialize weight-update value */
#define ETA_BIG 1.2 /* parameter to increase step size */
#define ETA_SMALL 0.5 /* parameter to reduce step size */
#define MAX_D 50 /* max. value for dw */
#define MIN_D 0.000001 /* min. value for dw */
#define ERRORLEVEL 0.001 /* stopping criteria */
#define NITERATIONS 3000 /* no. of iterations to be run */
#define P 4 /* no. of patterns to be trained */
#define I 2 /* no. of input nodes */
#define H 2 /* no. of hidden nodes */
#define J 1 /* no. of output nodes */
#define N 9 /* no. of weights = H*(I+1)+J*(H+1) */
#define weightfile 30 /* no. of weight files used */
/* new part */
#define TEMP 0.08 /* SA Temperature parameter */
#define RAND_MAX 32767
typedef struct _chromq
{
double wts[N];
double err;
} Ctype;
double target[P][J], out0[P][I], out1[P][H], out2[P][J];
double weights[N],delta1[P][H],delta2[P][J];
double pre_dw1[H][I+1],pre_dw2[J][H+1];
double dE1[H][I+1],dE2[J][H+1];
double pre_dE1[H][I+1],pre_dE2[J][H+1];
double d1[H][I+1],d2[J][H+1];
double pre_d1[H][I+1],pre_d2[J][H+1];
FILE *fpRun, *fpPattern, *fpWts;
FILE *fpWeightsOut, *fpResults, *fpError;
void itoa(n, s) /* convert integer to character */
int n; char s[];
{
int i=0;
if (n/10 ==0)
s[i++]= n +'0';
else
{
s[i++] = (n/10)+'0';
s[i++] = (n%10)+'0';
}
s[i] = '\0';
}
double minimum(a, b)
double a, b;
{
if (a < b)
return a;
else
return b;
}
double maximum(a, b)
double a, b;
{
if (a > b)
return a;
else
return b;
}
double sign(a)
double a;
{
if (a < 0.0)
return -1.0;
else
return 1.0;
}
main (argc, argv)
int argc;
char *argv[];
{
double eta,alpha;
double eta_big=ETA_BIG,eta_small=ETA_SMALL,max_d=MAX_D,min_d=MIN_D;
double error[20],derror,temp,temp1,sum,dw;
double converge=0.0;
register int h,i,j,p,q,r,l,x,min,tmp;
int nIterations=NITERATIONS;
unsigned steady=0,non_conv=0;
char szResults[66],szError[66],szPattern[66],szWeightsOut[66];
char charstr[12],tmpstr[3];
double optwts[N], minerr;
time_t t,start,end;
/* new part */
double temperature; /* temp = temperature */
double sa; /* sa = SA factor */
double random_no; /* random number */
int index; /* index */
double sa_value; /* SA value */
int seed; /* seed */
int offset; /* a value to select a set of weight files */
t = time(NULL); /* randomize the seed for each run */
/* tmp = srand(t);
*/
seed = (int) atoi(argv[3]);
srand(seed);
if (argc < 2)
{
fprintf(stderr, "Usage: %s runfilename\n", argv[0]);
exit(1);
}
if ((fpRun = fopen(*++argv,"r")) == NULL)
{
fprintf(stderr, "can't open file %s\n", *argv);
exit(1);
}
offset = atoi(argv[1]);
temperature = TEMP;
fscanf(fpRun, "%s %s %s %s %lf %lf",
szResults, szError, szPattern, szWeightsOut, &eta, &alpha);
fclose(fpRun);
if ((fpPattern = fopen(szPattern, "r")) == NULL)
{
fprintf(stderr, "can't open file %s\n", szPattern);
exit(1);
}
for (p=0; p<P; p++)
{
for (i=0; i<I; i++)
fscanf(fpPattern, "%lf", &out0[p][i]);
for (j=0; j<J; j++)
fscanf(fpPattern, "%lf", &target[p][j]);
}
fclose(fpPattern);
if ((fpError = fopen(szError, "w")) == NULL)
{
fprintf(stderr, "can't open file %s \n", szError);
exit(1);
}
start = time(NULL);
for (x=offset; x<weightfile+offset; x++)
{
minerr = 99999999.0; steady=0;
strcpy(charstr,"w");
itoa(x,tmpstr);
strcat(charstr,tmpstr);
strcat(charstr,".wts");
if ((fpWts = fopen(charstr,"r")) == NULL)
{
fprintf (stderr, "can't open wts file\n");
fprintf (stderr, "%s\n", charstr);
exit(1);
}
for (h=0; h<H; h++)
for (i=0; i<=I; i++)
{
fscanf (fpWts, "%lf", &weights[h*(I+1)+i]);
pre_dw1[h][i] = 0.0;
pre_dE1[h][i]=0.0;
d1[h][i]=ETA;
}
for (j=0; j<J; j++)
for (h=0; h<=H; h++)
{
fscanf(fpWts, "%lf", &weights[H*(I+1)+j*(H+1)+h]);
pre_dw2[j][h] = 0.0;
pre_dE2[j][h]=0.0;
d2[j][h]=ETA;
}
fclose(fpWts);
/* begin processing */
for (q=0; q <= nIterations; q++)
{
/* calculate feed-forward net */
for (p=0; p<P; p++)
{
for (h=0; h<H; h++)
{
sum = weights[h*(I+1)+I];
for (i=0; i< I; i++)
sum += weights[h*(I+1)+i] * out0[p][i];
out1[p][h] = 1.0 / (1.0 + exp(-sum));
}
for (j=0; j<J; j++)
{
sum = weights[H*(I+1)+j*(H+1)+H];
for (h=0; h< H; h++)
sum += weights[H*(I+1)+j*(H+1)+h] * out1[p][h];
out2[p][j] = 1.0 / (1.0 + exp(-sum));
}
/* calculate error signals */
for (j=0; j<J; j++)
{
temp = out2[p][j]-target[p][j];
temp1= pow(out2[p][j]*(1.0-out2[p][j]),1.0/S);
delta2[p][j] = temp * temp1;
}
for (h=0; h<H; h++)
{
sum = 0.0;
temp1= pow(out1[p][h]*(1.0-out1[p][h]),1.0/S);
for (j=0; j<J; j++)
sum += delta2[p][j]*weights[H*(I+1)+j*(H+1)+h];
delta1[p][h] = sum * temp1;
}
}
/* calculate system error */
if (q==0) r=0;
for (p=0, error[r]=0.0; p<P; p++)
{
for (j=0; j<J; j++)
{
temp = out2[p][j]-target[p][j];
error[r] += temp * temp;
}
}
error[r] /= (P * J);
if (error[r] < ERRORLEVEL)
break;
if (error[r] < minerr)
{
minerr = error[r];
for (l=0; l<N; l++) optwts[l] = weights[l];
}
fprintf (stderr,"Iteration %5d/%-5d Error %lf minerr %lf\r",
q, nIterations, error[r], minerr);
/* calculate rate of change of error with respect to weights */
sa = pow(2.0, -1.0 * q * temperature);
for (j=0; j<J; j++)
{
dE2[j][H] = 0.0;
for (p=0; p<P; p++)
dE2[j][H] += delta2[p][j];
index = H * (I + 1) + j * (H + 1) + H;
sa_value = 0.01 * sa * weights[index];
sa_value /= 1 + weights[index] * weights[index];
dE2[j][H] -= sa_value;
for (h=0; h<H; h++)
{
dE2[j][h] = 0.0;
for (p=0; p<P; p++)
dE2[j][h] += delta2[p][j] * out1[p][h];
index = H * (I + 1) + j * (H + 1) + h;
sa_value = 0.01 * sa * weights[index];
sa_value /= 1 + weights[index] * weights[index];
dE2[j][h] -= sa_value;
}
}
for (h=0; h<H; h++)
{
dE1[h][I] = 0.0;
for (p=0; p<P; p++)
dE1[h][I] += delta1[p][h];
index = h * (I + 1) + I;
sa_value = 0.01 * sa * weights[index];
sa_value /= 1 + weights[index] * weights[index];
dE1[h][I] -= sa_value;
for (i=0; i<I; i++)
{
dE1[h][i] = 0.0;
for (p=0; p < P; p++)
dE1[h][i] += delta1[p][h] * out0[p][i];
index = h * (I + 1) + i;
sa_value = 0.01 * sa * weights[index];
sa_value /= 1 + weights[index] * weights[index];
dE1[h][i] -= sa_value;
}
}
/* calculate weight update rule */
for (j=0; j<J; j++)
for (h=0; h<=H; h++)
{
if (dE2[j][h]*pre_dE2[j][h] > 0.0)
{
d2[j][h]=minimum(pre_d2[j][h]*eta_big,max_d);
dw = -sign(dE2[j][h])*d2[j][h];
weights[H*(I+1)+j*(H+1)+h] += dw;
pre_dE2[j][h] = dE2[j][h];
}
else if (dE2[j][h]*pre_dE2[j][h] < 0.0)
{
if (pre_d2[j][h] < 0.4 * sa * sa)
{
random_no = (double) rand() / RAND_MAX;
d2[j][h] = pre_d2[j][h] * eta_small
+ 0.8 * random_no * sa * sa;
}
else
d2[j][h] = pre_d2[j][h] * eta_small;
d2[j][h]=maximum(d2[j][h],min_d);
/* dw = -pre_dw2[j][h];
weights[H*(I+1)+j*(H+1)+h] += dw;
*/ pre_dE2[j][h] = 0.0;
}
else /* dE2 * pre_dE2 == 0.0 */
{
dw = -sign(dE2[j][h])*d2[j][h];
weights[H*(I+1)+j*(H+1)+h] += dw;
pre_dE2[j][h] = dE2[j][h];
}
pre_d2[j][h] = d2[j][h];
pre_dw2[j][h] = dw;
}
for (h=0; h<H; h++)
for (i=0; i<=I; i++)
{
if (dE1[h][i]*pre_dE1[h][i] > 0.0)
{
d1[h][i]=minimum(pre_d1[h][i]*eta_big,max_d);
dw = -sign(dE1[h][i])*d1[h][i];
weights[h*(I+1)+i] += dw;
pre_dE1[h][i] = dE1[h][i];
}
else if (dE1[h][i]*pre_dE1[h][i] < 0.0)
{
if (pre_d1[h][i] < 0.4 * sa * sa)
{
random_no = (double) rand() / RAND_MAX;
d1[h][i] = pre_d1[h][i] * eta_small
+ 0.8 * random_no * sa * sa;
}
else
d1[h][i] = pre_d1[h][i] * eta_small;
d1[h][i]=maximum(d1[h][i],min_d);
/* dw = -pre_dw1[h][i];
weights[h*(I+1)+i] += dw;
*/ pre_dE1[h][i] = 0.0;
}
else /* dE1 * pre_dE1 == 0.0 */
{
dw = -sign(dE1[h][i])*d1[h][i];
weights[h*(I+1)+i] += dw;
pre_dE1[h][i] = dE1[h][i];
}
pre_d1[h][i] = d1[h][i];
pre_dw1[h][i] = dw;
}
fprintf (fpError, "%lf\n", error[r]);
if (++r==10) r=0;
}
/* end processing */
printf ("Iteration %5d/%-5d Error %lf minerr %lf\n",q-1,nIterations,error[r],minerr);
if (q-1 == NITERATIONS)
non_conv ++;
else
converge += q-1;
fprintf (stderr,"\n");
fprintf (fpError, "%lf\n", error[r]);
fclose(fpError);
}
end = time(NULL);
printf ("\nElapsed time = %ld sec\n",(long)end - (long)start);
printf ("The avg rate is %5.2lf, percentage of conv is %5.2lf\n\n",converge/(weightfile-non_conv),(double)(weightfile-non_conv)/weightfile*100);
if ((fpWeightsOut = fopen(szWeightsOut, "w")) == NULL)
{
fprintf(stderr, "can't write file %s\n", szWeightsOut);
exit(1);
}
for (h=0; h < H; h++)
for (i=0; i <= I; i++)
fprintf(fpWeightsOut, "%9.6f%c", weights[h*(I+1)+i],
(i == I) ? '\n':' ');
for (j=0; j < J; j++)
for (h=0; h <= H; h++)
fprintf(fpWeightsOut, "%9.6f%c", weights[H*(I+1)+j*(H+1)+h],
(h == H) ? '\n':' ');
fclose(fpWeightsOut);
if ((fpResults = fopen(szResults,"w")) == NULL)
{
fprintf(stderr, "can't write file %s\n", szResults);
fpResults = stderr;
}
for (p=0; p<P; p++)
{
fprintf(fpResults, "%d ", p);
for (j=0; j < J; j++)
fprintf(fpResults, " %lf", out2[p][j]);
fprintf (fpResults,"\n");
}
fclose(fpResults);
return 0;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -