⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nnssif.c

📁 matlab实现神经网络程序集合
💻 C
📖 第 1 页 / 共 2 页
字号:
    for(t=0;t<N;t++){
    	/*-- Derivative of states wrt. hidden outputs --*/
    	for(j=0;j<louts;j++){
    		i=(int)cvget(L_output,j);
    		for(k=0;k<hidden;k++) put_val(dxdy1,i,k,get_val(W2,i,k));
    	}
    	for(j=0;j<houts;j++){
    		i=(int)cvget(H_output,j);
    		for(k=0;k<hidden;k++) put_val(dxdy1,i,k,get_val(W2,i,k)*(1-\
    	                             	get_val(y2,i,t)*get_val(y2,i,t)));
    	}

      	/*-- Partial deriv. of output from each hidden unit wrt. net inputs --*/
      	for(j=0;j<lhids;j++){
      		i=(int)cvget(L_hidden,j);
		for(k=0;k<nx;k++) put_val(dy1dx,i,k,get_val(W1,i,k));    		
      		for(k=nxu;k<inputs;k++) put_val(dy1de,i,k-nxu,get_val(W1,i,k));
      	}
      	for(j=0;j<hhids;j++){
      		i=(int)cvget(H_hidden,j);
      		for(k=0;k<nx;k++) put_val(dy1dx,i,k,\
      			get_val(W1,i,k)*(1-get_val(y1,i,t)*get_val(y1,i,t)));
      		for(k=nxu;k<inputs;k++) put_val(dy1de,i,\
      			k-nxu,get_val(W1,i,k)*(1-get_val(y1,i,t)*get_val(y1,i,t)));
      	}

     	/*--Partial derivative of states w.r.t. past states and residuals --*/
     	mmul(Ahat,dxdy1,dy1dx);
     	for(k=0;k<nx-ny;k++) put_val(Ahat,(int)cvget(nrowidx,k),\
     					(int)cvget(nrowidx,k)+1,1.0);                
     	mmul(Khat,dxdy1,dy1de);
     	mmul(AKC,Khat,C);
     	msub(AKC,Ahat,AKC);


    /* 
     >>>>>>>>>>>>>>>>>>>     Filter partial derivatives        <<<<<<<<<<<<<<<<<<<<  
    */
    	if(t>=1){
    		/* PSIx = PSIx + PSIx*AKC' */
    		index5 = t*nx;
    		index6 = (t-1)*nx;
		for(i=0;i<reduced;i++){
			ii =(int)cvget(theta_index,i);
			for(j=0;j<nx;j++){
			   for(k=0;k<nx;k++){
			     PSIx->mat[ii][index5+j]+=get_val(PSIx,ii,index6+k)*\
				 	get_val(AKC,j,k);
			   }
			}
		}
	}

	
	/*PSI=PSIx*C';*/
	index5 = t*ny;
    	index6 = t*nx;
	for(i=0;i<reduced;i++){
		ii =(int)cvget(theta_index,i);
		for(j=0;j<ny;j++){
			for(sum=0,k=0;k<nx;k++){
			     sum+=get_val(PSIx,ii,index6+k)*get_val(C,j,k);
			}
			put_val(PSI,ii,index5+j,sum);
		}
	}
	
    }
	minit(PSIx);
	dw = 0;
	/* 
     	 >>>>>>>>>>>>  Gradient (G = PSI_red*E_vector - D*theta_red)  <<<<<<<<<<<<<  
         */    
     	for(i=0; i<reduced; i++){
      		ii = (int)cvget(theta_index,i);
    		for(sum=0.0,k=skipstart; k<Nny; k++)
    			sum+=get_val(PSI,ii,k)*rvget(Evec,k);
    		cvput(G,i,sum - cvget(D,i)*cvget(theta_red,i));
      	}

    	/* 
     	 >>>>>>>>>> Mean square error part of Hessian (PSI_red*PSI_red') <<<<<<<<<<  
         */
    	for(i=0; i<reduced; i++){
    		ii = (int)cvget(theta_index,i);
    		for(j=i; j<reduced; j++){
      			jj = (int)cvget(theta_index,j);
      			for(sum=0.0,k=skipstart; k<Nny; k++)
				sum += get_val(PSI,ii,k)*get_val(PSI,jj,k);
      			put_val(R,i,j,sum);
      			put_val(R,j,i,sum);	
    		}
  	}
 }

/*
 >>>>>>>>>>>>>>>>>>>>>>>>>>>        COMPUTE h_k        <<<<<<<<<<<<<<<<<<<<<<<<<<<
 */
 
  /* -- Hessian (H = R + lambda*I + D)  --*/
  mset(H,R);
  for(i=0;i<reduced;i++)                            /* Add diagonal matrix     */
    put_val(H,i,i,get_val(H,i,i)+lambda+cvget(D,i));               

  /* -- Search direction -- */
  choldc(H, Htmp);
  cholsl(Htmp,h,G);

  /* -- Compute 'apriori' iterate -- */
  madd(theta_red_new,theta_red,h);                  /* Update parameter vector */
  mcopyi(theta,theta_index,index0,theta_red_new,index7,index0);

  /* -- Put the parameters back into the weight matrices -- */
  v2mreshape(W1_new,theta,parameters2);
  v2mreshape(W2_new,theta,0);


  /*
   >>>>>>>>>>>>>       Compute network output y2(theta+h)          <<<<<<<<<<<<<< 
  */
  for(t=0;t<N;t++){
	mvmul(h1,W1_new,PHI,t);
	vtanh(y1,H_hidden,t,h1,H_hidden,0);
	vcopyi(y1,L_hidden,t,h1,L_hidden,0);
	
	mvmul(h2,W2_new,y1,t);
	vtanh(y2,H_output,t,h2,H_output,0);
	vcopyi(y2,L_output,t,h2,L_output,0);
	
	for(k=0;k<(nx-ny);k++){
		i=(int)cvget(nrowidx,k);
		y2->mat[i][t]+=get_val(PHI,i+1,t);
	}
	mvmul(Yhat,C,y2,t);                             /* Output prediction     */
	
	for(i=0;i<ny;i++){
		cvput(E,i,get_val(Y2,i,t)-cvget(Yhat,i));/* Prediction error     */
		rvput(Evec_new,t*ny+i,cvget(E,i));       /* Store E in Evec_new  */
	}
	if(t<N-1){
		for(i=0;i<nx;i++)
			put_val(PHI,i,t+1,get_val(y2,i,t));
		for(i=0;i<ny;i++)
			put_val(PHI,nx+nu+i,t+1,cvget(E,i));
	}
  }	
  for(SSE_new=0,t=skipstart;t<Nny;t++)
	SSE_new+=rvget(Evec_new,t)*rvget(Evec_new,t);   /* Sum of squared errors */
  for(tmp1=0,i=0;i<reduced;i++) tmp1+=cvget(theta_red,i)*cvget(theta_red,i)*cvget(D,i); 
	NSSE_new = (SSE_new+tmp1)/(2*N2);                               /* Value of cost function*/



  /*
   >>>>>>>>>>>>>>>>>>>>>>>>>>>       UPDATE  lambda     <<<<<<<<<<<<<<<<<<<<<<<<<<
   */
    for(tmp1=0,i=0;i<reduced;i++) tmp1+=cvget(h,i)*cvget(h,i)*(cvget(D,i)+lambda);
    L = sprod3(h,G) + tmp1;

    /* Decrease lambda if SSE has fallen 'sufficiently' */
    if(2*N2*(NSSE - NSSE_new) > (0.75*L)) lambda = lambda/2;
  
    /* Increase lambda if SSE has grown 'sufficiently'  */
    else if(2*N2*(NSSE-NSSE_new) <= (0.25*L)) lambda = 2*lambda;  


  /*
   >>>>>>>>>>>>>>>>>>>       UPDATES FOR NEXT ITERATION        <<<<<<<<<<<<<<<<<<<<
   */
    /* Update only if criterion has decreased */
    if(NSSE_new<NSSE)
    {
     tmp = W1; W1=W1_new; W1_new=tmp;
     tmp = W2; W2=W2_new; W2_new=tmp;
     tmp = theta_red; theta_red=theta_red_new; theta_red_new = tmp;
     tmp = Evec; Evec = Evec_new; Evec_new = tmp;
     dw = 1;
     NSSE = NSSE_new;
     cvput(NSSEvec,iteration-1,NSSE);
     printf("iteration # %i   PI = %4.3e\r",iteration,NSSE); /* On-line information  */
     ++iteration;
     }
}


/*
 >>>>>>>>>>    RETURN POINTERS TO RETURN ARGUMENTS & FREE MEMORY    <<<<<<<<<<<<
 */

/* Swap pointers if they have been messed up */
if ((iteration&1) == 0) {
	mset(W1_new,W1);
	tmp = W1; W1=W1_new; W1_new=tmp;
	mset(W2_new,W2);
     	tmp = W2; W2=W2_new; W2_new=tmp;
}
iteration=iteration-1;
if(iteration==0){
	*NSSEvecpp = mmake(1,1);
	(*NSSEvecpp)->row=0;
	(*NSSEvecpp)->col=0;
}
else{
	*NSSEvecpp = mmake(iteration,1);
	subvec(*NSSEvecpp,NSSEvec,0,iteration-1);
}
*iter = iteration;
*lam  = lambda;
mfree(L_hidden); mfree(H_hidden); mfree(L_output); mfree(H_output); mfree(Evec);
mfree(h1); mfree(h2); mfree(y1); mfree(y2); mfree(Y2); mfree(Ahat); mfree(C);
mfree(E); mfree(Evec_new); mfree(dxdy1); mfree(Khat); mfree(dy1de); mfree(dy1dx);
mfree(W1_new); mfree(W2_new); mfree(D);mfree(Dtmp); mfree(NSSEvec);mfree(Htmp);
mfree(theta); mfree(thtmp); mfree(theta_index); mfree(theta_red); mfree(theta_red_new);
mfree(PSI); mfree(PSIx); mfree(G); mfree(H); mfree(R); mfree(h), mfree(Yhat);
mfree(all); mfree(index0); mfree(index7);mfree(onesvec);mfree(tmp0); mfree(tmp2);
mfree(tmp3); mfree(index); mfree(index2); mfree(PHI);
mfree(rowidx); mfree(nrowidx); mfree(AKC);

lt = time(NULL);
c  = localtime(&lt);
printf("\n\nNetwork training ended at %.8s\n\n\n",asctime(c)+11);
}
/*
  --------------------------------------------------------------------------------
  ----------------             END OF NETWORK TRAINING              --------------
  --------------------------------------------------------------------------------
*/


/*********************************************************************************
 *                                                                               *
 *                           G A T E W A Y   R O U T I N E                       *
 *                                                                               *
 *********************************************************************************/
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
  /*
   >>>>>>>>>>>>>>>>>>           VARIABLE DECLARATIONS          <<<<<<<<<<<<<<<<<<<
   */
   matrix *NSSEvec;
   matrix *NetDef, *W1, *W2, *obsidx, *U, *Y, *trparms;
   double *M, lambda;
   int iter, skip, ny, N, nu, nx, hidden, k;

  /*
   >>>>>>>>>>>>>>>>      CHECK FOR PROPER NUMBER OF ARGUMENTS      <<<<<<<<<<<<<<<
   */
   if (nrhs!=9) mexErrMsgTxt("Wrong number of input arguments");
   else if (nlhs > 6) mexErrMsgTxt("Two many output arguments");
   nu = mxGetM(prhs[8]);   /* # of control signals */
   ny  = mxGetM(prhs[7]);  /* Rows of vector Y */
   if(nu<1) mexErrMsgTxt("Wrong dimension of input vector");
   N = mxGetN(prhs[7]);      /* Columns of vector Y */
   if(N!=mxGetN(prhs[7])) mexErrMsgTxt("U and Y should have the same number of columns!");
   hidden = mxGetN(prhs[0]); /* # of hidden units */
   if(mxGetM(prhs[0])!=2) mexErrMsgTxt("Error in architecture definition!");
   if(hidden<2) mexErrMsgTxt("Use at least two hidden units!");
         

  /*
   >>>>>>>>>>>>>>>>>     CONVERT INPUT ARGUMENTS TO SM FORMAT     <<<<<<<<<<<<<<<<
   */
  NetDef  = matstring2sm(prhs[0]);     /* Network architecture */
  nx = (int)(*mxGetPr(prhs[1]));
  if(mxGetM(prhs[5])!=0)         /* Parameters associated with minimization */
 	trparms = mat2sm(prhs[5]);
  else{                          /* Use defaults if passed as [] */
   	trparms = mmake(1,4);
   	mput(trparms,0,0,500);
   	mput(trparms,0,1,0.0);
   	mput(trparms,0,2,1);
   	mput(trparms,0,3,0.0);
  }
  if(mxGetM(prhs[6])!=0)         /* Skip first gradients */
 	skip = (int)(*mxGetPr(prhs[6]));
  else skip=0;                   /* Use skip=0 if passed as []  */
  Y    = mat2sm(prhs[7]);        /* Vector of observed outputs  */
  U = mat2sm(prhs[8]);           /* Vector of inputs            */

  /* Initialize pseudo-observability indices if obsidx passed as [] */
  if(mxGetM(prhs[4])==0 && mxGetN(prhs[4])==0){
  	obsidx=mmake(1,ny);
  	minitx(obsidx,1.0);
  	vput(obsidx,ny-1,(double)nx-ny+1);
  }
  else
  	obsidx = mat2sm(prhs[4]);


  /* Initialize weight matrices if passed as [] */
  if(mxGetM(prhs[2])==0 || mxGetN(prhs[2])==0 || mxGetM(prhs[3])==0\
                        || mxGetN(prhs[3])==0){
	W1 = mmake(hidden,nx+nu+ny+1);
   	W2 = mmake(nx,hidden+1);
      	W2->row = 0;   /* Hack telling that the weights should be initialized */ 
   }
   else{
   	if(mxGetM(prhs[2])!=hidden) mexErrMsgTxt("W1 has the wrong dimension");
   	if(mxGetN(prhs[2])!=nx+nu+ny+1) mexErrMsgTxt("W1 has the wrong dimension");
   	if(mxGetM(prhs[3])!=nx) mexErrMsgTxt("W2 has the wrong dimension");
   	if(mxGetN(prhs[3])!=hidden+1) mexErrMsgTxt("W2 has the wrong dimension");
   	W1 = mat2sm(prhs[2]);     /* Input-to-hidden layer weights */
   	W2 = mat2sm(prhs[3]);     /* Hidden-to-output layer weights */
   }
   

  /*
   >>>>>>>>>>>>>>>>>>>>>>         CALL THE C-ROUTINE         <<<<<<<<<<<<<<<<<<<<<
   */
  nnssif(&NSSEvec, &iter, &lambda, NetDef, nx, W1, W2, obsidx, trparms, skip, Y, U);


  /*
   >>>>>>>>>>>>>>>>>>>         CREATE OUTPUT MATICES            <<<<<<<<<<<<<<<<<<
   */
  plhs[0] = mxCreateDoubleMatrix(getrows(W1),getcols(W1),mxREAL);
  plhs[1] = mxCreateDoubleMatrix(getrows(W2),getcols(W2),mxREAL);
  plhs[2] = mxCreateDoubleMatrix(getrows(obsidx),getcols(obsidx),mxREAL);
  plhs[3] = mxCreateDoubleMatrix(getrows(NSSEvec),getcols(NSSEvec),mxREAL);
  plhs[4] = mxCreateDoubleMatrix(1,1,mxREAL);
  plhs[5] = mxCreateDoubleMatrix(1,1,mxREAL);

  sm2mat(plhs[0],W1);
  sm2mat(plhs[1],W2);
  sm2mat(plhs[2],obsidx);
  sm2mat(plhs[3],NSSEvec);
  M = mxGetPr(plhs[4]); M[0] = (double)iter;
  M = mxGetPr(plhs[5]); M[0] = (double)lambda;

  /*
   >>>>>>>>>>>>>>>>>>>>        FREE ARGUMENT MATRICES        <<<<<<<<<<<<<<<<<<<<<
   */
  mfree(NetDef);
  mfree(U);
  mfree(Y);
  mfree(trparms);
}



⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -