nnssif.c

来自「基于MATLAB的神经网络非线性系统辨识软件包.」· C语言 代码 · 共 818 行 · 第 1/3 页

C
818
字号
          break;
       default:
          printf("iteration # %i   W = %4.3e\r",iteration,NSSE);
     }
     ++iteration;
     }
}


/*
 >>>>>>>>>>    RETURN POINTERS TO RETURN ARGUMENTS & FREE MEMORY    <<<<<<<<<<<<
 */

/* Swap pointers if they have been messed up */
if ((iteration&1) == 0) {
	mset(W1_new,W1);
	tmp = W1; W1=W1_new; W1_new=tmp;
	mset(W2_new,W2);
     	tmp = W2; W2=W2_new; W2_new=tmp;
}
iteration=iteration-1;
if(iteration==0){
	*NSSEvecpp = mmake(1,1);
	(*NSSEvecpp)->row=0;
	(*NSSEvecpp)->col=0;
}
else{
	*NSSEvecpp = mmake(iteration,1);
	subvec(*NSSEvecpp,NSSEvec,0,iteration-1);
}
*iter = iteration;
*lam  = lambda;
mfree(L_hidden); mfree(H_hidden); mfree(L_output); mfree(H_output); mfree(Evec);
mfree(h1); mfree(h2); mfree(y1); mfree(y2); mfree(Y2); mfree(Ahat); mfree(C);
mfree(E); mfree(Evec_new); mfree(dxdy1); mfree(Khat); mfree(dy1de); mfree(dy1dx);
mfree(W1_new); mfree(W2_new); mfree(D);mfree(Dtmp); mfree(NSSEvec);mfree(Htmp);
mfree(theta); mfree(thtmp); mfree(theta_index); mfree(theta_red); mfree(theta_red_new);
mfree(PSI); mfree(PSIx); mfree(G); mfree(H); mfree(h), mfree(Yhat);
mfree(all); mfree(index0); mfree(index7);mfree(onesvec);mfree(tmp0); mfree(tmp2);
mfree(tmp3); mfree(index); mfree(index2); mfree(PHI);
mfree(rowidx); mfree(nrowidx); mfree(AKC);

printf("\n\nNetwork training ended.\n\n\n");
}
/*
  --------------------------------------------------------------------------------
  ----------------             END OF NETWORK TRAINING              --------------
  --------------------------------------------------------------------------------
*/


/*********************************************************************************
 *                                                                               *
 *                           G A T E W A Y   R O U T I N E                       *
 *                                                                               *
 *********************************************************************************/
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
  /*
   >>>>>>>>>>>>>>>>>>           VARIABLE DECLARATIONS          <<<<<<<<<<<<<<<<<<<
   */
   matrix *NSSEvec;
   matrix *NetDef, *W1, *W2, *obsidx, *U, *Y;
   double *M, lambda;
   int iter, skip, ny, N, nu, nx, hidden, k, n, a, decays;
   trparmstruct *trparms;
   mxArray  *Matmatrix;
   char *infolevelstr[] = {"infolevel", "Infolevel", "INFOLEVEL", "InfoLevel"};
   char *maxiterstr[] = {"maxiter", "MAXITER", "Maxiter", "MaxIter"};
   char *critminstr[] = {"critmin", "Critmin", "CRITMIN", "CritMin"};
   char *crittermstr[] = {"critterm", "Critterm", "CRITTERM", "CritTerm"};
   char *gradtermstr[] = {"gradterm", "Gradterm", "GRADTERM", "GradTerm"};
   char *paramtermstr[] = {"paramterm", "Paramterm", "PARAMTERM", "ParamTerm"};
   char *Dstr[] = {"D", "d"};
   char *lambdastr[] = {"lambda", "Lambda", "LAMBDA"};
   char *skipstr[] = {"skip", "Skip", "SKIP"};


  /*
   >>>>>>>>>>>>>>>>      CHECK FOR PROPER NUMBER OF ARGUMENTS      <<<<<<<<<<<<<<<
   */
   if (nrhs!=8) mexErrMsgTxt("Wrong number of input arguments");
   else if (nlhs > 6) mexErrMsgTxt("Too many output arguments");
   nu = mxGetM(prhs[7]);   /* # of control signals */
   ny  = mxGetM(prhs[6]);  /* Rows of vector Y */
   if(nu<1) mexErrMsgTxt("Wrong dimension of input vector");
   N = mxGetN(prhs[6]);      /* Columns of vector Y */
   if(N!=mxGetN(prhs[7])) mexErrMsgTxt("U and Y should have the same number of columns!");
   hidden = mxGetN(prhs[0]); /* # of hidden units */
   if(mxGetM(prhs[0])!=2) mexErrMsgTxt("Error in architecture definition!");
   if(hidden<2) mexErrMsgTxt("Use at least two hidden units!");
         

  /*
   >>>>>>>>>>>>>>>>>     CONVERT INPUT ARGUMENTS TO SM FORMAT     <<<<<<<<<<<<<<<<
   */
  NetDef  = matstring2sm(prhs[0]);     /* Network architecture */
  nx = (int)(*mxGetPr(prhs[1]));
  Y    = mat2sm(prhs[6]);        /* Vector of observed outputs  */
  U = mat2sm(prhs[7]);           /* Vector of inputs            */

  /* Initialize pseudo-observability indices if obsidx passed as [] */
  if(mxGetM(prhs[4])==0 && mxGetN(prhs[4])==0){
  	obsidx=mmake(1,ny);
  	minitx(obsidx,1.0);
  	vput(obsidx,ny-1,(double)nx-ny+1);
  }
  else
  	obsidx = mat2sm(prhs[4]);


  /* Initialize weight matrices if passed as [] */
  if(mxGetM(prhs[2])==0 || mxGetN(prhs[2])==0 || mxGetM(prhs[3])==0\
                        || mxGetN(prhs[3])==0){
	W1 = mmake(hidden,nx+nu+ny+1);
   	W2 = mmake(nx,hidden+1);
      	W2->row = 0;   /* Hack telling that the weights should be initialized */ 
   }
   else{
   	if(mxGetM(prhs[2])!=hidden) mexErrMsgTxt("W1 has the wrong dimension");
   	if(mxGetN(prhs[2])!=nx+nu+ny+1) mexErrMsgTxt("W1 has the wrong dimension");
   	if(mxGetM(prhs[3])!=nx) mexErrMsgTxt("W2 has the wrong dimension");
   	if(mxGetN(prhs[3])!=hidden+1) mexErrMsgTxt("W2 has the wrong dimension");
   	W1 = mat2sm(prhs[2]);     /* Input-to-hidden layer weights */
   	W2 = mat2sm(prhs[3]);     /* Hidden-to-output layer weights */
   }
   
 trparms = (trparmstruct*)malloc(sizeof(trparmstruct)); 
 a = 5;
 if(mxGetN(prhs[a])!=0|| mxGetM(prhs[a])!=0) {
    /* INFOLEVEL */
    trparms->infolevel   = TRDINFOLEVEL;    
    for(n=0;n<4;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, infolevelstr[n]))!=NULL){
           trparms->infolevel=(int)(*mxGetPr(Matmatrix));
           break;
        }
    }

    /* MAXITER */
    trparms->maxiter   = TRDMAXITER;    
    for(n=0;n<4;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, maxiterstr[n]))!=NULL){
           trparms->maxiter=(int)(*mxGetPr(Matmatrix));
           break;
        }
    }

    /* CRITMIN */
    trparms->critmin   = TRDCRITMIN;    
    for(n=0;n<4;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, critminstr[n]))!=NULL){
           trparms->critmin=(double)(*mxGetPr(Matmatrix));
           break;
        }
    }

    
    /* CRITTERM */
    trparms->critterm   = TRDCRITTERM;    
    for(n=0;n<4;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, crittermstr[n]))!=NULL){
           trparms->critterm=(double)(*mxGetPr(Matmatrix));
           break;
        }
    }

    /* GRADTERM */
    trparms->gradterm   = TRDGRADTERM;    
    for(n=0;n<4;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, gradtermstr[n]))!=NULL){
           trparms->gradterm=(double)(*mxGetPr(Matmatrix));
           break;
        }
    }

    /* PARAMTERM */
    trparms->paramterm   = TRDPARAMTERM;    
    for(n=0;n<4;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, paramtermstr[n]))!=NULL){
           trparms->paramterm=(double)(*mxGetPr(Matmatrix));
           break;
        }
    }

    /* Lambda */
    trparms->lambda   = TRDLAMBDA;    
    for(n=0;n<3;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, lambdastr[n]))!=NULL){
           trparms->lambda=(double)(*mxGetPr(Matmatrix));
           break;
        }
    }

    /* Skip */
    trparms->skip   = TRDSKIP;    
    for(n=0;n<3;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, skipstr[n]))!=NULL){
           trparms->skip=(int)(*mxGetPr(Matmatrix));
           break;
        }
    }


    /* D */
    for(n=0;n<2;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, Dstr[n]))!=NULL){
           decays = mxGetM(Matmatrix)*mxGetN(Matmatrix);
           trparms->D         = mmake(1,decays);
           M    = mxGetPr(Matmatrix);
           for(n=0;n<decays;n++){
              rvput(trparms->D,n,M[n]);
           }
           break;
        }
    }
    if(Matmatrix==NULL){
       trparms->D         = mmake(1,1);
       put_val(trparms->D,0,0,TRDD);
    }
}
  else
  {
    trparms->infolevel = TRDINFOLEVEL;
    trparms->maxiter   = TRDMAXITER;
    trparms->critmin   = TRDCRITMIN;
    trparms->critterm  = TRDCRITTERM;
    trparms->gradterm  = TRDGRADTERM;
    trparms->paramterm = TRDPARAMTERM;
    trparms->D         = mmake(1,1);
    put_val(trparms->D,0,0,TRDD);
    trparms->lambda    = TRDLAMBDA;
    trparms->skip      = TRDSKIP;
  }


  /*
   >>>>>>>>>>>>>>>>>>>>>>         CALL THE C-ROUTINE         <<<<<<<<<<<<<<<<<<<<<
   */
  nnssif(&NSSEvec, &iter, &lambda, NetDef, nx, W1, W2, obsidx, trparms, Y, U);


  /*
   >>>>>>>>>>>>>>>>>>>         CREATE OUTPUT MATICES            <<<<<<<<<<<<<<<<<<
   */
  plhs[0] = mxCreateDoubleMatrix(getrows(W1),getcols(W1),mxREAL);
  plhs[1] = mxCreateDoubleMatrix(getrows(W2),getcols(W2),mxREAL);
  plhs[2] = mxCreateDoubleMatrix(getrows(obsidx),getcols(obsidx),mxREAL);
  plhs[3] = mxCreateDoubleMatrix(getrows(NSSEvec),getcols(NSSEvec),mxREAL);
  plhs[4] = mxCreateDoubleMatrix(1,1,mxREAL);
  plhs[5] = mxCreateDoubleMatrix(1,1,mxREAL);

  sm2mat(plhs[0],W1);
  sm2mat(plhs[1],W2);
  sm2mat(plhs[2],obsidx);
  sm2mat(plhs[3],NSSEvec);
  M = mxGetPr(plhs[4]); M[0] = (double)iter;
  M = mxGetPr(plhs[5]); M[0] = (double)lambda;

  /*
   >>>>>>>>>>>>>>>>>>>>        FREE ARGUMENT MATRICES        <<<<<<<<<<<<<<<<<<<<<
   */
  mfree(NetDef);
  mfree(U);
  mfree(Y);
  mfree(trparms->D);
  free(trparms);
}



⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?