⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lvq_train.c

📁 有SOM、LVQ、ART三種不同方式的類神經網路可以參考的實例
💻 C
字号:
#include  <stdio.h>
#include  <stdlib.h>
#include  <math.h>
#include  <time.h>

#define  Ncycle  			20
#define  Ntrain  			4
#define  Ninp	  			2
#define  Nhid    			6
#define  Nout    			2
#define  eta_rate1 		0.9
#define  eta_rate2		0.9
#define  train_file  	"d:\\tc\\xor.tra"
#define  weight_file 	"d:\\tc\\xor.wei"
#define  mse_file    	"d:\\tc\\xor.mes"

float random_value(void);

void main(void){

 FILE  *fp1,*fp2,*fp3;
 float X[Ninp],T[Nout],Y[Nout];
 float W_xh[Ninp][Nhid];
 float sum,mse;
 int   Icycle,Itrain;
 int   i,j,h;
 int	 desired_class,hmin,compute_class;
 float min;
 float dist[Nhid];
 float eta1, eta2;

 eta1=0.5;
 eta2=0.5;

 /*----- open files -----*/
 fp1=fopen(train_file,"r");
 fp2=fopen(weight_file,"w");
 fp3=fopen(mse_file,"w");
 if(fp1==NULL)
 {
  puts("Train file not exist !!");
  getchar();
  exit(1);
 }

 /*----- initialize weights -----*/
 srand((int)time(0));

 for(h=0;h<Nhid;h++)
   for(i=0;i<Ninp;i++)
     W_xh[i][h]=random_value();

 /*----- start learning -----*/
 for(Icycle=0;Icycle<Ncycle;Icycle++)
 {
 mse=0.0;
 /* read input X and desired output T */
 fseek(fp1,0,0);
 for(Itrain=0;Itrain<Ntrain;Itrain++)
 {
  for(i=0;i<Ninp;i++)
    fscanf(fp1,"%f",&X[i]);
  for(j=0;j<Nout;j++)
    fscanf(fp1,"%f",&T[j]);

  /*----- find desired class -----*/
  for (j=0;j<Nout;j++)
    if (T[j]==1) desired_class=j;

 /*----- compute net[h] -----*/
 for(h=0;h<Nhid;h++)
 {
  sum=0.0;
  for(i=0;i<Ninp;i++)
   sum=sum+(X[i]-W_xh[i][h])*(X[i]-W_xh[i][h]);
  dist[h]=sqrt(sum);
 }

 /*----- find net[h*] -----*/
 min=1.0e+10;
 for (h=0;h<Nhid;h++)
 {
   if (dist[h] < min)
   {
     hmin=h;
     min=dist[h];
   }
 }

 /*----- compute Y -----*/
 for (j=0;j<Nout;j++)
   Y[j]=0.0;
 compute_class=hmin/(Nhid/Nout);
 Y[compute_class]=1.0;

 /*----- compute new W -----*/
 if (compute_class==desired_class)
 {
   for (i=0;i<Ninp;i++)
     W_xh[i][hmin]=W_xh[i][hmin]+eta1*(X[i]-W_xh[i][hmin]);
 }
 else
 {
   for (i=0;i<Ninp;i++)
     W_xh[i][hmin]=W_xh[i][hmin]-eta2*(X[i]-W_xh[i][hmin]);
 }

 /*----- compute the mean_square_error -----*/
 for(j=0;j<Nout;j++)
  mse+=(T[j]-Y[j])*(T[j]-Y[j]);
 }/*end of 1 learning cycle */

 /*---- update learning factor ------*/
 eta1=eta1*eta_rate1;
 if (eta1<0.1) eta1=0.1;
 eta2=eta2*eta_rate2;
 if (eta2<0.1) eta2=0.1;

 /* write the mse_value to mse_file */
 mse=mse/Ntrain;
 if((Icycle%10) == 9)
 {
  printf("\nIcycle=%3d mse=%-8.4f\n",Icycle,mse);
  fprintf(fp3,"%3d %-8.4f\n",Icycle,mse);
 }
}/* end of total learning cycle */

 /*----Write the weights to weight_file----*/
 printf("\n");
 for(h=0;h<Nhid;h++)
 {
	for(i=0;i<Ninp;i++)
	{
		printf("W_xh[%2d][%2d]=%-8.2f",i,h,W_xh[i][h]);
		fprintf(fp2,"%-8.2f",W_xh[i][h]);
	}
	printf("\n");
	fprintf(fp2,"\n");
 }
 printf("\n");

 /*----- close files -----*/
 fclose(fp1);
 fclose(fp2);
 fclose(fp3);
 getchar();
} /* end of the program */

float random_value(void)
{
	return((rand()/(RAND_MAX+1.0))-0.5);
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -