📄 nn.c
字号:
/* nn.c*//* Note the weight numbers are only used when defining the initial w matrix. The rc[] vector is another waste of space that might be removed. NB take care to distinguish net->w and net->wo*/#include "./ansi/r.h"#include "./ansi/rand.h"#include "./ansi/mynr.h"#include "./ansi/cmatrix.h"#include "./ansi/macopt.h"#include "./thing_sort.h"#include "./fe.h"#include "./mnc.h"#include "./mnc3.h"#include "./nn.h"static double safe_divide ( int , int ) ;static int nn_in_A_T ( int o , int i, mnc_all *all ) ;static int nn_regclass ( int o , int i , int regularize, mnc_all *all ) ;void nn_defaults ( mnc_net *net , mnc_net_control *nc ) { /* command line flag to change this variable */#include "nn_var_def.c" nc->loop = 1 ; nc->num_derivatives = 0 ; /* this is a counter */ macopt_defaults ( &(nc->macarg) ) ; }int nn_make_sense ( mnc_net_control *nc , mnc_all *all ) { int status = 0 ; char junk[100] ; data_creation_param *dc = all->dc ; mnc_net *net = all->net ; fe_min_param *p = all->p ; /* possible situations: training && randomize weights, or hand set them (default is no training) no training && read in weights (default if read in w) training && read in weights weights can be read in from a default file if there is training then it makes sense for the weights to be written*/ if ( !(nc->train) && !(nc->read) ) { nc->read=1; }/* if reading weights, make file name */ sprintf ( junk , "%dN%dM%dt%dS%lds%gn%gS%ldn%da%ga%g" , dc->MNC , dc->N , dc->M , dc->tr , dc->mseed , dc->fs , dc->fn , nc->trseed , nc->train_n , nc->alpha[2] , nc->alpha[3] ) ; if ( nc->read==1 ) sprintf ( nc->infile , "w/%s" , junk ) ; if ( nc->report==1 ) sprintf ( nc->reportfile , "rnn/%s" , junk ) ; /* if training, ensure that a special name has been given for the weights */ if ( (nc->train) && !(nc->write) ) { fprintf ( stderr , "warning: nn weights will not be written\n" ) ; } else if ( !(nc->train) ) { nc->write = 0 ; }/* if writing weights, make file name */ if ( nc->write==1 ) sprintf ( nc->outfile , "w/%s" , junk ) ; /* a little calculation (inaccurate)*/ nc->typ_input_sum = dc->fs * (double) ( dc->tr * dc->tr * dc->M ) / (double)(dc->N) ; if ( nc->LOOP > 1 ) { nc->tolf = exp ( log ( nc->tolmin / nc->tol0 ) / ( double ) ( nc->LOOP - 1 ) ) ; } net->I = p->M ; net->O = p->N ; net->K = ( net->I + 1 ) * net->O ; if ( nc->hitlist_n > net->O ) nc->hitlist_n = net->O ; return status ; } int nn_allocate ( fe_min_param *p , mnc_net *net , mnc_net_control *nc , mnc_all *all ) { int status = 0 ; net->wo = dvector ( 1 , net->K ) ; net->w = net->wo ; /* NB take care to distinguish net->w and net->wo*/ net->a = dvector ( 1 , net->O ) ; net->ca = dvector ( 1 , net->O ) ; net->e = dvector ( 1 , net->O ) ; net->y = dvector ( 1 , net->O ) ;/* net->g = dvector ( 1 , net->K ) ; */ net->rc = ivector ( 1 , net->K ) ; net->wn = imatrix ( 1 , net->O , 0 , net->I ) ; net->h = cvector ( 1 , net->O ) ; status += nn_set_up_rc ( net , nc , all ) ; nc->tot_tr = nc->train_n * net->O ; nc->tot_te = nc->test_n * net->O ; nc->tot_dec = nc->decodn * net->O ; return status ; }int nn_set_up_rc ( mnc_net *net , mnc_net_control *nc , mnc_all *all) { int i , o , k , c ; int status = 0 ; k = 1 ; switch ( nc->regularize ) { case ( 0 ) : default : nc->RC = 1 ; break ; case ( 1 ) : nc->RC = 2 ; /* biases and inputs */ break ; case ( 2 ) : nc->RC = 3 ; /* biases and inputs in A^T and other inputs */ break ; } for ( c = 1 ; c <= nc->RC ; c++ ) { nc->ninrc[c] = 0 ; } for ( o = 1 ; o <= net->O ; o ++ ) { for ( i = 0 ; i <= net->I ; i++ ) { c = nn_regclass ( o , i , nc->regularize , all ) ; net->rc[k] = c ; nc->ninrc[c] ++ ; net->wn[o][i] = k ; k ++ ; } } k--; if ( k != net->K ) status -- ; if ( nc->verbose ) { for ( c = 1 ; c <= nc->RC ; c++ ) { printf ( "in rc %d: %d; " , c , nc->ninrc[c] ) ; } printf ( "\n" ) ; } return status ; }static int nn_regclass ( int o , int i , int regularize, mnc_all *all ) { switch ( regularize ) { case ( 0 ) : default : return 1 ; break ; case (1) : return ( i==0 ) ? 1 : 2 ; /* bias is 1, inputs are 2 */ break ; case (2) : if ( i == 0 ) return 1 ; else { return ( nn_in_A_T ( o , i , all ) ? 2 : 3 ) ; } break ; }}static int nn_in_A_T ( int o , int i, mnc_all *all ) { alist_matrix *a = &(all->p->a) ; int m , u , n ; n = o ; for ( u = 1 ; u <= a->num_nlist[n] ; u++ ) { m = a->nlist[n][u] ; if ( m == i ) return 1 ; } return 0 ; }int nn_initialize ( mnc_net *net , mnc_net_control *nc , mnc_all *all ) { /* this includes training the thing if necessary, and saving its weights */ int status = 0 ; if ( nc->read ) { status += readdvector ( net->wo , 1 , net->K , nc->infile ) ; if ( status == 0 ) printf ("Read in weights\n" ) ; } else { status += nn_weight_init ( net , nc , all ) ; } if ( nc->train ) { status += nn_train ( net , nc , all ) ; /* includes writing of weights */ }/* check out net performance */ nn_eval_errors ( net , all ) ; nn_say_headings ( stdout , nc , 0 ) ; nn_say_errors ( stdout , nc , 0 ) ; /* OK, now we are initialized */ return status ; }int nn_weight_init ( mnc_net *net , mnc_net_control *nc , mnc_all *all ) { int status = 0 ; int k , n , u , m ; double *w = net->wo ; int **wn = net->wn ; alist_matrix *a = &(all->p->a) ; switch ( nc->init_rule ) { case ( 1) : /* set w to multiple of A^T */ nc->def_b = -nc->def_w * nc->typ_input_sum ; set_dvector_const ( w , 1 , net->K , 0.0 ) ; for ( n = 1 ; n <= a->N ; n ++ ) { w[wn[n][0]] = nc->def_b ; for ( u = 1 ; u <= a->num_nlist[n] ; u++ ) { m = a->nlist[n][u] ; w[ wn[n][m] ] = nc->def_w ; } } break ; case ( 0 ) : default: /* randomize */ ran_seed ( nc->wseed ) ; for ( k = 1 ; k <= net->K ; k ++ ) w[k] = rann() * nc->sigma_w0 ; break ; } return status ; } int nn_train ( mnc_net *net , mnc_net_control *nc , mnc_all *all ) {/* Run optimizer to minimize performance objective function */ int status = 0 ; FILE *fp ; for ( nc->loop = 1 , nc->tol = nc->tol0 ; nc->loop <= nc->LOOP ; nc->loop ++ , nc->tol *= nc->tolf ) { nc->macarg.tol = nc->tol ; nc->macarg.itmax = nc->itmax ; nc->macarg.rich = nc->rich ; nc->macarg.end_if_small_step = nc->end_on_step ; /* NB always pass net->wo to optimizers. They will then pass back other vectors which get copied to w for functional purposes, with the final result being returned by the optimizer in net->wo. At the end of the day remember to reset net->w from whatever so that it points at net->wo again, before any function of the net is requested. */ if ( nc->CG ) { checkgrad ( net->wo , net->K , nc->epsilon , nn_objective , (void *)(all) , nn_derivative , (void *)(all) ); } if ( nc->opt == 1 ) macopt ( net->wo , net->K , ( nc->end_on_step * 10 + nc->rich ) , nc->tol , &(nc->iter) , nc->itmax , nn_derivative , (void *)(all) ); else macoptII ( net->wo , net->K , nn_derivative , (void *)(all) , &(nc->macarg) );/* nn_align_w_with_wo straight away! */ net->w = net->wo ; /* chat */ if ( nc->report || nc->verbose ) { nn_eval_errors ( net , all ) ; if ( nc->report ) { fp = fopen ( nc->reportfile , "a" ) ; if ( !((nc->loop-1)%20) ) nn_say_headings ( fp , nc , 1 ) ; nn_say_errors ( fp , nc , 1 ) ; fclose ( fp ) ; } if ( nc->verbose ) { nn_say_headings ( stdout , nc , 1 ) ; nn_say_errors ( stdout , nc , 1 ) ; fflush ( stdout ) ; } } if ( nc->write ) { status += writedvector ( net->wo , 1 , net->K , nc->outfile ) ; } if ( nc->CG ) { checkgrad ( net->wo , net->K , nc->epsilon , nn_objective , (void *)(all) , nn_derivative , (void *)(all) ); net->w = net->wo ; } } return status ; }void nn_eval_errors ( mnc_net *net , mnc_all *all ) { mnc_net_control *nc = all->nc ; /* error on (maybe same) test set using iterative decoding */ if ( nc->decodn > 0 ) { if ( nc->verbose ) printf ( "Finding iterative decoding error\n" ) ; ran_seed ( nc->decodseed ) ; nc->itEH = nn_iter_error_on ( net , all , nc->decodn ) ; } if ( nc->train_n ) {/* just like nn_objective ( net->wo , (void*)(all) ) ; */ if ( nc->verbose ) printf ( "Finding training set error\n" ) ; ran_seed ( nc->trseed ) ; nc->ED = nn_error_on ( net , all , nc->train_n ) ; nc->QW = nn_weight_energy ( net , nc ) ; nc->M = nc->ED + nc->QW ; /* hard scores */ nc->EDHtr = nc->EDH ; nc->EHwb_tr = nc->EHwb ; } /* error on test set */ if ( nc->test_n > 0 ) { if ( nc->verbose ) printf ( "Finding test error\n" ) ; ran_seed ( nc->teseed ) ; nc->ET = nn_error_on ( net , all , nc->test_n ) ; nc->EDHte = nc->EDH ; nc->EHwb_te = nc->EHwb ; }}/* uses the iterative method to attempt to decode. each iteration, the bit error count is reported.*/int nn_iter_error_on ( mnc_net *net , mnc_all *all , int n ) { mnc_net_control *nc = all->nc ; mnc_vectors *v = all->vec ; alist_matrix *a = &(all->p->a) ; int ET = 0 , E=999 , Ewb = 0 ; int decodit ; FILE *fp ; net->x = v->z ; /* input = output of A x + y = z */ net->t = v->x ; /* target = original x vector */ if ( nc->decwrite ) { fp = fopen ( nc->decfile , "w" ) ; if ( !fp ) { fprintf ( stderr , " couldn't open decwrite file %s\n" , nc->decfile ) ; nc->decwrite = 0 ; } else { fprintf ( fp , "Decoding " ) ; fclose ( fp ) ; } } for ( ; n >= 1 ; n-- ) { make_vectors_quick ( all->dc , all->p , all->vec) ; if ( nc->decwrite ) { fp = fopen ( nc->decfile , "a" ) ; fprintf ( fp , " ... \n" ) ; fclose ( fp ) ; } if ( nc->decverbose > 0 ) printf ( "Decoding ... \n" ) ; set_ivector_const ( v->touches , 1 , net->O , 0 ) ; for ( decodit = 1 ; decodit <= nc->decodits ; decodit++ ) { if ( nc->decverbose > 1 ) printf ( "iteration %d / %d - " , decodit , nc->decodits ) ; nn_forward_pass ( net ) ; /* look at outputs and make a hit list of bits to change */ nn_hitlist ( net , all ) ; /* change those bits and evaluate how many bits were flipped, corrected, spoilt */ /* print to stdout and or file */ if ( ( E = nn_flip_and_score ( net , all ) ) && /* if there are still bits to change , and */ /* if this is not the last loop */ (decodit < nc->decodits) ) { alist_times_cvector_mod2 ( a , v->x , v->z ) ; /* multiply the hit list by A and add the result to v->z (i.e. change the input of the net) add the hit list to net->t (i.e. change the target too) equivalently, just take the new x and produce a new z. So that no need for a zpartial vector. */ /* NB since this routine messes with the input and target, no other routine should rely on them! */ } else break ; } ET += E ; Ewb += ( E > 0 ) ? 1 : 0 ; } nc->itEHwb = Ewb ; return ( ET ) ; /* returns the total number of high bits that remained uncorrected (not the number of bit errors, oops! */}void nn_say_errors ( FILE *fp , mnc_net_control *nc , int extras ) { int c ; fprintf ( fp , "%-11.6g %-9d %-11.6g %-6d %-9.4g " , nc->ED ,
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -