nnssif.c
来自「基于MATLAB的神经网络非线性系统辨识软件包.」· C语言 代码 · 共 818 行 · 第 1/3 页
C
818 行
break;
default:
printf("iteration # %i W = %4.3e\r",iteration,NSSE);
}
++iteration;
}
}
/*
>>>>>>>>>> RETURN POINTERS TO RETURN ARGUMENTS & FREE MEMORY <<<<<<<<<<<<
*/
/* Swap pointers if they have been messed up */
if ((iteration&1) == 0) {
mset(W1_new,W1);
tmp = W1; W1=W1_new; W1_new=tmp;
mset(W2_new,W2);
tmp = W2; W2=W2_new; W2_new=tmp;
}
iteration=iteration-1;
if(iteration==0){
*NSSEvecpp = mmake(1,1);
(*NSSEvecpp)->row=0;
(*NSSEvecpp)->col=0;
}
else{
*NSSEvecpp = mmake(iteration,1);
subvec(*NSSEvecpp,NSSEvec,0,iteration-1);
}
*iter = iteration;
*lam = lambda;
mfree(L_hidden); mfree(H_hidden); mfree(L_output); mfree(H_output); mfree(Evec);
mfree(h1); mfree(h2); mfree(y1); mfree(y2); mfree(Y2); mfree(Ahat); mfree(C);
mfree(E); mfree(Evec_new); mfree(dxdy1); mfree(Khat); mfree(dy1de); mfree(dy1dx);
mfree(W1_new); mfree(W2_new); mfree(D);mfree(Dtmp); mfree(NSSEvec);mfree(Htmp);
mfree(theta); mfree(thtmp); mfree(theta_index); mfree(theta_red); mfree(theta_red_new);
mfree(PSI); mfree(PSIx); mfree(G); mfree(H); mfree(h), mfree(Yhat);
mfree(all); mfree(index0); mfree(index7);mfree(onesvec);mfree(tmp0); mfree(tmp2);
mfree(tmp3); mfree(index); mfree(index2); mfree(PHI);
mfree(rowidx); mfree(nrowidx); mfree(AKC);
printf("\n\nNetwork training ended.\n\n\n");
}
/*
--------------------------------------------------------------------------------
---------------- END OF NETWORK TRAINING --------------
--------------------------------------------------------------------------------
*/
/*********************************************************************************
* *
* G A T E W A Y R O U T I N E *
* *
*********************************************************************************/
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
/*
>>>>>>>>>>>>>>>>>> VARIABLE DECLARATIONS <<<<<<<<<<<<<<<<<<<
*/
matrix *NSSEvec;
matrix *NetDef, *W1, *W2, *obsidx, *U, *Y;
double *M, lambda;
int iter, skip, ny, N, nu, nx, hidden, k, n, a, decays;
trparmstruct *trparms;
mxArray *Matmatrix;
char *infolevelstr[] = {"infolevel", "Infolevel", "INFOLEVEL", "InfoLevel"};
char *maxiterstr[] = {"maxiter", "MAXITER", "Maxiter", "MaxIter"};
char *critminstr[] = {"critmin", "Critmin", "CRITMIN", "CritMin"};
char *crittermstr[] = {"critterm", "Critterm", "CRITTERM", "CritTerm"};
char *gradtermstr[] = {"gradterm", "Gradterm", "GRADTERM", "GradTerm"};
char *paramtermstr[] = {"paramterm", "Paramterm", "PARAMTERM", "ParamTerm"};
char *Dstr[] = {"D", "d"};
char *lambdastr[] = {"lambda", "Lambda", "LAMBDA"};
char *skipstr[] = {"skip", "Skip", "SKIP"};
/*
>>>>>>>>>>>>>>>> CHECK FOR PROPER NUMBER OF ARGUMENTS <<<<<<<<<<<<<<<
*/
if (nrhs!=8) mexErrMsgTxt("Wrong number of input arguments");
else if (nlhs > 6) mexErrMsgTxt("Too many output arguments");
nu = mxGetM(prhs[7]); /* # of control signals */
ny = mxGetM(prhs[6]); /* Rows of vector Y */
if(nu<1) mexErrMsgTxt("Wrong dimension of input vector");
N = mxGetN(prhs[6]); /* Columns of vector Y */
if(N!=mxGetN(prhs[7])) mexErrMsgTxt("U and Y should have the same number of columns!");
hidden = mxGetN(prhs[0]); /* # of hidden units */
if(mxGetM(prhs[0])!=2) mexErrMsgTxt("Error in architecture definition!");
if(hidden<2) mexErrMsgTxt("Use at least two hidden units!");
/*
>>>>>>>>>>>>>>>>> CONVERT INPUT ARGUMENTS TO SM FORMAT <<<<<<<<<<<<<<<<
*/
NetDef = matstring2sm(prhs[0]); /* Network architecture */
nx = (int)(*mxGetPr(prhs[1]));
Y = mat2sm(prhs[6]); /* Vector of observed outputs */
U = mat2sm(prhs[7]); /* Vector of inputs */
/* Initialize pseudo-observability indices if obsidx passed as [] */
if(mxGetM(prhs[4])==0 && mxGetN(prhs[4])==0){
obsidx=mmake(1,ny);
minitx(obsidx,1.0);
vput(obsidx,ny-1,(double)nx-ny+1);
}
else
obsidx = mat2sm(prhs[4]);
/* Initialize weight matrices if passed as [] */
if(mxGetM(prhs[2])==0 || mxGetN(prhs[2])==0 || mxGetM(prhs[3])==0\
|| mxGetN(prhs[3])==0){
W1 = mmake(hidden,nx+nu+ny+1);
W2 = mmake(nx,hidden+1);
W2->row = 0; /* Hack telling that the weights should be initialized */
}
else{
if(mxGetM(prhs[2])!=hidden) mexErrMsgTxt("W1 has the wrong dimension");
if(mxGetN(prhs[2])!=nx+nu+ny+1) mexErrMsgTxt("W1 has the wrong dimension");
if(mxGetM(prhs[3])!=nx) mexErrMsgTxt("W2 has the wrong dimension");
if(mxGetN(prhs[3])!=hidden+1) mexErrMsgTxt("W2 has the wrong dimension");
W1 = mat2sm(prhs[2]); /* Input-to-hidden layer weights */
W2 = mat2sm(prhs[3]); /* Hidden-to-output layer weights */
}
trparms = (trparmstruct*)malloc(sizeof(trparmstruct));
a = 5;
if(mxGetN(prhs[a])!=0|| mxGetM(prhs[a])!=0) {
/* INFOLEVEL */
trparms->infolevel = TRDINFOLEVEL;
for(n=0;n<4;n++){
if ((Matmatrix=mxGetField(prhs[a], 0, infolevelstr[n]))!=NULL){
trparms->infolevel=(int)(*mxGetPr(Matmatrix));
break;
}
}
/* MAXITER */
trparms->maxiter = TRDMAXITER;
for(n=0;n<4;n++){
if ((Matmatrix=mxGetField(prhs[a], 0, maxiterstr[n]))!=NULL){
trparms->maxiter=(int)(*mxGetPr(Matmatrix));
break;
}
}
/* CRITMIN */
trparms->critmin = TRDCRITMIN;
for(n=0;n<4;n++){
if ((Matmatrix=mxGetField(prhs[a], 0, critminstr[n]))!=NULL){
trparms->critmin=(double)(*mxGetPr(Matmatrix));
break;
}
}
/* CRITTERM */
trparms->critterm = TRDCRITTERM;
for(n=0;n<4;n++){
if ((Matmatrix=mxGetField(prhs[a], 0, crittermstr[n]))!=NULL){
trparms->critterm=(double)(*mxGetPr(Matmatrix));
break;
}
}
/* GRADTERM */
trparms->gradterm = TRDGRADTERM;
for(n=0;n<4;n++){
if ((Matmatrix=mxGetField(prhs[a], 0, gradtermstr[n]))!=NULL){
trparms->gradterm=(double)(*mxGetPr(Matmatrix));
break;
}
}
/* PARAMTERM */
trparms->paramterm = TRDPARAMTERM;
for(n=0;n<4;n++){
if ((Matmatrix=mxGetField(prhs[a], 0, paramtermstr[n]))!=NULL){
trparms->paramterm=(double)(*mxGetPr(Matmatrix));
break;
}
}
/* Lambda */
trparms->lambda = TRDLAMBDA;
for(n=0;n<3;n++){
if ((Matmatrix=mxGetField(prhs[a], 0, lambdastr[n]))!=NULL){
trparms->lambda=(double)(*mxGetPr(Matmatrix));
break;
}
}
/* Skip */
trparms->skip = TRDSKIP;
for(n=0;n<3;n++){
if ((Matmatrix=mxGetField(prhs[a], 0, skipstr[n]))!=NULL){
trparms->skip=(int)(*mxGetPr(Matmatrix));
break;
}
}
/* D */
for(n=0;n<2;n++){
if ((Matmatrix=mxGetField(prhs[a], 0, Dstr[n]))!=NULL){
decays = mxGetM(Matmatrix)*mxGetN(Matmatrix);
trparms->D = mmake(1,decays);
M = mxGetPr(Matmatrix);
for(n=0;n<decays;n++){
rvput(trparms->D,n,M[n]);
}
break;
}
}
if(Matmatrix==NULL){
trparms->D = mmake(1,1);
put_val(trparms->D,0,0,TRDD);
}
}
else
{
trparms->infolevel = TRDINFOLEVEL;
trparms->maxiter = TRDMAXITER;
trparms->critmin = TRDCRITMIN;
trparms->critterm = TRDCRITTERM;
trparms->gradterm = TRDGRADTERM;
trparms->paramterm = TRDPARAMTERM;
trparms->D = mmake(1,1);
put_val(trparms->D,0,0,TRDD);
trparms->lambda = TRDLAMBDA;
trparms->skip = TRDSKIP;
}
/*
>>>>>>>>>>>>>>>>>>>>>> CALL THE C-ROUTINE <<<<<<<<<<<<<<<<<<<<<
*/
nnssif(&NSSEvec, &iter, &lambda, NetDef, nx, W1, W2, obsidx, trparms, Y, U);
/*
>>>>>>>>>>>>>>>>>>> CREATE OUTPUT MATICES <<<<<<<<<<<<<<<<<<
*/
plhs[0] = mxCreateDoubleMatrix(getrows(W1),getcols(W1),mxREAL);
plhs[1] = mxCreateDoubleMatrix(getrows(W2),getcols(W2),mxREAL);
plhs[2] = mxCreateDoubleMatrix(getrows(obsidx),getcols(obsidx),mxREAL);
plhs[3] = mxCreateDoubleMatrix(getrows(NSSEvec),getcols(NSSEvec),mxREAL);
plhs[4] = mxCreateDoubleMatrix(1,1,mxREAL);
plhs[5] = mxCreateDoubleMatrix(1,1,mxREAL);
sm2mat(plhs[0],W1);
sm2mat(plhs[1],W2);
sm2mat(plhs[2],obsidx);
sm2mat(plhs[3],NSSEvec);
M = mxGetPr(plhs[4]); M[0] = (double)iter;
M = mxGetPr(plhs[5]); M[0] = (double)lambda;
/*
>>>>>>>>>>>>>>>>>>>> FREE ARGUMENT MATRICES <<<<<<<<<<<<<<<<<<<<<
*/
mfree(NetDef);
mfree(U);
mfree(Y);
mfree(trparms->D);
free(trparms);
}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?