📄 bp.c
字号:
case 'm': {LAYER *layer; int itemp, itemp2;
nlayers = 0;
wttotal = 0;
pccnet = 0;
ioconnects = 0;
ch = readch();
p = NULL;
while (ch != 'x' && ch != '\n' && ch != '*')
{
itemp = readint(1,MAXINT,'m');
if (readerror)
{
wttotal = 0;
start = NULL;
goto endm;
};
/* check for + number for a recurrent net */
if (nlayers == 0)
{
do ch = readch(); while (ch == ' ');
if (ch != '+')
{
bufferptr = bufferptr - 1;
stmsize = 0;
}
else
{
itemp2 = readint(1,MAXINT,'m');
if (readerror)
{
wttotal = 0;
start = NULL;
goto endm;
};
stmsize = itemp2;
itemp = itemp + itemp2;
};
};
nlayers = nlayers + 1;
p = mklayer(p,itemp);
if (nlayers == 1) start = p;
ch = readch();
while (ch == ' ') ch = readch();
if (ch >= '0' && ch <= '9') bufferptr = bufferptr - 1;
};
last = p;
#ifndef SYMMETRIC
if (ch == 'x')
{
n1 = start->units;
while (n1 != NULL)
{
n2 = last->units;
while (n2 != NULL)
{
connect(n1,n2,(WTTYPE) 0);
n2 = n2->next;
};
n1 = n1->next;
}
ioconnects = 1;
};
#endif
wtsinuse = wttotal;
bufferptr = bufferptr - 1;
nullpatterns(TRAIN);
nullpatterns(TEST);
/* should really free up the old network structure as well */
last->activation = ao;
last->D = Do;
layer = last->backlayer;
while (layer != start)
{
layer->D = Dh;
layer->activation = ah;
layer = layer->backlayer;
};
clear();
endm: break;};
case 'o':
do ch = readch(); while (ch == ' ');
{
PATNODE *targetpn;
bufferptr = bufferptr - 1; /* unget the character */
if (nonetwork() || nopatterns()) break;
itemp = readint(1,s[TOL][TRAIN].npats,'o');
if (readerror) break;
resetpats(TRAIN);
for (i=1;i<=itemp;i++) nextpat(TRAIN);
/* setoutputpat(); */
u = (UNIT *) last->units;
pl = last->currentpat[TRAIN];
targetpn = pl->pats;
itemp2 = 0; /* unit counter */
i = 1; /* format counter */
while (u != NULL)
{
sprintf(outstr,"%5.2f",unscale(targetpn->val)); pg(stdout,outstr);
targetpn++;
itemp2 = itemp2 + 1;
if (format[i] == itemp2)
{
if (outformat == 'r') pg(stdout,"\n"); else pg(stdout," ");
if (i < MAXFORMAT - 1) i = i + 1;
};
u = u->next;
}
pg(stdout,"\n");
};
endo:
break;
case 'p':
do ch = readch(); while (ch == ' ');
bufferptr = bufferptr - 1;
if (ch == 'a' || ch == '\n' || ch == '*')
{int stopprinting;
if (nonetwork() || nopatterns()) break;
stopprinting = eval(TRAIN,1);
if (!stopprinting) printstats(stdout,TRAIN,0,s);
}
else
{
itemp = readint(0,s[TOL][TRAIN].npats,'p');
if (readerror) break;
if (nonetwork() || nopatterns()) break;
if (itemp == 0) {eval(TRAIN,0); printstats(stdout,TRAIN,0,s);}
else evalone(itemp,TRAIN,1,0);
};
break;
case '.':
case '0':
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '8':
case '9':
case 'x':
case '-':
case '\\': {int status;
if (nonetwork()) break;
if (ch != '\\') bufferptr = bufferptr - 1;
status = loadpat();
if (status == 1) printoutunits(0,last,-1,0.0,(char) 0,(char) 0);
break;}
case 'q':
ch = readch();
if (ch == '\n') return;
#ifndef SYMMETRIC
else if (ch != 'p') texterror();
else
{
while (ch != '\n' && ch != '*')
{
ch = readch();
if (ch == 'd')
{REAL rtemp;
do ch = readch(); while (ch == ' ');
if (ch == 'h')
{
rtemp = readreal(GE,0.0,'q');
if (readerror) goto endq;
qpdecayh = rtemp;
}
else if (ch == 'o')
{
rtemp = readreal(GE,0.0,'q');
if (readerror) goto endq;
qpdecayo = rtemp;
}
else if (ch >= '0' && ch <= '9')
{
bufferptr = bufferptr - 1;
rtemp = readreal(GE,0.0,'q');
if (readerror) goto endq;
qpdecayh = rtemp;
qpdecayo = rtemp;
}
}
else if (ch == 'e')
{
temp = rdr(GT,0.0,'d');
if (!readerror && !nonetwork()) qpeta = temp;
}
else if (ch == 'm')
{
temp = rdr(GT,0.0,'d');
if (!readerror) mu = temp;
}
else if (ch == 'n')
{
temp = rdr(GE,0.0,'d');
if (!readerror) qpnoise = temp;
}
else if (ch == 's')
{
do (ch = readch()); while (ch == ' ');
if (ch == '+' || ch == '-') qpslope = ch;
else texterror();
}
else if (ch == '*' || ch == '\n' || ch == ' ');
else texterror();
}
};
endq:
bufferptr = bufferptr - 1;
break;
#endif
case 'r': /* r for run, rw for restore weights */
do ch = readch(); while (ch == ' ');
if (ch == '\n' || ch == '*')
{
if (nonetwork() || nopatterns()) goto endr;
run(maxiter,printrate);
}
else if (ch == 'w')
{
do ch = readch(); while (ch == ' ');
bufferptr = bufferptr - 1;
if (ch == '*' || ch == '\n') /* nothing */ ;
else
{
wtfile = readstr();
strcpy(wtfilename,wtfile);
if (saveonminimum == '+')
{int i, ok, temp;
i = 0;
while (wtfile[i] != '\0') i = i + 1;
while (wtfile[i] != '.') i = i - 1;
wtfile[i] = '\0';
ok = sscanf(&wtfilename[i+1],"%d",&temp);
if (ok == 1) wtfilecount = temp;
else
{
wtfilecount = 0;
pg(stdout,"weight file number error; number starts a 0\n");
};
};
};
if (nonetwork()) break; else restoreweights();
}
else if (ch == 't')
{
do ch = readch(); while (ch == ' ');
if (ch == '{') {/* nothing */}
else
{
bufferptr = bufferptr - 1;
trainfile = readstr();
itemp = MAXINT;
if (!pushfile(trainfile,1)) goto endr;
strcat(trfiles,"rt ");
strcat(trfiles,trainfile);
strcat(trfiles,"\n");
};
nullpatterns(TRAIN);
readingpattern = 1;
itemp2 = readpats(TRAIN,'r');
readingpattern = 0;
s[TOL][TRAIN].npats = itemp2;
s[MAX][TRAIN].npats = itemp2;
sprintf(outstr,"%d training patterns read\n\n",itemp2); pg(stdout,outstr);
goto endr;
}
else if (ch == 'x')
{
if (nonetwork()) break;
trainfile = readstr();
if (!pushfile(trainfile,1)) goto endr;
strcat(trfiles,"rx ");
strcat(trfiles,trainfile);
strcat(trfiles,"\n");
prevnpats = s[TOL][TRAIN].npats;
findendofpats(last);
findendofpats(start);
readingpattern = 1;
itemp2 = readpats(TRAIN,'r');
sprintf(outstr,"%d patterns added\n\n",itemp2); pg(stdout,outstr);
readingpattern = 0;
itemp = prevnpats + itemp2;
s[TOL][TRAIN].npats = itemp;
s[MAX][TRAIN].npats = itemp;
goto endr;
}
else if (ch >= '1' && ch <= '9')
{
bufferptr = bufferptr - 1;
itemp = readint(1,MAXINT,'r');
if (!readerror) maxiter = itemp; else goto endr;
itemp = readint(1,MAXINT,'r');
if (!readerror) printrate = itemp; else goto endr;
do ch = readch(); while (ch == ' ');
if (ch == '"') goto endr; else bufferptr = bufferptr - 1;
if (!nonetwork() && !nopatterns()) run(maxiter,printrate);
}
else texterror();
endr:
break;
case 's': /* s <int> for seed, sw <filename> for save weights */
do ch = readch(); while (ch == ' ');
if (ch == 'w')
{
do ch = readch(); while (ch == ' ');
bufferptr = bufferptr - 1;
if (ch == '*' || ch == '\n') /* nothing */ ; else wtfile = readstr();
if (nonetwork()) break; else saveweights();
}
else if (ch == 'e' || ch == 'a')
{char *savefile;
FILE *sf;
saveweights();
do ch = readch(); while (ch == ' ' || ch == '\n');
bufferptr = bufferptr - 1;
savefile = readstr();
sf = fopen(savefile,"w");
if (sf == NULL)
{
sprintf(outstr,"cannot open the file: %s\n",savefile);
goto ends;
};
parameters(sf);
fflush(sf);
fclose(sf);
}
else if (ch == 'i' || (ch >= '0' && ch <= '9'))
{
bufferptr = bufferptr - 1;
sprev = seedstart;
while (ch != '\n' && ch != '*')
{
do ch = readch(); while (ch == ' ');
if (ch == 'i')
{
do ch = readch(); while (ch == ' ');
if (ch == '+' || ch == '-') incrementseed = ch;
else {texterror(); goto ends;};
}
else if (ch >= '0' && ch <= '9')
{
bufferptr = bufferptr - 1;
seed = readint(0,MAXINT,'s');
if (readerror) goto ends;
snode = (SEEDNODE *) malloc(sizeof(SEEDNODE));
snode->val = seed;
snode->next = NULL;
sprev->next = snode;
sprev = snode;
}
else if (ch == '\n' || ch == '*');
else {texterror(); goto ends;};
};
snode = seedstart->next;
seed = snode->val;
}
else if (ch == 'b')
{WTTYPE temp; LAYER *layer; UNIT *u;
if (nonetwork()) break;
temp = rdr(GT,(REAL) -unscale(32767),'s');
if (readerror) break;
stdthresh = temp;
layer = start->next;
while (layer != NULL)
{
u = (UNIT *) layer->units;
while (u != NULL)
{
w = (WTNODE *) u->wtlist;
while (w->next != NULL) w = w->next;
#ifdef SYMMETRIC
*(w->weight) = stdthresh;
#else
w->weight = stdthresh;
#endif
u = u->next;
};
layer = layer->next;
};
biasset = 1;
}
else texterror();
ends:
endsw:
break;
case 't':
do ch = readch(); while (ch == ' ');
if ((ch == 'a' || ch == '\n' || ch == '*') && (!nonetwork()))
{
if (s[TOL][TEST].npats == 0) pg(stdout,"no test patterns\n");
else
{
eval(TEST,1);
printstats(stdout,TEST,0,s);
bufferptr = bufferptr - 1;
};
}
else if (ch == 'o')
{
rtemp = readreal(GE,0.0,'t');
if (readerror) break;
toloverall = rtemp;
}
else if (ch == 'f')
{
do ch = readch(); while (ch == ' ');
bufferptr = bufferptr - 1;
testfile = readstr();
nullpatterns(TEST);
if (!pushfile(testfile,1)) goto endt;
readingpattern = 1;
itemp = readpats(TEST,'t');
readingpattern = 0;
sprintf(outstr,"%d test patterns read\n",itemp); pg(stdout,outstr);
s[TOL][TEST].npats = itemp;
s[MAX][TEST].npats = itemp;
}
else if (ch == 'r')
{
do ch = readch(); while (ch == ' ');
if (ch == 'p') itemp2 = 1;
else {itemp2 = 0; bufferptr = bufferptr - 1;};
itemp = readint(1,s[TOL][TEST].npats,'t');
if (!readerror) evalr(TEST,itemp2,itemp);
}
else
{
bufferptr = bufferptr - 1;
rtemp = readreal(GE,0.0,'t');
if (readerror) break;
else if (rtemp > 0 && rtemp < 1.0) toler = scale(rtemp);
else
{
bufferptr = bufferptr - 1;
if (s[TOL][TEST].npats == 0)
{
pg(stdout,"there is no test set\n");
break;
};
if (rtemp == 0.0) {eval(TEST,0); printstats(stdout,TEST,0,s);}
else
{
itemp = rtemp;
if (itemp > s[TOL][TEST].npats)
pg(stdout,"not that many patterns in the test set\n");
else evalone(itemp,TEST,1,0);
};
};
};
endt: break;
case 'w':
{int i;
if (nonetwork()) break;
layerno = readint(2,nlayers,'w');
if (readerror) break;
unitno = readint(1,MAXINT,'w');
if (readerror) break;
u = locateunit(layerno,unitno);
if (u != NULL) printweights(u,layerno);
break;
};
case 'g': /* a maze (game) problem for temporal difference */
{int ngames, nlevels, ndoors;
ngames = readint(1,10,'x');
nlevels = readint(1,10,'x');
ndoors = readint(1,10,'x');
maze(ngames,nlevels,ndoors);
ch = ' ';
break;
};
case 'z': /* reads in non-user set (hidden) parameters */
{REAL rtemp; int itemp;
rtemp = readreal(GE,0.0,'z');
if (readerror) break;
minimumsofar = rtemp; /* minimum test set error */
};
break;
default: if (ch >= 'A' && ch <= 'Z') menus(ch); else texterror();
break;
};
if (ch != '\n') do ch = readch(); while (ch != '\n');
}while (!finished);
}
void main(argc,argv)
int argc;
char *argv[];
{
setbuf(stdout,NULL); /* set unbuffered output */
#if defined(UNIX) && defined(HOTKEYS)
initraw();
#endif
lineno = 0;
pg(stdout,"Basis of AI Backprop (c) 1990-96 by Donald R. Tveter\n");
pg(stdout," drt@mcs.com - http://www.mcs.com/~drt/home.html\n");
pg(stdout," April 10, 1996 version.\n");
filestackptr = 0;
filestack[0] = stdin;
data = stdin;
emptystring = '\0';
trainfile = &emptystring;
testfile = &emptystring;
if (argc == 1)
{
printf("no data file, stdin assumed\n");
datafile = "stdin";
}
else
{
datafile = argv[1];
pushfile(datafile,1);
};
init();
signal(SIGINT,restartcmdloop); /* restart from interrupt */
cmdloop();
if (copy != NULL)
{
fflush(copy);
fclose(copy);
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -