📄 lvq_train.c
字号:
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>
#define Ncycle 20
#define Ntrain 4
#define Ninp 2
#define Nhid 6
#define Nout 2
#define eta_rate1 0.9
#define eta_rate2 0.9
#define train_file "d:\\tc\\xor.tra"
#define weight_file "d:\\tc\\xor.wei"
#define mse_file "d:\\tc\\xor.mes"
float random_value(void);
void main(void){
FILE *fp1,*fp2,*fp3;
float X[Ninp],T[Nout],Y[Nout];
float W_xh[Ninp][Nhid];
float sum,mse;
int Icycle,Itrain;
int i,j,h;
int desired_class,hmin,compute_class;
float min;
float dist[Nhid];
float eta1, eta2;
eta1=0.5;
eta2=0.5;
/*----- open files -----*/
fp1=fopen(train_file,"r");
fp2=fopen(weight_file,"w");
fp3=fopen(mse_file,"w");
if(fp1==NULL)
{
puts("Train file not exist !!");
getchar();
exit(1);
}
/*----- initialize weights -----*/
srand((int)time(0));
for(h=0;h<Nhid;h++)
for(i=0;i<Ninp;i++)
W_xh[i][h]=random_value();
/*----- start learning -----*/
for(Icycle=0;Icycle<Ncycle;Icycle++)
{
mse=0.0;
/* read input X and desired output T */
fseek(fp1,0,0);
for(Itrain=0;Itrain<Ntrain;Itrain++)
{
for(i=0;i<Ninp;i++)
fscanf(fp1,"%f",&X[i]);
for(j=0;j<Nout;j++)
fscanf(fp1,"%f",&T[j]);
/*----- find desired class -----*/
for (j=0;j<Nout;j++)
if (T[j]==1) desired_class=j;
/*----- compute net[h] -----*/
for(h=0;h<Nhid;h++)
{
sum=0.0;
for(i=0;i<Ninp;i++)
sum=sum+(X[i]-W_xh[i][h])*(X[i]-W_xh[i][h]);
dist[h]=sqrt(sum);
}
/*----- find net[h*] -----*/
min=1.0e+10;
for (h=0;h<Nhid;h++)
{
if (dist[h] < min)
{
hmin=h;
min=dist[h];
}
}
/*----- compute Y -----*/
for (j=0;j<Nout;j++)
Y[j]=0.0;
compute_class=hmin/(Nhid/Nout);
Y[compute_class]=1.0;
/*----- compute new W -----*/
if (compute_class==desired_class)
{
for (i=0;i<Ninp;i++)
W_xh[i][hmin]=W_xh[i][hmin]+eta1*(X[i]-W_xh[i][hmin]);
}
else
{
for (i=0;i<Ninp;i++)
W_xh[i][hmin]=W_xh[i][hmin]-eta2*(X[i]-W_xh[i][hmin]);
}
/*----- compute the mean_square_error -----*/
for(j=0;j<Nout;j++)
mse+=(T[j]-Y[j])*(T[j]-Y[j]);
}/*end of 1 learning cycle */
/*---- update learning factor ------*/
eta1=eta1*eta_rate1;
if (eta1<0.1) eta1=0.1;
eta2=eta2*eta_rate2;
if (eta2<0.1) eta2=0.1;
/* write the mse_value to mse_file */
mse=mse/Ntrain;
if((Icycle%10) == 9)
{
printf("\nIcycle=%3d mse=%-8.4f\n",Icycle,mse);
fprintf(fp3,"%3d %-8.4f\n",Icycle,mse);
}
}/* end of total learning cycle */
/*----Write the weights to weight_file----*/
printf("\n");
for(h=0;h<Nhid;h++)
{
for(i=0;i<Ninp;i++)
{
printf("W_xh[%2d][%2d]=%-8.2f",i,h,W_xh[i][h]);
fprintf(fp2,"%-8.2f",W_xh[i][h]);
}
printf("\n");
fprintf(fp2,"\n");
}
printf("\n");
/*----- close files -----*/
fclose(fp1);
fclose(fp2);
fclose(fp3);
getchar();
} /* end of the program */
float random_value(void)
{
return((rand()/(RAND_MAX+1.0))-0.5);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -