scg.c

来自「NIST Handwriting OCR Testbed」· C语言 代码 · 共 359 行

C
359
字号
/* [Check whether some of the loops in here should be replaced byvectorish blas calls; already uses snrm2 and sdot, and maybe shouldnow use saxpy or other blas routines.] *//*# proc: scg - Uses a Scaled Conjugate Gradients (SCG) algorithm to train# proc:       (optimize) the MLP, optionally performing Boltzmann pruning during# proc:       training.Input args:  do_confuse: If TRUE, will compute, and write to stderr and to the    short outfile, the confusion matrices for the network at the    end of the scg training run.  do_long_outfile: If TRUE, will produce long_outfile at the end of    the scg training run.  long_outfile: (Used only if do_long_outfile is TRUE.)  Filename of    the long outfile to be produced.  show_acs_times_1000: (Used only if do_long_outfile is TRUE.)    Passed to e_and_g(); see comment in e_and_g.c.  do_cvr: If TRUE, will compute, and write to stderr and to the short    outfile, a correct-vs.-rejected table.  niter_max: Maximum number of training iterations allowed.  ninps, nhids, nouts: Numbers of input, hidden, and output nodes.  npats: Number of patterns.  featvecs: Feature vectors of the patterns, an npats by ninps matrix.  use_targvecs: If TRUE, parm targvecs below is used; if FALSE, parm    classes below is used.  Note that if errfunc != MSE, use_targvecs    must be FALSE.  targvecs: Target vectors, an npats by nouts matrix; used if    use_targvecs is TRUE.  These must be used if the mlp is to be a    function FITTER (not CLASSIFIER).  (If use_targvecs is FALSE, just    set this to (float *)NULL.)  classes: Classes of the patterns, an array of npats shorts;    used if use_targvecs is FALSE.  (If use_targvecs is TRUE, just set    this to (short *)NULL.)  acfunc_and_deriv_hids: A function that computes the activation    function to be used on the hidden nodes, and its derivative.    This should be a void-returning function that takes three args:    the input value (float), the output activation function value    (float pointer), and the output activation function derivative    value (float pointer).  acfunc_and_deriv_outs: Like acfunc_and_deriv_hids, but for the    output nodes.  errfunc: Type of function used to compute the error contribution    of a pattern from the activations vector and either the target    vector or the actual class.  Must be one of the following    (defined in parms.h):    MSE: mean-squared error between activations and targets, or its      equivalent using actual class instead of targets.    TYPE_1: error function of type 1, using parm alpha (below).  (If      this is used, use_targvecs must be FALSE.)    POS_SUM: positive sum.  (If this is used, use_targvecs must be      FALSE.)  alpha: A parm that must be set if errfunc is TYPE_1.  (If errfunc is    not TYPE_1, set value 0. for this parm.)  patwts: Pattern-weights.  regfac: Regularization factor.  The regularization term of the error    is this factor times half the average of the squares of the    weights.  pct: A threshold for pctmin, such that if pctmin >= pct then    the routine returns (under some conditions).  boltzmann: Decides whether to use Boltzmann pruning, and if    so, what kind of threshold to use (see boltz.c).  Must be one of:      NO_PRUNE: Do not prune.      ABS_PRUNE: Prune using threshold exp(-|wt|/temperature),        where wt is a weight being considered for pruning.      SQUARE_PRUNE: Prune using threshold exp(-wt^t/temperature).  temperature: For Boltzmann pruning.  (Not used if boltzmann    is NO_PRUNE.)  nfreq:  egoal: If *rmserr becomes < egoal, the routine returns.  gwgoal: If size(g) / size(w) becomes < gwgoal, the routine returns.  oklvl: Threshold for rejection.  purpose, long_classnames, short_classnames: Passed to accum_print.Input/output args:  w: The network weights, in this order:    1st-layer weights (nhids by ninps "matrix", row-major);    1st-layer biases (nhids elts);    2nd-layer weights (nouts by nhids "matrix", row-major);    2nd-layer biases (nouts elts).  ncalls: A counter which this routine increments each time it calls    e_and_g(), which computes the error and its gradient.Scratch-buffer arg:  g: For holding the gradient of the error w.r.t the network weights.    Caller is to provide this buffer, allocated to (at least) as many    floats as there are weights, i.e. nhids * (ninps + 1) + nouts *    (nhids + 1) floats.  This buffer's contents upon entry to this    routine do not matter, and its contents after return are not    intended to be used.Output args:  rmserr: Error value.  gw: Size(gradient) / size(weights).  iter: How many iterations were used.  ierr: Error code. [Show what the values mean; also, better yet,    define names for the values in an scg.h.  And perhaps it would    be better to let this integer code be the function value, instead    of an arg.]*/#include <stdio.h>#include <math.h>#include <mlp/blas.h>#include <mlp/defs.h>#include <mlp/macros.h>#include <mlp/parms.h>#include <mlp/scg.h>voidscg(do_confuse, do_long_outfile, long_outfile, show_acs_times_1000,  do_cvr, niter_max, ninps, nhids, nouts, npats, featvecs,  use_targvecs, targvecs, classes, acfunc_and_deriv_hids,  acfunc_and_deriv_outs, errfunc, alpha, patwts, regfac, pct,  boltzmann, temperature, nfreq, egoal, gwgoal, oklvl, purpose,  long_classnames, short_classnames, w, g, rmserr, gw, iter, ncalls,  ierr)char do_confuse, do_long_outfile, long_outfile[], show_acs_times_1000,  do_cvr, use_targvecs, errfunc, boltzmann, purpose,  **long_classnames, **short_classnames;int niter_max, ninps, nhids, nouts, npats, nfreq, *iter, *ncalls,  *ierr;short *classes;float *featvecs, *targvecs, alpha, *patwts, regfac, pct, temperature,  egoal, gwgoal, oklvl, *w, *g, *rmserr, *gw;void (*acfunc_and_deriv_hids)(), (*acfunc_and_deriv_outs)();{  char str[100];  int numwts, i, k,    icount /* number of steps since last restart */,    fcount /* number of consecutive failures */,    check, success, kmin, iover, junkint0, junkint1;  static int i1 = 1;  float pctmin, deltak, err, enew, xl, xlb, wsiz, sigma, psiz, psq,    sigmak, c, xmu, alphak, delta, beta, gsiz, e1, e2, junkfloat,    *wnew, /* new weights */    *p,    /* direction vector */    *r,    /* remembered-g */    *s;    /* second deriv. info along p direction */  numwts = nhids * (ninps + 1) + nouts * (nhids + 1);  kmin = max(NF * nfreq, NITER);  if(boltzmann != NO_PRUNE && temperature > 0.)    kmin = max(kmin, NBOLTZ);  if(niter_max > 0) {    sprintf(str, "\n SCG: doing <= %d iterations; %d variables.\n\n",      niter_max, numwts);    fsaso(str);  }  if(boltzmann != NO_PRUNE)    boltz(ninps, nhids, nouts, boltzmann, temperature, w);  /* Get inital error and gradient. */  e_and_g(TRUE, TRUE, FALSE, FALSE, (char *)NULL, FALSE, FALSE, ninps,    nhids, nouts, w, npats, featvecs, use_targvecs, targvecs, classes,    acfunc_and_deriv_hids, acfunc_and_deriv_outs, errfunc, alpha,    patwts, regfac, oklvl, &err, g, &e1, &e2);  (*ncalls)++;  wsiz = snrm2_(&numwts, w, &i1);  *iter = 0;  optchk_store_e1_e2(e1, e2);  optchk(FALSE, 0, w, err, ierr, &pctmin);  *rmserr = sqrt((double)(err * 2.));  if((wnew = (float *)malloc(i = numwts * sizeof(float))) ==    (float *)NULL)    syserr("scg", "malloc", "wnew");  if((p = (float *)malloc(i)) == (float *)NULL)    syserr("scg", "malloc", "p");  if((r = (float *)malloc(i)) == (float *)NULL)    syserr("scg", "malloc", "r");  if((s = (float *)malloc(i)) == (float *)NULL)    syserr("scg", "malloc", "s");  *ierr = 1;  sigma = 1.e-4; /* relative distance for numerical derivative */  xl = XLSTART; /* lambda_k */  xlb = 0.; /* lambda_k bar */  deltak = 0.;  success = TRUE;  /* rmsold = *rmserr */  for(i = 0; i < numwts; i++)    p[i] = r[i] = -g[i];  iover = min(numwts, NRESTART); /* how often to restart the                                 algorithm */  icount = 0; /* number of iterations since last restart */  fcount = 0; /* number of failed iterations in a row */  *iter = niter_max;  k = 0;  /* notimp = 0; (fortran original has this commented out) */  while(k < niter_max) {    icount++;    psiz = snrm2_(&numwts, p, &i1);    psq = psiz * psiz;    if(success) {      sigmak = sigma * wsiz / psiz;      for(i = 0; i < numwts; i++)	wnew[i] = w[i] + sigmak * p[i];      /* Get error and gradient 2nd derivative information. */      e_and_g(TRUE, TRUE, FALSE, FALSE, (char *)NULL, FALSE, FALSE,        ninps, nhids, nouts, wnew, npats, featvecs, use_targvecs,        targvecs, classes, acfunc_and_deriv_hids,        acfunc_and_deriv_outs, errfunc, alpha, patwts, regfac, oklvl,        &enew, s, &e1, &e2);      (*ncalls)++;      /* dE/d(dist) along p is (enew-err)/sigmak. */      for(i = 0; i < numwts; i++)	s[i] = (s[i] - g[i]) / sigmak;      deltak = sdot_(&numwts, s, &i1, p, &i1);    }    c = xl - xlb;    if(c != 0.) {      for(i = 0; i < numwts; i++)	s[i] += c * p[i];      deltak += c * psq;    }    /* Maybe need to make "Hessian" positive definite. */    if(deltak <= 0) {      c = xl - 2. * deltak / psq;      for(i = 0; i < numwts; i++)	s[i] += c * p[i];      xlb = 2. * (xl - deltak / psq);      deltak = -deltak + xl * psq;      xl = xlb;    }    /* Get the right step size. */    xmu = sdot_(&numwts, p, &i1, r, &i1);    alphak = xmu / deltak;    for(i = 0; i < numwts; i++)      wnew[i] = w[i] + alphak * p[i];    /* Get new error and gradient. */    e_and_g(TRUE, TRUE, FALSE, FALSE, (char *)NULL, FALSE, FALSE,      ninps, nhids, nouts, wnew, npats, featvecs, use_targvecs,      targvecs, classes, acfunc_and_deriv_hids, acfunc_and_deriv_outs,      errfunc, alpha, patwts, regfac, oklvl, &enew, g, &e1, &e2);    (*ncalls)++;    delta = 2. * deltak * (err - enew) / (xmu * xmu);    if(delta >= 0.) {      k++;      fcount = 0;      memcpy((char *)w, (char *)wnew, numwts * sizeof(float));      if(boltzmann != NO_PRUNE) {	boltz(ninps, nhids, nouts, boltzmann, temperature, w);	e_and_g(TRUE, TRUE, FALSE, FALSE, (char *)NULL, FALSE, FALSE,          ninps, nhids, nouts, w, npats, featvecs, use_targvecs,          targvecs, classes, acfunc_and_deriv_hids,          acfunc_and_deriv_outs, errfunc, alpha, patwts, regfac,          oklvl, &enew, g, &e1, &e2);	(*ncalls)++;      }      wsiz = snrm2_(&numwts, w, &i1);      err = enew;      xlb = 0;      success = check = TRUE;      if(icount % iover == 0) /* restart */	for(i = 0; i < numwts; i++)	  p[i] = -g[i];      else { /* find conjugate direction */	beta = (sdot_(&numwts, g, &i1, g, &i1) +          sdot_(&numwts, g, &i1, r, &i1)) / xmu;	for(i = 0; i < numwts; i++)	  p[i] = -g[i] + beta * p[i];      }      for(i = 0; i < numwts; i++)	r[i] = -g[i];      if(delta >= 0.75) /* trustworthy */	xl /= 2; /* maybe try something else */    }    else {      xlb = xl;      success = check = FALSE;      fcount++;      if(fcount > 2)	if(icount > fcount) { /* At least one good step since                              restart. */	  sprintf(str, " restart SCG %4d\n", k);	  fsaso(str);	  for(i = 0; i < numwts; i++)	    p[i] = -g[i];	  xl = XLSTART;	  xlb = 0.;	  success = TRUE;	  delta = 1.;	  icount = fcount = 0;	}	else {	  *ierr = 3;	  *iter = k;	  break;	}    }    /* If not nearly as good as predicted, increase xl. */    if(delta < 0.25)      xl *= 4.;    /* maybe try xl = xl + deltak * (1 - delta) / psq */    *rmserr = sqrt((double)(err * 2.));    gsiz = snrm2_(&numwts, r, &i1);    if(check) {      optchk_store_e1_e2(e1, e2);      optchk(FALSE, k, w, err, ierr, &pctmin);      if(*ierr != 0 || pctmin >= pct) {	*iter = k;	break;      }    }    /* Terminate when error satisfactory. */    if(*rmserr < egoal) {      *ierr = 0;      *iter = k;      break;    }    /* Terminate when gradient is too small. */    if(gsiz < gwgoal * max(1., wsiz)) {      *ierr = 2;      *iter = k;      break;    }  } /* while(k < niter_max) */  if(k >= niter_max)    *ierr = 1;  *gw = gsiz / wsiz;  /* Do another call of this with the (same) final weights, to  accumulate at least the minimal counting information, which is  always needed; and the switches also may activate optional  computations: confusion matrices, long outfile, and  correct-vs.-rejected table. */  e_and_g(FALSE, TRUE, do_confuse, do_long_outfile, long_outfile,    show_acs_times_1000, do_cvr, ninps, nhids, nouts, w, npats,    featvecs, use_targvecs, targvecs, classes, acfunc_and_deriv_hids,    acfunc_and_deriv_outs, errfunc, alpha, patwts, regfac, oklvl,    &err, (float *)NULL, &e1, &e2);  /* Finishes and writes the minimal counting info, and if desired  also the confusion matrices. */  accum_print(do_confuse, purpose, npats, *iter, err, e1, e2, 'F', w,    long_classnames, short_classnames, &junkint0, &junkint1,    &junkfloat);  if(do_cvr)    /* Finishes and writes correct-vs.-rejected table. */    cvr_print(TRAIN, npats);  free((char *)wnew);  free((char *)p);  free((char *)r);  free((char *)s);}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?