wts.c
来自「NIST Handwriting OCR Testbed」· C语言 代码 · 共 494 行 · 第 1/2 页
C
494 行
/* Routines for setting the MLP weights -- either randomly or byreading a file -- and for writing the weights as a file:# proc: randwts - Gets random MLP network weights.# proc: randwts_oldorder - Gets random MLP network weights in an old format order.# proc: readwts - Reads MLP weights file. Returns the "network architecture"# proc: info stored in the weights file, as well as the weights.# proc: readwts_np - Reads MLP weights file. Returns the "network architecture"# proc: info stored in the weights file, as well as the weights, but# proc: NOT using structures in parms.h.# proc: putwts - Writes the MLP network weights as an ascii file.*/#include <stdio.h>#include <mlp/defs.h>#include <mlp/parms.h>#include <mlp/little.h>#include <mlp/acsmaps.h>#define SCALE 0.5 /* Random weights will be uniformly distributed in the range -SCALE through SCALE. *//********************************************************************//* randwts: Gets random network weights.Input args: ninps, nhids, nouts: Numbers of input, hidden, and output nodes. seed: For the "uni" uniform pseudorandom number generator. Must be a nonzero integer.Output arg: w: The weights, in a buffer allocated by this routine.*/voidrandwts(ninps, nhids, nouts, seed, w)int ninps, nhids, nouts, seed;float **w;{ int numwts; float *p, *ep; float uni(); numwts = nhids * (ninps + 1) + nouts * (nhids + 1); if((*w = (float *)malloc(numwts * sizeof(float))) == (float *)NULL) syserr("randwts (wts.c)", "malloc", "*w"); uni(seed); ep = (p = *w) + numwts; while(p < ep) *p++ = 2. * SCALE * (uni(0) - 0.5);}/********************************************************************//* randwts_oldorder: The difference between this and randwts, is thatthis version installs the weights in the order corresponding to theorder the old random weights installer (in the Fortran version)used.Input args: ninps, nhids, nouts: Numbers of input, hidden, and output nodes. seed: For the "uni" uniform pseudorandom number generator. Must be a nonzero integer.Output arg: w: The weights, in a buffer allocated by this routine.*/voidrandwts_oldorder(ninps, nhids, nouts, seed, w)int ninps, nhids, nouts, seed;float **w;{ int numwts, h, i, j; float *w1, *b1, *w2, *b2; float uni(); numwts = nhids * (ninps + 1) + nouts * (nhids + 1); if((*w = (float *)malloc(numwts * sizeof(float))) == (float *)NULL) syserr("randwts_oldorder (wts.c)", "malloc", "*w"); b2 = (w2 = (b1 = (w1 = *w) + nhids * ninps) + nhids) + nouts * nhids; uni(seed); for(h = 0; h < nhids; h++) { for(i = 0; i < ninps; i++) *(w1 + h * ninps + i) = 2. * SCALE * (uni(0) - 0.5); *(b1 + h) = 2. * SCALE * (uni(0) - 0.5); } for(j = 0; j < nouts; j++) { for(h = 0; h < nhids; h++) *(w2 + j * nhids + h) = 2. * SCALE * (uni(0) - 0.5); *(b2 + j) = 2. * SCALE * (uni(0) - 0.5); }}/********************************************************************//* readwts: Reads a weights file. Returns the "network architecture"info stored in the weights file, as well as the weights.Input/output arg: parms: The PARMS structure, with its members set according to the specfile. This routine finds wts_infile, the file from which to read the weights, in the parms structure, and it sets into the parms members the values of purpose, ninps, nhids, nouts, acfunc_hids, and acfunc_outs, which are specified at the top of wts_infile. When its sets these values in, it also turns on the corresponding "set" members.Output arg: w: The weights, in a buffer allocated by this routine.*/voidreadwts(parms, w)PARMS *parms;float **w;{ FILE *fp; char line[100], name_str[100], val_str[100], errstr[200]; int numwts; float *p, *pe; if((fp = fopen(parms->wts_infile.val, "rb")) == (FILE *)NULL) syserr("readwts (wts.c)", "fopen for reading", parms->wts_infile.val); /* Read the header info, which is in the form of name-value pairs in a defined order. */ /* network_type: must be mlp. (Other routines will read weights files for other types of network, e.g. rbf1.) */ if(!fgets(line, 100, fp) || sscanf(line, "%s %s", name_str, val_str) != 2 || strcmp(name_str, "network_type")) fatalerr("readwts (wts.c)", "improper weights file", parms->wts_infile.val); if(strcmp(val_str, "mlp")) { sprintf(errstr, "network_type must be mlp; it is %s", val_str); fatalerr("readwts (wts.c)", errstr, parms->wts_infile.val); } /* purpose: classifier or fitter. */ if(!fgets(line, 100, fp) || sscanf(line, "%s %s", name_str, val_str) != 2 || strcmp(name_str, "purpose")) fatalerr("readwts (wts.c)", "improper weights file", parms->wts_infile.val); if(!strcmp(val_str, "classifier")) parms->purpose.val = CLASSIFIER; else if(!strcmp(val_str, "fitter")) parms->purpose.val = FITTER; else fatalerr("readwts (wts.c)", "improper weights file", parms->wts_infile.val); parms->purpose.ssl.set = TRUE; /* If wts_infile sets purpose to fitter, then check whether the parms sitaution in specfile is inconsistent with fitter, and if so, error exit. */ if(parms->purpose.val == FITTER) { if(parms->class_wts_infile.ssl.set) { sprintf(errstr, "purpose, as read from wts_infile %s, is \fitter,\nbut class_wts_infile is set in specfile; class_wts_infile \is used only for classifier.", parms->wts_infile.val); fatalerr("readwts (wts.c)", errstr, NULL); } if(parms->lcn_scn_infile.ssl.set) { sprintf(errstr, "purpose, as read from wts_infile %s, is \fitter,\nbut lcn_scn_infile is set in specfile; lcn_scn_infile \is used only for classifier.", parms->wts_infile.val); fatalerr("readwts (wts.c)", errstr, NULL); } if(parms->nokdel.ssl.set) { sprintf(errstr, "purpose, as read from wts_infile %s, is \fitter,\nbut nokdel is set in specfile; nokdel is used only for \classifier.", parms->wts_infile.val); fatalerr("readwts (wts.c)", errstr, NULL); } if(parms->trgoff.ssl.set) { sprintf(errstr, "purpose, as read from wts_infile %s, is \fitter,\nbut trgoff is set in specfile; trgoff is used only for \classifier.", parms->wts_infile.val); fatalerr("readwts (wts.c)", errstr, NULL); } if(parms->scg_earlystop_pct.ssl.set) { sprintf(errstr, "purpose, as read from wts_infile %s, is \fitter,\nbut scg_earlystop_pct is set in specfile; \scg_earlystop_pct is used only for classifier.", parms->wts_infile.val); fatalerr("readwts (wts.c)", errstr, NULL); } if(parms->alpha.ssl.set) { sprintf(errstr, "purpose, as read from wts_infile %s, is \fitter,\nbut alpha is set in specfile; alpha is used only for \classifier.", parms->wts_infile.val); fatalerr("readwts (wts.c)", errstr, NULL); } if(parms->oklvl.ssl.set) { sprintf(errstr, "purpose, as read from wts_infile %s, is \fitter,\nbut oklvl is set in specfile; oklvl is used only for \classifier.", parms->wts_infile.val); fatalerr("readwts (wts.c)", errstr, NULL); } if(parms->priors.val == CLASS) { sprintf(errstr, "purpose, as read from wts_infile %s, is \fitter,\nbut priors is set to class in specfile; that makes sense \only for classifier.", parms->wts_infile.val); fatalerr("readwts (wts.c)", errstr, NULL); } if(parms->priors.val == BOTH) { sprintf(errstr, "purpose, as read from wts_infile %s, is \fitter,\nbut priors is set to both in specfile; that makes sense \only for classifier.", parms->wts_infile.val); fatalerr("readwts (wts.c)", errstr, NULL); } if(parms->errfunc.val == TYPE_1) { sprintf(errstr, "purpose, as read from wts_infile %s, is \fitter,\nbut errfunc is set to type_1 in specfile; that makes sense \only for classifier.", parms->wts_infile.val); fatalerr("readwts (wts.c)", errstr, NULL); } if(parms->errfunc.val == POS_SUM) { sprintf(errstr, "purpose, as read from wts_infile %s, is \fitter,\nbut errfunc is set to pos_sum in specfile; that makes sense \only for classifier.", parms->wts_infile.val); fatalerr("readwts (wts.c)", errstr, NULL); } if(parms->do_confuse.val == TRUE) { sprintf(errstr, "purpose, as read from wts_infile %s, is \fitter,\nbut do_confuse is set to true in specfile; that makes sense \only for classifier.", parms->wts_infile.val); fatalerr("readwts (wts.c)", errstr, NULL); } if(parms->do_cvr.val == TRUE) { sprintf(errstr, "purpose, as read from wts_infile %s, is \fitter,\nbut do_cvr is set to true in specfile; that makes sense \only for classifier.", parms->wts_infile.val); fatalerr("readwts (wts.c)", errstr, NULL); } } /* ninps: number of input nodes */ if(!fgets(line, 100, fp) || sscanf(line, "%s %d", name_str, &(parms->ninps.val)) != 2 || strcmp(name_str, "ninps")) fatalerr("readwts (wts.c)", "improper weights file",
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?