accum.c
来自「NIST Handwriting OCR Testbed」· C语言 代码 · 共 578 行 · 第 1/2 页
C
578 行
/* Routines for doing various weighted accumulations (they replace theoriginal "confuse" routine):# proc: accum_init - Initialization routine.# proc: accum_zero - Zeros the accumulators.# proc: accum_cpat - Accumulates for the current pattern.# proc: accum_print - Prints some results.# proc: accum_free - Frees used local buffers.# proc: (accum_printer - Used by accum_print.)# proc: (accum_sumout - Used by accum_print.)# proc: (accum_yow - Used by accum_sumout.)NOTE: This file has several static variables whose scope is the file,and they are used by the routines (i.e., side effects) withoutcomments.*/#include <stdio.h>#include <math.h>#include <mlp/tda.h>#include <mlp/defs.h>#include <mlp/macros.h>#include <mlp/parms.h>#define ERRS_NERRS_DIM 11 /* The histogram of errors is made in the range 2^(-(ERRS_NERRS_DIM-1)) to 1 *//* If following line is left in force, accum_printer will includein its output a table showing, for each class, the "key" (short name)and the "name" (long name); if commented out, no table. */#define KEYS_NAMES_TABLEstatic int nouts_loc, nerrs[ERRS_NERRS_DIM], *npats_bc, *iwtd_rpct_bc;static float errs[ERRS_NERRS_DIM], *r_acc_bc, *w_acc_bc, *rej_acc_bc, oklvl_loc, *outrej, sum1, sum2;static TDA_FLOAT confuse_acc, outlvl;/*******************************************************************//* accum_init: Mallocs local (to this file) buffers, and stores alocal copy of oklvl. Call this at start of each run.Input args: nouts: Number of output nodes. do_confuse: If TRUE, mallocs some extra buffers. oklvl: Threshold for rejection.*/voidaccum_init(nouts, do_confuse, oklvl)int nouts;char do_confuse;float oklvl;{ if((r_acc_bc = (float *)malloc(nouts * sizeof(float))) == (float *)NULL) syserr("accum_init (accum.c)", "malloc", "r_acc_bc"); if((w_acc_bc = (float *)malloc(nouts * sizeof(float))) == (float *)NULL) syserr("accum_init (accum.c)", "malloc", "w_acc_bc"); if((rej_acc_bc = (float *)malloc(nouts * sizeof(float))) == (float *)NULL) syserr("accum_init (accum.c)", "malloc", "rej_acc_bc"); if((iwtd_rpct_bc = (int *)malloc(nouts * sizeof(int))) == (int *)NULL) syserr("accum_init (accum.c)", "malloc", "iwtd_pct_bc"); if((outrej = (float *)malloc(nouts * sizeof(float))) == (float *)NULL) syserr("accum_init (accum.c)", "malloc", "outrej"); oklvl_loc = oklvl; if(do_confuse) { confuse_acc.dim2 = nouts; if((confuse_acc.buf = (float *)malloc(nouts * nouts * sizeof(float))) == (float *)NULL) syserr("accum_init (accum.c)", "malloc", "confuse_acc.buf"); outlvl.dim2 = nouts; if((outlvl.buf = (float *)malloc(nouts * nouts * sizeof(float))) == (float *)NULL) syserr("accum_init (accum.c)", "malloc", "outlvl.buf"); if((npats_bc = (int *)malloc(nouts * sizeof(float))) == (int *)NULL) syserr("accum_init (accum.c)", "malloc", "npats_bc"); } else { confuse_acc.buf = outlvl.buf = (float *)NULL; npats_bc = (int *)NULL; } nouts_loc = nouts;}/*******************************************************************//* accum_zero: Zeros out the accumulators, and initializes the errsand nerrs arrays. Call this from e_and_g before any calls ofaccum_cpat.Input arg: do_confuse: If TRUE, zeros out the confusion accumumlators, as well as the basic accumulators that it always zeros out.*/voidaccum_zero(do_confuse)char do_confuse;{ int i; memset((char *)r_acc_bc, 0, nouts_loc * sizeof(float)); memset((char *)w_acc_bc, 0, nouts_loc * sizeof(float)); memset((char *)rej_acc_bc, 0, nouts_loc * sizeof(float)); memset((char *)outrej, 0, nouts_loc * sizeof(float)); if(do_confuse) memset((char *)confuse_acc.buf, 0, nouts_loc * nouts_loc * sizeof(float)); if(do_confuse) { memset((char *)outlvl.buf, 0, nouts_loc * nouts_loc * sizeof(float)); memset((char *)npats_bc, 0, nouts_loc * sizeof(int)); } sum1 = sum2 = 0.; for(i = 0; i < ERRS_NERRS_DIM; i++) { errs[i] = pow((double)2, (double)(i - ERRS_NERRS_DIM + 1)); nerrs[i] = 0; }}/*******************************************************************//* accum_cpat: Updates the accumulators according to the currentpattern.Input args: do_confuse: If TRUE, accumulate into the confusion accumulators. purpose: CLASSIFIER or FITTER. vout_cpat: output activations vector for the current pattern. idpat_cpat: Id (class) of the current pattern, if purpose is CLASSIFIER. target_cpat: Target vector of current pattern, if purpose is FITTER. patwt_cpat: Pattern-weight of the current pattern.*/voidaccum_cpat(do_confuse, purpose, vout_cpat, idpat_cpat, target_cpat, patwt_cpat)char do_confuse, purpose;float *vout_cpat, *target_cpat, patwt_cpat;short idpat_cpat;{ short idres_cpat; char okay; int ibig1, ibig2, j, jj, k; float big1, big2, ee; /* Find biggest two activation levels. */ ibig1 = ((vout_cpat[0] >= vout_cpat[1]) ? 0 : 1); ibig2 = 1 - ibig1; big1 = vout_cpat[ibig1]; big2 = vout_cpat[ibig2]; for(j = 2; j < nouts_loc; j++) { if(vout_cpat[j] > big1) { big2 = big1; ibig2 = ibig1; big1 = vout_cpat[j]; ibig1 = j; } else if(vout_cpat[j] > big2) { big2 = vout_cpat[j]; ibig2 = j; } } okay = (idpat_cpat < 0 ? FALSE : big1 > oklvl_loc); if(okay) { if(idpat_cpat == ibig1) r_acc_bc[idpat_cpat] += patwt_cpat; else w_acc_bc[idpat_cpat] += patwt_cpat; if(do_confuse) e(confuse_acc, ibig1, idpat_cpat) += patwt_cpat; if(do_confuse) e(outlvl, ibig1, idpat_cpat) += big1; idres_cpat = ibig1; } else { if(idpat_cpat >= 0) { rej_acc_bc[idpat_cpat] += patwt_cpat; outrej[idpat_cpat] += big1; } idres_cpat = ibig1 - nouts_loc; /* makes value betw. -nouts_loc and -1 */ } sum1 += big1; sum2 += big2; if(do_confuse) { /* Count nos. of patterns by class; (unrelatedly) accumulate histogram of output errors */ npats_bc[idpat_cpat]++; for(j = 0; j < nouts_loc; j++) { ee = fabs((double)(vout_cpat[j] - (purpose == CLASSIFIER ? (j == idpat_cpat ? 1. : 0.) : target_cpat[j]))); jj = ERRS_NERRS_DIM - 1; for(k = 0; k < ERRS_NERRS_DIM - 1; k++) if(ee <= errs[k]) { jj = k; break; } nerrs[jj]++; } }}/*******************************************************************//* accum_print: Prints info from the finished counters, andoptionally also prints the confusion matrices; and, returns someinfo.Input args: do_confuse: If TRUE, print the confusion information. purpose: CLASSIFIER or FITTER. npats: Number of patterns. iter: Current iteration of the optimization. err: Error, including regularization term. e1: Main part of error, basically. e2: Mean squared weight. c: Passed to accum_sumout(). w: Passed to accum_sumout(): weights. long_classnames: Passed to accum_printer(): long names of the classes. short_classnames: Passed to accum_printer(): short names of the classes.Output args: wtd_nr: Weighted "number right". wtd_nw: Weighted "number wrong". wtd_rpct_min: Minimum, across classes, of weighted "right-percentage by class".*/voidaccum_print(do_confuse, purpose, npats, iter, err, e1, e2, c, w, long_classnames, short_classnames, wtd_nr, wtd_nw, wtd_rpct_min)int npats, iter, *wtd_nr, *wtd_nw;char do_confuse, purpose, c, **long_classnames, **short_classnames;float err, e1, e2, w[], *wtd_rpct_min;{ int i, j, ntotal, *wtd_nrej_bc; float a; TDA_INT confuse_wtd_counts; void accum_printer(), accum_sumout(); if(do_confuse) { sum1 /= (float)npats; sum2 /= (float)npats; if((wtd_nrej_bc = (int *)malloc(nouts_loc * sizeof(int))) == (int *)NULL) syserr("accum_print (accum.c)", "malloc", "wtd_nrej_bc"); confuse_wtd_counts.dim2 = nouts_loc; if((confuse_wtd_counts.buf = (int *)malloc(nouts_loc * nouts_loc * sizeof(int))) == (int *)NULL) syserr("accum_print (accum.c)", "malloc", "confuse_wtd_counts.buf"); for(i = 0; i < nouts_loc; i++) { for(j = 0, a = rej_acc_bc[i]; j < nouts_loc; j++) a += e(confuse_acc, j, i); if(a > 0.) a = npats_bc[i] / a; for(j = 0; j < nouts_loc; j++) e(confuse_wtd_counts, j, i) = round(a * e(confuse_acc, j, i)); wtd_nrej_bc[i] = round((float)npats_bc[i] * rej_acc_bc[i] / (r_acc_bc[i] + w_acc_bc[i] + rej_acc_bc[i])); outrej[i] *= (100. / (float)max(1, wtd_nrej_bc[i])); for(j = 0; j < nouts_loc; j++) e(outlvl, j, i) *= (100. / (float)max(1, e(confuse_wtd_counts, j, i))); } ntotal = nouts_loc * npats; accum_printer(long_classnames, short_classnames, ntotal, *wtd_nw, wtd_nrej_bc, &confuse_wtd_counts); free((char *)wtd_nrej_bc); free((char *)(confuse_wtd_counts.buf)); } accum_sumout(purpose, npats, iter, c, w,
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?