📄 bp_with_ga.cc
字号:
/* back-propagation algorithm
- a modified algorithm to embed GA into BP */
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <conio.h>
#include <ctype.h>
#include <string.h>
#include <time.h>
#define ERRORLEVEL 0.001 /* stopping criteria */
#define NITERATIONS 3000 /* no. of iterations to be run */
#define P 32 /* no. of patterns to be trained */
#define I 5 /* no. of input nodes */
#define H 12 /* no. of hidden nodes */
#define J 6 /* no. of output nodes */
#define POPSIZE 6 /* population size for GA */
#define N 150 /* no. of weights = H*(I+1)+J*(H+1) */
typedef struct _chrom
{
float wts[N];
float err;
} Ctype;
float target[P][J], out0[P][I], out1[P][H], out2[P][J];
float weights[N], delta1[P][H], delta2[P][J];
float delw1[H][I+1], delw2[J][H+1];
void generate_pop (float Pm, float *weights, Ctype *chrom)
{
int l, n;
for (l=0; l<N; l++)
chrom[0].wts[l] = weights[l];
for (n=1; n<POPSIZE; n++)
{
for (l=0; l<H*(I+1); l++)
chrom[n].wts[l] = weights[l];
/* perturb by [-0.5..+0.5] */
for (l=H*(I+1); l<N; l++)
if (Pm > (rand()/(float)32767))
chrom[n].wts[l] = weights[l] + (rand()/(float)32767-0.5);
else
chrom[n].wts[l] = weights[l];
}
}
int find_min (Ctype *chrom)
{
int m,M1;
float Min;
Min = chrom[0].err;
M1 = 0;
for (m=1; m<POPSIZE; m++)
if (chrom[m].err < Min)
{
Min = chrom[m].err;
M1 = m;
}
return M1;
}
void fitness (Ctype *chrom, float current_err)
{
register int i,h,j,p,n;
float sum, temp;
chrom[0].err = current_err;
for (n=1; n<POPSIZE; n++)
{
for (p=0, chrom[n].err=0.0; p<P; p++)
{
for (h=0; h<H; h++)
{
sum = chrom[n].wts[(h*(I+1))+I];
for (i=0; i<I; i++)
sum += chrom[n].wts[(h*(I+1))+i] * out0[p][i];
out1[p][h] = 1.0 / (1.0 + exp(-sum));
}
for (j=0; j<J; j++)
{
sum = chrom[n].wts[(H*(I+1))+(j*(H+1))+H];
for (h=0; h<H; h++)
sum += chrom[n].wts[(H*(I+1))+(j*(H+1))+h] * out1[p][h];
out2[p][j] = 1.0 / (1.0 + exp(-sum));
temp = target[p][j] - out2[p][j];
chrom[n].err += temp * temp;
}
}
chrom[n].err /= (float)(P * J);
}
}
void main (int argc, char *argv[])
{
Ctype chrom[POPSIZE];
float eta, alpha, Pm;
float ErrorLevel = ERRORLEVEL;
float error[10], derror, temp, sum, dw;
register int h, i, j, p, q, r, l, min;
int nIterations=NITERATIONS,count=0;
FILE *fpRun, *fpPattern, *fpWeights;
FILE *fpWeightsOut, *fpResults, *fpError;
char szResults[66], szError[66], szPattern[66];
char szWeights[66], szWeightsOut[66];
clock_t start, end;
randomize();
if (argc < 2)
{
fprintf(stderr, "Usage: %s runfilename\n", argv[0]);
exit(1);
}
if ((fpRun = fopen(*++argv,"r")) == NULL)
{
fprintf(stderr, "can't open file %s\n", *argv);
exit(1);
}
fscanf(fpRun, "%s %s %s %s %s %f %f %f",
szResults, szError, szPattern, szWeights, szWeightsOut,
&eta, &alpha, &Pm);
fclose(fpRun);
if ((fpWeights = fopen(szWeights,"r")) == NULL)
{
fprintf(stderr, "can't open file %s\n", szWeights);
exit(1);
}
for (h=0; h<H; h++)
for (i=0; i<=I; i++)
{
fscanf (fpWeights, "%f", &weights[h*(I+1)+i]);
delw1[h][i] = 0.0;
}
for (j=0; j<J; j++)
for (h=0; h<=H; h++)
{
fscanf(fpWeights, "%f", &weights[H*(I+1)+j*(H+1)+h]);
delw2[j][h] = 0.0;
}
fclose(fpWeights);
if ((fpPattern = fopen(szPattern, "r")) == NULL)
{
fprintf(stderr, "can't open file %s\n", szPattern);
exit(1);
}
for (p=0; p<P; p++)
{
for (i=0; i<I; i++)
fscanf(fpPattern, "%f", &out0[p][i]);
for (j=0; j<J; j++)
fscanf(fpPattern, "%f", &target[p][j]);
}
fclose(fpPattern);
if ((fpError = fopen(szError, "w")) == NULL)
{
fprintf(stderr, "can't open file %s \n", szError);
exit(1);
}
/* begin processing */
start = clock();
for (q=0; q <= nIterations; q++)
{
/* calculate feed-forward net */
for (p=0; p<P; p++)
{
for (h=0; h<H; h++)
{
sum = weights[h*(I+1)+I];
for (i=0; i< I; i++)
sum += weights[h*(I+1)+i] * out0[p][i];
out1[p][h] = 1.0 / (1.0 + exp(-sum));
}
for (j=0; j<J; j++)
{
sum = weights[H*(I+1)+j*(H+1)+H];
for (h=0; h< H; h++)
sum += weights[H*(I+1)+j*(H+1)+h] * out1[p][h];
out2[p][j] = 1.0 / (1.0 + exp(-sum));
}
/* calculate error signals */
for (j=0; j<J; j++)
delta2[p][j] = (target[p][j] - out2[p][j]) *
out2[p][j] * (1.0 - out2[p][j]);
for (h=0; h<H; h++)
{
sum = 0.0;
for (j=0; j<J; j++)
sum += delta2[p][j] * weights[H*(I+1)+j*(H+1)+h];
delta1[p][h] = sum * out1[p][h] * (1.0 - out1[p][h]);
}
}
/* calculate system error */
if (q==0) r=0;
for (p=0, error[r]=0.0; p<P; p++)
{
for (j=0; j<J; j++)
{
temp = target[p][j] - out2[p][j];
error[r] += temp * temp;
}
}
error[r] /= (P * J);
fprintf (stderr, "Iteration %5d/%-5d Error %f\r",
q, nIterations, error[r]);
fprintf (fpError, "%f\n", error[r]);
if (error[r] < ErrorLevel)
break;
/* calculate backward net and update weights */
for (j=0; j<J; j++)
{
sum = 0.0;
for (p=0; p<P; p++)
sum += delta2[p][j];
dw = eta * sum + alpha * delw2[j][H];
weights[H*(I+1)+j*(H+1)+H] += dw;
delw2[j][H] = dw;
for (h=0; h<H; h++)
{
sum = 0.0;
for (p=0; p<P; p++)
sum += delta2[p][j] * out1[p][h];
dw = eta * sum + alpha * delw2[j][h];
weights[H*(I+1)+j*(H+1)+h] += dw;
delw2[j][h] = dw;
}
}
for (h=0; h<H; h++)
{
sum = 0.0;
for (p=0; p<P; p++)
sum += delta1[p][h];
dw = eta * sum + alpha * delw1[h][I];
weights[h*(I+1)+I] += dw;
delw1[h][I] = dw;
for (i=0; i<I; i++)
{
sum = 0.0;
for (p=0; p < P; p++)
sum += delta1[p][h] * out0[p][i];
dw = eta * sum + alpha * delw1[h][i];
weights[h*(I+1)+i] += dw;
delw1[h][i] = dw;
}
}
/* start mutation if the rate of change of error is less than
the rate_threshold */
if (q>=10)
{
derror = (error[r]-error[(r+1)%10])/9;
if ((derror<=0.0) && (derror>-0.0003) && (error[r]>ERRORLEVEL*5))
{
count++;
generate_pop (Pm, weights, chrom);
fitness (chrom,error[r]);
min = find_min (chrom);
for (l=0; l<N; l++)
weights[l] = chrom[min].wts[l];
}
}
if (++r==10) r=0;
}
/* end processing */
end = clock();
printf ("Iteration %5d/%-5d Error %lf\n",q,nIterations,error[r]);
// printf ("\nElapsed time = %f, count=%d\n",(end-start)/CLK_TCK,count);
fprintf(stderr, "\n");
fclose(fpError);
if ((fpWeightsOut = fopen(szWeightsOut, "w")) == NULL)
{
fprintf(stderr, "can't write file %s\n", szWeightsOut);
exit(1);
}
for (h=0; h < H; h++)
for (i=0; i <= I; i++)
fprintf(fpWeightsOut, "%9.6f%c", weights[h*(I+1)+i],
(i == I) ? '\n':' ');
for (j=0; j < J; j++)
for (h=0; h <= H; h++)
fprintf(fpWeightsOut, "%9.6f%c", weights[H*(I+1)+j*(H+1)+j],
(h == H) ? '\n':' ');
fclose(fpWeightsOut);
if ((fpResults = fopen(szResults,"w")) == NULL)
{
fprintf(stderr, "can't write file %s\n", szResults);
fpResults = stderr;
}
for (p=0; p<P; p++)
{
fprintf(fpResults, "%d ", p);
for (j=0; j < J; j++)
fprintf(fpResults, " %f", out2[p][j]);
fprintf (fpResults,"\n");
}
fclose(fpResults);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -