ef.c

来自「NIST Handwriting OCR Testbed」· C语言 代码 · 共 169 行

C
169
字号
/* "Error functions" each of which, given an output activationsvector, and a target vector or an actual class, and possibly a parm(alpha, for type_1), computes the resulting error contribution and itsgradient w.r.t. the activations vector.  For use by e_and_g, whichcomputes the error and its gradient w.r.t. the weights.Contains:# proc: ef_mse_t - computes "mse" error function, using a target vector.# proc: ef_mse_c - computes "mse" error function, using a target class.# proc: ef_t1_c - computes "type_1" error function, using a class.# proc: ef_ps_c - computes "pos_sum" error function, using a class.[Maybe some or all of these should use blas routines.]*/#include <math.h>/*******************************************************************//* For "mse" error function, using a target vector.Input args:  nouts: Number of output nodes.  acsvec: Output activations vector.  targvec: Target activations vector.Output args:  e: Error contribution of this pattern.  g: Gradient of the error contribution of this pattern w.r.t. the     output activations.  A vector of nouts elts, to be allocated by     caller.*/voidef_mse_t(nouts, acsvec, targvec, e, g)int nouts;float *acsvec, *targvec, *e, *g;{  float a, e_yow, *ac_e;  ac_e = acsvec + nouts;  e_yow = 0.;  while(acsvec < ac_e) {    *g++ = 2. * (a = *acsvec++ - *targvec++);    e_yow += a * a;  }  *e = e_yow;}/*******************************************************************//* For "mse" error function, using a class.Input args:  nouts: Number of output nodes.  acsvec: Output activations vector.  actual_class: The actual class of this pattern (in range 0 through    nouts - 1).Output args:  e: Error contribution of this pattern.  g: Gradient of the error contribution of this pattern w.r.t. the     output activations.  A vector of nouts elts, to be allocated by     caller.*/voidef_mse_c(nouts, acsvec, actual_class, e, g)int nouts;short actual_class;float *acsvec, *e, *g;{  float *actual_p, *ac_e, e_yow, a;  actual_p = acsvec + actual_class;  ac_e = acsvec + nouts;  e_yow = 0.;  while(acsvec < ac_e) {    *g++ = 2. * (a = (acsvec == actual_p ? *acsvec++ - 1. :      *acsvec++));    e_yow += a * a;  }  *e = e_yow;}/*******************************************************************//* For "type_1" error function (using a class, the only possibilityfor type_1).Input args:  nouts: Number of output nodes.  acsvec: Output activations vector.  actual_class: The actual class of this pattern (in range 0 through    nouts - 1).  alpha: Factor used (with minus sign) before taking the exp.Output args:  e: Error contribution of this pattern.  g: Gradient of the error contribution of this pattern w.r.t. the     output activations.  A vector of nouts elts, to be allocated by     caller.*/voidef_t1_c(nouts, acsvec, actual_class, alpha, e, g)int nouts;short actual_class;float *acsvec, alpha, *e, *g;{  float *actual_p, actual_ac, *ac_e, beta, *g_p, ep;  actual_ac = *(actual_p = acsvec + actual_class);  ac_e = acsvec + nouts;  beta = 0.;  for(g_p = g; acsvec < ac_e; acsvec++, g_p++)    if(acsvec != actual_p) {      beta += (ep = exp((double)(-alpha * (actual_ac - *acsvec))));      *g_p = alpha * ep;    }  *e = 1. - 1. / (1. + beta);  *(g + actual_class) = -alpha * beta;}/*******************************************************************//* For "pos_sum" error function (using a class, the only possibilityfor pos_sum).Input args:  nouts: Number of output nodes.  acsvec: Output activations vector.  actual_class: The actual class of this pattern (in range 0 through    nouts - 1).Output args:  e: Error contribution of this pattern.  g: Gradient of the error contribution of this pattern w.r.t. the     output activations.  A vector of nouts elts, to be allocated by     caller.*/voidef_ps_c(nouts, acsvec, actual_class, e, g)int nouts;short actual_class;float *acsvec, *e, *g;{  float *actual_p, *ac_e, e_yow, a;  actual_p = acsvec + actual_class;  ac_e = acsvec + nouts;  e_yow = 0.;  while(acsvec < ac_e) {    if(acsvec == actual_p) {      a = 1. - *acsvec++;      *g++ = -20. * a - 1.;    }    else {      a = *acsvec++;      *g++ = 20. * a + 1.;    }    e_yow += (10. * a + 1.) * a;  }}/*******************************************************************/

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?