⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nnoe.c

📁 基于MATLAB的神经网络非线性系统辨识软件包.
💻 C
📖 第 1 页 / 共 2 页
字号:
    		ii = (int)cvget(theta_index,i);
    		for(j=i; j<reduced; j++){
      			jj = (int)cvget(theta_index,j);
      			for(sum=0.0,k=skip; k<N; k++)
				sum += get_val(PSI,ii,k)*get_val(PSI,jj,k);
      			put_val(H,i,j,sum);
      			put_val(H,j,i,sum);	
    		}
  	}
  for(i=0;i<reduced;i++)                            /* Add diagonal matrix     */
    put_val(H,i,i,get_val(H,i,i)+cvget(D,i));               
 }

/*
 >>>>>>>>>>>>>>>>>>>>>>>>>>>        COMPUTE h_k        <<<<<<<<<<<<<<<<<<<<<<<<<<<
 */
 
  /* -- Hessian (H = R + lambda*I + D)  --*/
  tmp1 = lambda - lambda_old;
  for(i=0;i<reduced;i++)                            /* Add diagonal matrix     */
    put_val(H,i,i,get_val(H,i,i)+tmp1);               

  /* -- Search direction -- */
  choldc(H, Htmp);
  cholsl(Htmp,h,G);

  /* -- Compute 'apriori' iterate -- */
  madd(theta_red_new,theta_red,h);                  /* Update parameter vector */
  mcopyi(theta,theta_index,index0,theta_red_new,index7,index0);

  /* -- Put the parameters back into the weight matrices -- */
  v2mreshape(W1_new,theta,parameters2);
  v2mreshape(W2_new,theta,0);


  /*
   >>>>>>>>>>>>>       Compute network output y2(theta+h)          <<<<<<<<<<<<<< 
  */
  for(t=0;t<N;t++){
	mvmul(h1,W1_new,PHI,t);
	vtanh(y1,H_hidden,t,h1,H_hidden,0);
	vcopyi(y1,L_hidden,t,h1,L_hidden,0);
	
	mvmul(h2,W2_new,y1,t);
	vtanh(y2,H_output,t,h2,H_output,0);
	vcopyi(y2,L_output,t,h2,L_output,0);
	
	j=na;
	if(N-t-1<na) j=N-t-1;
	for(i=1;i<=j;i++){
		put_val(PHI,i-1,t+i,rvget(y2,t));
	}
  }
  for(t=0;t<N;t++)                                     /* Prediction error      */
              rvput(E_new,t,rvget(Y2,t)-rvget(y2,t)); 
  for(SSE_new=0,t=skip;t<N;t++)                        /* Sum of squared errors */
              SSE_new+=rvget(E_new,t)*rvget(E_new,t);
  for(tmp1=0,i=0;i<reduced;i++) tmp1+=cvget(theta_red_new,i)*cvget(theta_red_new,i)*cvget(D,i); 
  NSSE_new = (SSE_new+tmp1)/(2*N2);                    /* Value of cost function*/


  /*
   >>>>>>>>>>>>>>>>>>>>>>>>>>>       UPDATE  lambda     <<<<<<<<<<<<<<<<<<<<<<<<<<
   */
    lambda_old = lambda;
    for(tmp1=0,i=0;i<reduced;i++) tmp1+=cvget(h,i)*cvget(h,i)*(cvget(D,i)+lambda);
    L = sprod3(h,G) + tmp1;

    /* Decrease lambda if SSE has fallen 'sufficiently' */
    if(2*N2*(NSSE - NSSE_new) > (0.75*L)) lambda = lambda/2;
  
    /* Increase lambda if SSE has grown 'sufficiently'  */
    else if(2*N2*(NSSE-NSSE_new) <= (0.25*L)) lambda = 2*lambda;  


  /*
   >>>>>>>>>>>>>>>>>>>       UPDATES FOR NEXT ITERATION        <<<<<<<<<<<<<<<<<<<<
   */
    /* Update only if criterion has decreased */
    if(NSSE_new<NSSE)
    {
     critdif  = NSSE-NSSE_new;                           /* Criterion difference */
     for(i=0,gradmax=0.0,ptm1=G->mat[0];i<reduced;i++){  /* Maximum gradient     */
        sum = fabs(*(ptm1++));
        if(gradmax<sum)
           gradmax = sum;
     }
     gradmax/=N2;
     ptm1=theta_red_new->mat[0];
     ptm2=theta_red->mat[0];
     for(i=0,paramdif=0.0;i<reduced;i++){  /* Maximum gradient     */
        sum = fabs(*(ptm1++) - *(ptm2++));
        if(paramdif<sum)
           paramdif = sum;
     }
     lambda_old = 0.0;
     tmp = W1; W1=W1_new; W1_new=tmp;
     tmp = W2; W2=W2_new; W2_new=tmp;
     tmp = theta_red; theta_red=theta_red_new; theta_red_new = tmp;
     tmp = E; E = E_new; E_new = tmp;
     dw = 1;
     NSSE = NSSE_new;
     cvput(NSSEvec,iteration-1,NSSE);
     switch(trparms->infolevel){                            /* Print on-line inform */
       case 1:
          printf("# %i   W=%4.3e  critdif=%3.2e  maxgrad=%3.2e  paramdif=%3.2e\n",
                                                  iteration,NSSE,critdif,gradmax,paramdif);
          break;
       default:
          printf("iteration # %i   W = %4.3e\r",iteration,NSSE);
     }
     ++iteration;
   }
}


/*
 >>>>>>>>>>    RETURN POINTERS TO RETURN ARGUMENTS & FREE MEMORY    <<<<<<<<<<<<
 */

/* Swap pointers if they have been messed up */
if ((iteration&1) == 0) {
	mset(W1_new,W1);
	tmp = W1; W1=W1_new; W1_new=tmp;
	mset(W2_new,W2);
     	tmp = W2; W2=W2_new; W2_new=tmp;
}
iteration=iteration-1;
if(iteration==0){
	*NSSEvecpp = mmake(1,1);
	(*NSSEvecpp)->row=0;
	(*NSSEvecpp)->col=0;
}
else {
	*NSSEvecpp = mmake(iteration,1);
	subvec(*NSSEvecpp,NSSEvec,0,iteration-1);
}
*iter = iteration;
*lam  = lambda;
mfree(L_hidden); mfree(H_hidden); mfree(L_output); mfree(H_output);
mfree(h1); mfree(h2); mfree(y1); mfree(y2); mfree(Y2);
mfree(E); mfree(E_new); mfree(dy2dy1); mfree(dy2dy); mfree(dy1dy); mfree(dy2dy_vec);
mfree(W1_new); mfree(W2_new); mfree(D);mfree(Dtmp); mfree(NSSEvec);mfree(Htmp);
mfree(theta); mfree(thtmp); mfree(theta_index); mfree(theta_red); mfree(theta_red_new);
mfree(PSI); mfree(G); mfree(H); mfree(h);
mfree(all); mfree(index0); mfree(index7);mfree(onesvec);mfree(tmp0); mfree(tmp2);
mfree(tmp3); mfree(index); mfree(index2); mfree(PHI); mfree(nb); mfree(nk);
printf("\n\nNetwork training ended.\n\n\n");
}
/*
  --------------------------------------------------------------------------------
  ----------------             END OF NETWORK TRAINING              --------------
  --------------------------------------------------------------------------------
*/


/*********************************************************************************
 *                                                                               *
 *                           G A T E W A Y   R O U T I N E                       *
 *                                                                               *
 *********************************************************************************/
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
  /*
   >>>>>>>>>>>>>>>>>>           VARIABLE DECLARATIONS          <<<<<<<<<<<<<<<<<<<
   */
   matrix *NSSEvec, *NN, *nb;
   matrix *NetDef, *W1, *W2, *U, *Y;
   double *M, lambda;
   int iter, m, N, na, nab, nu, hidden, k, n, a, decays;
   trparmstruct *trparms;
   mxArray  *Matmatrix;
   char *infolevelstr[] = {"infolevel", "Infolevel", "INFOLEVEL", "InfoLevel"};
   char *maxiterstr[] = {"maxiter", "MAXITER", "Maxiter", "MaxIter"};
   char *critminstr[] = {"critmin", "Critmin", "CRITMIN", "CritMin"};
   char *crittermstr[] = {"critterm", "Critterm", "CRITTERM", "CritTerm"};
   char *gradtermstr[] = {"gradterm", "Gradterm", "GRADTERM", "GradTerm"};
   char *paramtermstr[] = {"paramterm", "Paramterm", "PARAMTERM", "ParamTerm"};
   char *Dstr[] = {"D", "d"};
   char *lambdastr[] = {"lambda", "Lambda", "LAMBDA"};
   char *skipstr[] = {"skip", "Skip", "SKIP"};

  /*
   >>>>>>>>>>>>>>>>      CHECK FOR PROPER NUMBER OF ARGUMENTS      <<<<<<<<<<<<<<<
   */
   if (nrhs!=7) mexErrMsgTxt("Wrong number of input arguments");
   else if (nlhs > 5) mexErrMsgTxt("Too many output arguments");
   nu = mxGetM(prhs[6]);     /* # of control signals */
   m  = mxGetM(prhs[5]);     /* Rows of vector Y */
   if(m!=1) mexErrMsgTxt("Wrong dimension of vector of observed outputs");
   N = mxGetN(prhs[5]);      /* Columns of vector Y */
   if(N!=mxGetN(prhs[6])) mexErrMsgTxt("U and Y should have the same number of columns!");
   hidden = mxGetN(prhs[0]); /* # of hidden units */
   if(mxGetM(prhs[0])!=2) mexErrMsgTxt("Error in architecture definition!");
   if(mxGetM(prhs[1])!=1) mexErrMsgTxt("NN should only have 1 row!");
   if(mxGetN(prhs[1])!=2*nu+1) mexErrMsgTxt("Mismatch between U and NN");
         

  /*
   >>>>>>>>>>>>>>>>>     CONVERT INPUT ARGUMENTS TO SM FORMAT     <<<<<<<<<<<<<<<<
   */
  NetDef  = matstring2sm(prhs[0]); /* Network architecture */
  NN      = mat2sm(prhs[1]);     /* Regressor structure  */
  Y    = mat2sm(prhs[5]);        /* Vector of observed outputs  */
  U    = mat2sm(prhs[6]);        /* Vector of inputs            */
  
  
  /*
   >>>>>>>>>>>>>>>>      CHECK FOR PROPER NUMBER OF ARGUMENTS      <<<<<<<<<<<<<<<
   */
   na        = vget(NN,0);  /* Past predictions used as inputs */   
   nb        = mmake(1,nu); /* Past controls used as input     */
   subvec(nb,NN,1,nu);
   nab       = na;          /* na+nb                           */
   for(k=0;k<nu;k++) nab=nab+rvget(nb,k);
   
   /* Initialize weight matrices if passed as [] */
   if(mxGetM(prhs[2])==0 || mxGetN(prhs[2])==0 || mxGetM(prhs[3])==0\
                         || mxGetN(prhs[3])==0){
   	W1 = mmake(hidden,nab+1);
   	W2 = mmake(1,hidden+1);
      	W2->row = 0;   /* Hack telling that the weights should be initialized */ 

   }
   else{
   	if(mxGetM(prhs[2])!=hidden) mexErrMsgTxt("W1 has the wrong dimension");
   	if(mxGetN(prhs[2])!=nab+1) mexErrMsgTxt("W1 has the wrong dimension");
   	if(mxGetM(prhs[3])!=1) mexErrMsgTxt("W2 has the wrong dimension");
   	if(mxGetN(prhs[3])!=hidden+1) mexErrMsgTxt("W2 has the wrong dimension");
   	W1 = mat2sm(prhs[2]);     /* Input-to-hidden layer weights */
   	W2 = mat2sm(prhs[3]);     /* Hidden-to-output layer weights */
   }


 trparms = (trparmstruct*)malloc(sizeof(trparmstruct)); 
 a = 4;
 if(mxGetN(prhs[a])!=0|| mxGetM(prhs[a])!=0) {
    /* INFOLEVEL */
    trparms->infolevel   = TRDINFOLEVEL;    
    for(n=0;n<4;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, infolevelstr[n]))!=NULL){
           trparms->infolevel=(int)(*mxGetPr(Matmatrix));
           break;
        }
    }

    /* MAXITER */
    trparms->maxiter   = TRDMAXITER;    
    for(n=0;n<4;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, maxiterstr[n]))!=NULL){
           trparms->maxiter=(int)(*mxGetPr(Matmatrix));
           break;
        }
    }

    /* CRITMIN */
    trparms->critmin   = TRDCRITMIN;    
    for(n=0;n<4;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, critminstr[n]))!=NULL){
           trparms->critmin=(double)(*mxGetPr(Matmatrix));
           break;
        }
    }

    
    /* CRITTERM */
    trparms->critterm   = TRDCRITTERM;    
    for(n=0;n<4;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, crittermstr[n]))!=NULL){
           trparms->critterm=(double)(*mxGetPr(Matmatrix));
           break;
        }
    }

    /* GRADTERM */
    trparms->gradterm   = TRDGRADTERM;    
    for(n=0;n<4;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, gradtermstr[n]))!=NULL){
           trparms->gradterm=(double)(*mxGetPr(Matmatrix));
           break;
        }
    }

    /* PARAMTERM */
    trparms->paramterm   = TRDPARAMTERM;    
    for(n=0;n<4;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, paramtermstr[n]))!=NULL){
           trparms->paramterm=(double)(*mxGetPr(Matmatrix));
           break;
        }
    }

    /* Lambda */
    trparms->lambda   = TRDLAMBDA;    
    for(n=0;n<3;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, lambdastr[n]))!=NULL){
           trparms->lambda=(double)(*mxGetPr(Matmatrix));
           break;
        }
    }

    /* Skip */
    trparms->skip   = TRDSKIP;    
    for(n=0;n<3;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, skipstr[n]))!=NULL){
           trparms->skip=(int)(*mxGetPr(Matmatrix));
           break;
        }
    }


    /* D */
    for(n=0;n<2;n++){
        if ((Matmatrix=mxGetField(prhs[a], 0, Dstr[n]))!=NULL){
           decays = mxGetM(Matmatrix)*mxGetN(Matmatrix);
           trparms->D         = mmake(1,decays);
           M    = mxGetPr(Matmatrix);
           for(n=0;n<decays;n++){
              rvput(trparms->D,n,M[n]);
           }
           break;
        }
    }
    if(Matmatrix==NULL){
       trparms->D         = mmake(1,1);
       put_val(trparms->D,0,0,TRDD);
    }
}
  else
  {
    trparms->infolevel = TRDINFOLEVEL;
    trparms->maxiter   = TRDMAXITER;
    trparms->critmin   = TRDCRITMIN;
    trparms->critterm  = TRDCRITTERM;
    trparms->gradterm  = TRDGRADTERM;
    trparms->paramterm = TRDPARAMTERM;
    trparms->D         = mmake(1,1);
    put_val(trparms->D,0,0,TRDD);
    trparms->lambda    = TRDLAMBDA;
    trparms->skip      = TRDSKIP;
  }


  /*
   >>>>>>>>>>>>>>>>>>>>>>         CALL THE C-ROUTINE         <<<<<<<<<<<<<<<<<<<<<
   */
  nnoe(&NSSEvec, &iter, &lambda, NetDef, NN, W1, W2, trparms, Y, U);


  /*
   >>>>>>>>>>>>>>>>>>>         CREATE OUTPUT MATICES            <<<<<<<<<<<<<<<<<<
   */
  plhs[0] = mxCreateDoubleMatrix(getrows(W1),getcols(W1),mxREAL);
  plhs[1] = mxCreateDoubleMatrix(getrows(W2),getcols(W2),mxREAL);
  plhs[2] = mxCreateDoubleMatrix(getrows(NSSEvec),getcols(NSSEvec),mxREAL);
  plhs[3] = mxCreateDoubleMatrix(1,1,mxREAL);
  plhs[4] = mxCreateDoubleMatrix(1,1,mxREAL);

  sm2mat(plhs[0],W1);
  sm2mat(plhs[1],W2);
  sm2mat(plhs[2],NSSEvec);
  M = mxGetPr(plhs[3]); M[0] = (double)iter;
  M = mxGetPr(plhs[4]); M[0] = (double)lambda;

  /*
   >>>>>>>>>>>>>>>>>>>>        FREE ARGUMENT MATRICES        <<<<<<<<<<<<<<<<<<<<<
   */
  mfree(NetDef);
  mfree(NN);
  mfree(U);
  mfree(Y);
  mfree(trparms->D);
  free(trparms); 
}



⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -