⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 facetrain.c

📁 Mitchell的《机器学习〉随书源码
💻 C
字号:
/* ****************************************************************** * HISTORY * 15-Oct-94  Jeff Shufelt (js), Carnegie Mellon University *      Prepared for 15-681, Fall 1994. * * Tue Oct  7 08:12:06 EDT 1997, bthom, added a few comments, *       tagged w/bthom * ****************************************************************** */#include <stdio.h>#include <math.h>#include <pgmimage.h>#include <backprop.h>extern char *strcpy();extern void exit();main(argc, argv)int argc;char *argv[];{  char netname[256], trainname[256], test1name[256], test2name[256];  IMAGELIST *trainlist, *test1list, *test2list;  int ind, epochs, seed, savedelta, list_errors;  seed = 102194;   /*** today's date seemed like a good default ***/  epochs = 100;  savedelta = 100;  list_errors = 0;  netname[0] = trainname[0] = test1name[0] = test2name[0] = '\0';  if (argc < 2) {    printusage(argv[0]);    exit (-1);  }  /*** Create imagelists ***/  trainlist = imgl_alloc();  test1list = imgl_alloc();  test2list = imgl_alloc();  /*** Scan command line ***/  for (ind = 1; ind < argc; ind++) {    /*** Parse switches ***/    if (argv[ind][0] == '-') {      switch (argv[ind][1]) {          case 'n': strcpy(netname, argv[++ind]);                  break;        case 'e': epochs = atoi(argv[++ind]);                  break;        case 's': seed = atoi(argv[++ind]);                  break;        case 'S': savedelta = atoi(argv[++ind]);                  break;        case 't': strcpy(trainname, argv[++ind]);                  break;        case '1': strcpy(test1name, argv[++ind]);                  break;        case '2': strcpy(test2name, argv[++ind]);                  break;        case 'T': list_errors = 1;	          epochs = 0;                  break;        default : printf("Unknown switch '%c'\n", argv[ind][1]);                  break;      }    }  }  /*** If any train, test1, or test2 sets have been specified, then       load them in. ***/  if (trainname[0] != '\0')     imgl_load_images_from_textfile(trainlist, trainname);  if (test1name[0] != '\0')     imgl_load_images_from_textfile(test1list, test1name);  if (test2name[0] != '\0')    imgl_load_images_from_textfile(test2list, test2name);  /*** If we haven't specified a network save file, we should... ***/  if (netname[0] == '\0') {    printf("%s: Must specify an output file, i.e., -n <network file>\n",     argv[0]);    exit (-1);  }  /*** Don't try to train if there's no training data ***/  if (trainname[0] == '\0') {    epochs = 0;  }  /*** Initialize the neural net package ***/  bpnn_initialize(seed);  /*** Show number of images in train, test1, test2 ***/  printf("%d images in training set\n", trainlist->n);  printf("%d images in test1 set\n", test1list->n);  printf("%d images in test2 set\n", test2list->n);  /*** If we've got at least one image to train on, go train the net ***/  backprop_face(trainlist, test1list, test2list, epochs, savedelta, netname,		list_errors);  exit(0);}backprop_face(trainlist, test1list, test2list, epochs, savedelta, netname,	      list_errors)IMAGELIST *trainlist, *test1list, *test2list;int epochs, savedelta, list_errors;char *netname;{  IMAGE *iimg;  BPNN *net;  int train_n, epoch, i, imgsize;  double out_err, hid_err, sumerr;  train_n = trainlist->n;  /*** Read network in if it exists, otherwise make one from scratch ***/  if ((net = bpnn_read(netname)) == NULL) {    if (train_n > 0) {      printf("Creating new network '%s'\n", netname);      iimg = trainlist->list[0];      imgsize = ROWS(iimg) * COLS(iimg);      /* bthom ===========================	make a net with:	  imgsize inputs, 4 hiden units, and 1 output unit          */      net = bpnn_create(imgsize, 4, 1);    } else {      printf("Need some images to train on, use -t\n");      return;    }  }  if (epochs > 0) {    printf("Training underway (going to %d epochs)\n", epochs);    printf("Will save network every %d epochs\n", savedelta);    fflush(stdout);  }  /*** Print out performance before any epochs have been completed. ***/  printf("0 0.0 ");  performance_on_imagelist(net, trainlist, 0);  performance_on_imagelist(net, test1list, 0);  performance_on_imagelist(net, test2list, 0);  printf("\n");  fflush(stdout);  if (list_errors) {    printf("\nFailed to classify the following images from the training set:\n");    performance_on_imagelist(net, trainlist, 1);    printf("\nFailed to classify the following images from the test set 1:\n");    performance_on_imagelist(net, test1list, 1);    printf("\nFailed to classify the following images from the test set 2:\n");    performance_on_imagelist(net, test2list, 1);  }  /************** Train it *****************************/  for (epoch = 1; epoch <= epochs; epoch++) {    printf("%d ", epoch);  fflush(stdout);    sumerr = 0.0;    for (i = 0; i < train_n; i++) {      /** Set up input units on net with image i **/      load_input_with_image(trainlist->list[i], net);      /** Set up target vector for image i **/      load_target(trainlist->list[i], net);      /** Run backprop, learning rate 0.3, momentum 0.3 **/      bpnn_train(net, 0.3, 0.3, &out_err, &hid_err);      sumerr += (out_err + hid_err);    }    printf("%g ", sumerr);    /*** Evaluate performance on train, test, test2, and print perf ***/    performance_on_imagelist(net, trainlist, 0);    performance_on_imagelist(net, test1list, 0);    performance_on_imagelist(net, test2list, 0);    printf("\n");  fflush(stdout);    /*** Save network every 'savedelta' epochs ***/    if (!(epoch % savedelta)) {      bpnn_save(net, netname);    }  }  printf("\n"); fflush(stdout);  /** Save the trained network **/  if (epochs > 0) {    bpnn_save(net, netname);  }}/*** Computes the performance of a net on the images in the imagelist. ***//*** Prints out the percentage correct on the image set, and the     average error between the target and the output units for the set. ***/performance_on_imagelist(net, il, list_errors)BPNN *net;IMAGELIST *il;int list_errors;{  double err, val;  int i, n, j, correct;  err = 0.0;  correct = 0;  n = il->n;  if (n > 0) {    for (i = 0; i < n; i++) {      /*** Load the image into the input layer. **/      load_input_with_image(il->list[i], net);      /*** Run the net on this input. **/      bpnn_feedforward(net);      /*** Set up the target vector for this image. **/      load_target(il->list[i], net);      /*** See if it got it right. ***/      if (evaluate_performance(net, &val, 0)) {        correct++;      } else if (list_errors) {	printf("%s - outputs ", NAME(il->list[i]));	for (j = 1; j <= net->output_n; j++) {	  printf("%.3f ", net->output_units[j]);	}	putchar('\n');      }      err += val;    }    err = err / (double) n;    if (!list_errors)      /* bthom==================================	 this line prints part of the ouput line	 discussed in section 3.1.2 of homework          */      printf("%g %g ", ((double) correct / (double) n) * 100.0, err);  } else {    if (!list_errors)      printf("0.0 0.0 ");  }}evaluate_performance(net, err)BPNN *net;double *err;{  double delta;  delta = net->target[1] - net->output_units[1];  *err = (0.5 * delta * delta);  /*** If the target unit is on... ***/  if (net->target[1] > 0.5) {    /*** If the output unit is on, then we correctly recognized me! ***/    if (net->output_units[1] > 0.5) {      return (1);    /*** otherwise, we didn't think it was me... ***/    } else {      return (0);    }  /*** Else, the target unit is off... ***/  } else {    /*** If the output unit is on, then we mistakenly thought it was me ***/    if (net->output_units[1] > 0.5) {      return (0);    /*** else, we correctly realized that it wasn't me ***/    } else {      return (1);    }  }}printusage(prog)char *prog;{  printf("USAGE: %s\n", prog);  printf("       -n <network file>\n");  printf("       [-e <number of epochs>]\n");  printf("       [-s <random number generator seed>]\n");  printf("       [-S <number of epochs between saves of network>]\n");  printf("       [-t <training set list>]\n");  printf("       [-1 <testing set 1 list>]\n");  printf("       [-2 <testing set 2 list>]\n");  printf("       [-T]\n");}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -