⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nnarmax2.c

📁 类神经网路─MATLAB的应用(范例程式)
💻 C
📖 第 1 页 / 共 2 页
字号:
     >>>>>>>>>>>>>>>>>>>>        Linearize network           <<<<<<<<<<<<<<<<<<<<<  
    */
    for(t=0;t<N;t++){
    	/*-- Derivative of output wrt. hidden outputs --*/
    	if(louts==1) for(k=0;k<hidden;k++) rvput(dy2dy1,k,rvget(W2,k));
    	else for(k=0;k<hidden;k++) rvput(dy2dy1,k,rvget(W2,k)*(1-\
    	                             		rvget(y2,t)*rvget(y2,t)));

      	/*-- Partial deriv. of output from each hidden unit wrt. past net outp. --*/
      	for(j=0;j<lhids;j++){
      		i=(int)cvget(L_hidden,j);
      		for(k=nab;k<nabc;k++) put_val(dy1de,i,k-nab,\
      		        get_val(W1,i,k));
      	}
      	for(j=0;j<hhids;j++){
      		i=(int)cvget(H_hidden,j);
      		for(k=nab;k<nabc;k++) put_val(dy1de,i,k-nab,\
      			get_val(W1,i,k)*(1-get_val(y1,i,t)*get_val(y1,i,t)));
      	}

     	/*--Partial derivative of net output w.r.t. past net outputs --*/
     	mmul(dy2de_vec,dy2dy1,dy1de);
     	for(i=0;i<nc;i++) put_val(dy2de,i,t,rvget(dy2de_vec,i));
    }


    /* 
     >>>>>>>>>>>>>>>>>>>     Filter partial derivatives        <<<<<<<<<<<<<<<<<<<<  
    */
	for(t=0;t<N;t++){
		j=nc;
		if(t<nc) j=t;
		for(k=1;k<=j;k++){
			for(i=0;i<reduced;i++) {
			  ii =(int)cvget(theta_index,i);
			  PSI->mat[ii][t] -= get_val(dy2de,k-1,t)*get_val(PSI,ii,t-k);
			  }
		}
	}

	dw = 0;
	/* 
     	 >>>>>>>>>>>>  Gradient (G = PSI_red*E_vector - D*theta_red)  <<<<<<<<<<<<<  
         */    
     	for(i=0; i<reduced; i++){
      		ii = (int)cvget(theta_index,i);
    		for(sum=0.0,k=skip; k<N; k++) sum+=get_val(PSI,ii,k)*rvget(E,k);
    		cvput(G,i,sum - cvget(D,i)*cvget(theta_red,i));
      	}

    	/* 
     	 >>>>>>>>>> Mean square error part of Hessian (PSI_red*PSI_red') <<<<<<<<<<  
         */
    	for(i=0; i<reduced; i++){
    		ii = (int)cvget(theta_index,i);
    		for(j=i; j<reduced; j++){
      			jj = (int)cvget(theta_index,j);
      			for(sum=0.0,k=skip; k<N; k++)
				sum += get_val(PSI,ii,k)*get_val(PSI,jj,k);
      			put_val(R,i,j,sum);
      			put_val(R,j,i,sum);	
    		}
  	}
 }

/*
 >>>>>>>>>>>>>>>>>>>>>>>>>>>        COMPUTE h_k        <<<<<<<<<<<<<<<<<<<<<<<<<<<
 */
 
  /* -- Hessian (H = R + lambda*I + D)  --*/
  mset(H,R);
  for(i=0;i<reduced;i++)                            /* Add diagonal matrix     */
    put_val(H,i,i,get_val(H,i,i)+lambda+cvget(D,i));               

  /* -- Search direction -- */
  choldc(H, Htmp);
  cholsl(Htmp,h,G);

  /* -- Compute 'apriori' iterate -- */
  madd(theta_red_new,theta_red,h);                  /* Update parameter vector */
  mcopyi(theta,theta_index,index0,theta_red_new,index7,index0);

  /* -- Put the parameters back into the weight matrices -- */
  v2mreshape(W1_new,theta,parameters2);
  v2mreshape(W2_new,theta,0);


  /*
   >>>>>>>>>>>>>       Compute network output y2(theta+h)          <<<<<<<<<<<<<< 
  */
  for(t=0;t<N;t++){
	mvmul(h1,W1_new,PHI,t);
	vtanh(y1,H_hidden,t,h1,H_hidden,0);
	vcopyi(y1,L_hidden,t,h1,L_hidden,0);
	
	mvmul(h2,W2_new,y1,t);
	vtanh(y2,H_output,t,h2,H_output,0);
	vcopyi(y2,L_output,t,h2,L_output,0);
	
	rvput(E_new,t,rvget(Y2,t)-rvget(y2,t));          /* Prediction error        */	
	j=nc;
	if(N-t-1<nc) j=N-t-1;
	for(i=1;i<=j;i++){
		put_val(PHI,nab+i-1,t+i,rvget(E_new,t));
	}
  }
  for(SSE_new=0,t=skip;t<N;t++)                        /* Sum of squared errors */
              SSE_new+=rvget(E_new,t)*rvget(E_new,t);
  for(tmp1=0,i=0;i<reduced;i++) tmp1+=cvget(theta_red,i)*cvget(theta_red,i)*cvget(D,i); 
  NSSE_new = (SSE_new+tmp1)/(2*N2);                    /* Value of cost function*/


  /*
   >>>>>>>>>>>>>>>>>>>>>>>>>>>       UPDATE  lambda     <<<<<<<<<<<<<<<<<<<<<<<<<<
   */
    for(tmp1=0,i=0;i<reduced;i++) tmp1+=cvget(h,i)*cvget(h,i)*(cvget(D,i)+lambda);
    L = sprod3(h,G) + tmp1;

    /* Decrease lambda if SSE has fallen 'sufficiently' */
    if(2*N2*(NSSE - NSSE_new) > (0.75*L)) lambda = lambda/2;
  
    /* Increase lambda if SSE has grown 'sufficiently'  */
    else if(2*N2*(NSSE-NSSE_new) <= (0.25*L)) lambda = 2*lambda;  


  /*
   >>>>>>>>>>>>>>>>>>>       UPDATES FOR NEXT ITERATION        <<<<<<<<<<<<<<<<<<<<
   */
    /* Update only if criterion has decreased */
    if(NSSE_new<NSSE)
    {
     tmp = W1; W1=W1_new; W1_new=tmp;
     tmp = W2; W2=W2_new; W2_new=tmp;
     tmp = theta_red; theta_red=theta_red_new; theta_red_new = tmp;
     tmp = E; E = E_new; E_new = tmp;
     dw = 1;
     NSSE = NSSE_new;
     cvput(NSSEvec,iteration-1,NSSE);
     printf("iteration # %i   PI = %4.3e\r",iteration,NSSE); /* On-line information  */
     ++iteration;
   }
}


/*
 >>>>>>>>>>    RETURN POINTERS TO RETURN ARGUMENTS & FREE MEMORY    <<<<<<<<<<<<
 */

/* Swap pointers if they have been messed up */
if ((iteration&1) == 0) {
	mset(W1_new,W1);
	tmp = W1; W1=W1_new; W1_new=tmp;
	mset(W2_new,W2);
     	tmp = W2; W2=W2_new; W2_new=tmp;
}
iteration=iteration-1;
if(iteration==0){
	*NSSEvecpp = mmake(1,1);
	(*NSSEvecpp)->row=0;
	(*NSSEvecpp)->col=0;
}
else {
	*NSSEvecpp = mmake(iteration,1);
	subvec(*NSSEvecpp,NSSEvec,0,iteration-1);
}
*iter = iteration;
*lam  = lambda;
mfree(L_hidden); mfree(H_hidden); mfree(L_output); mfree(H_output);
mfree(h1); mfree(h2); mfree(y1); mfree(y2); mfree(Y2);
mfree(E); mfree(E_new); mfree(dy2dy1); mfree(dy2de); mfree(dy1de); mfree(dy2de_vec);
mfree(W1_new); mfree(W2_new); mfree(D);mfree(Dtmp); mfree(NSSEvec);mfree(Htmp);
mfree(theta); mfree(thtmp); mfree(theta_index); mfree(theta_red); mfree(theta_red_new);
mfree(PSI); mfree(G); mfree(H); mfree(R); mfree(h);
mfree(all); mfree(index0); mfree(index7);mfree(onesvec);mfree(tmp0); mfree(tmp2);
mfree(tmp3); mfree(index); mfree(index2); mfree(PHI);
if(nu!=0){
	mfree(nb); mfree(nk);
}
lt = time(NULL);
c  = localtime(&lt);
printf("\n\nNetwork training ended at %.8s\n\n\n",asctime(c)+11);
}
/*
  --------------------------------------------------------------------------------
  ----------------             END OF NETWORK TRAINING              --------------
  --------------------------------------------------------------------------------
*/


/*********************************************************************************
 *                                                                               *
 *                           G A T E W A Y   R O U T I N E                       *
 *                                                                               *
 *********************************************************************************/
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
  /*
   >>>>>>>>>>>>>>>>>>           VARIABLE DECLARATIONS          <<<<<<<<<<<<<<<<<<<
   */
   matrix *NSSEvec, *NN, *nb;
   matrix *NetDef, *W1, *W2, *U, *Y, *trparms;
   double *M, lambda;
   int iter, skip, m, N, na, nc, nabc, nu, hidden, k;

  /*
   >>>>>>>>>>>>>>>>      CHECK FOR PROPER NUMBER OF ARGUMENTS      <<<<<<<<<<<<<<<
   */
   if (nrhs<7 || nrhs>8) mexErrMsgTxt("Wrong number of input arguments");
   else if (nlhs > 5) mexErrMsgTxt("Two many output arguments");
   if(nrhs==8)
     nu = mxGetM(prhs[7]);   /* # of control signals */
   else
     nu=0;
   m  = mxGetM(prhs[6]);     /* Rows of vector Y */
   if(m!=1) mexErrMsgTxt("Wrong dimension of vector of observed outputs");
   N = mxGetN(prhs[6]);      /* Columns of vector Y */
   if(nu!=0) {
   if(N!=mxGetN(prhs[6])) mexErrMsgTxt("U and Y should have the same number of columns!");
   }
   hidden = mxGetN(prhs[0]); /* # of hidden units */
   if(mxGetM(prhs[0])!=2) mexErrMsgTxt("Error in architecture definition!");
   if(mxGetM(prhs[1])!=1) mexErrMsgTxt("NN should only have 1 row!");
   if(mxGetN(prhs[1])!=2*nu+2) mexErrMsgTxt("Mismatch between U and NN");
         

  /*
   >>>>>>>>>>>>>>>>>     CONVERT INPUT ARGUMENTS TO SM FORMAT     <<<<<<<<<<<<<<<<
   */
  NetDef  = matstring2sm(prhs[0]);/* Network architecture */
  NN      = mat2sm(prhs[1]);     /* Regressor structure  */
  if(mxGetM(prhs[4])!=0)         /* Parameters associated with minimization */
 	trparms = mat2sm(prhs[4]);
  else {                         /* Use defaults if passed as [] */
   	trparms = mmake(1,4);
   	mput(trparms,0,0,500);
   	mput(trparms,0,1,0.0);
   	mput(trparms,0,2,1);
   	mput(trparms,0,3,0.0);
  }
  if(mxGetM(prhs[5])!=0)         /* Skip first gradients */
 	skip = (int)(*mxGetPr(prhs[5]));
  else skip=0;                   /* Use skip=0 if passed as []  */
  Y    = mat2sm(prhs[6]);        /* Vector of observed outputs  */
  if(nu!=0) U = mat2sm(prhs[7]); /* Vector of inputs            */
  else{
  	U=mmake(1,1);
  	U->row=0;
  	U->col=0;
  }
  


  /*
   >>>>>>>>>>>>>>>>      CHECK FOR PROPER NUMBER OF ARGUMENTS      <<<<<<<<<<<<<<<
   */
   na        = vget(NN,0);   /* Past outputs used as input           */
   nc        = vget(NN,nu+1);/* Past prediction errors used as input */
   if(nu!=0){
   	nb = mmake(1,nu);    /* Past controls used as input          */
   	subvec(nb,NN,1,nu);
   }
   nabc      = na+nc;        /* na+nb+nc                              */
   for(k=0;k<nu;k++) nabc=nabc+rvget(nb,k);
   
   /* Initialize weight matrices if passed as [] */
   if(mxGetM(prhs[2])==0 || mxGetN(prhs[2])==0 || mxGetM(prhs[3])==0\
                         || mxGetN(prhs[3])==0){
   	W1 = mmake(hidden,nabc+1);
   	W2 = mmake(1,hidden+1);
      	W2->row = 0;   /* Hack telling that the weights should be initialized */ 

   }
   else{
   	if(mxGetM(prhs[2])!=hidden) mexErrMsgTxt("W1 has the wrong dimension");
   	if(mxGetN(prhs[2])!=nabc+1) mexErrMsgTxt("W1 has the wrong dimension");
   	if(mxGetM(prhs[3])!=1) mexErrMsgTxt("W2 has the wrong dimension");
   	if(mxGetN(prhs[3])!=hidden+1) mexErrMsgTxt("W2 has the wrong dimension");
   	W1 = mat2sm(prhs[2]);     /* Input-to-hidden layer weights */
   	W2 = mat2sm(prhs[3]);     /* Hidden-to-output layer weights */
   }


  /*
   >>>>>>>>>>>>>>>>>>>>>>         CALL THE C-ROUTINE         <<<<<<<<<<<<<<<<<<<<<
   */
  nnarmax2(&NSSEvec, &iter, &lambda, NetDef, NN, W1, W2, trparms, skip, Y, U);


  /*
   >>>>>>>>>>>>>>>>>>>         CREATE OUTPUT MATICES            <<<<<<<<<<<<<<<<<<
   */
  plhs[0] = mxCreateDoubleMatrix(getrows(W1),getcols(W1),mxREAL);
  plhs[1] = mxCreateDoubleMatrix(getrows(W2),getcols(W2),mxREAL);
  plhs[2] = mxCreateDoubleMatrix(getrows(NSSEvec),getcols(NSSEvec),mxREAL);
  plhs[3] = mxCreateDoubleMatrix(1,1,mxREAL);
  plhs[4] = mxCreateDoubleMatrix(1,1,mxREAL);

  sm2mat(plhs[0],W1);
  sm2mat(plhs[1],W2);
  sm2mat(plhs[2],NSSEvec);
  M = mxGetPr(plhs[3]); M[0] = (double)iter;
  M = mxGetPr(plhs[4]); M[0] = (double)lambda;

  /*
   >>>>>>>>>>>>>>>>>>>>        FREE ARGUMENT MATRICES        <<<<<<<<<<<<<<<<<<<<<
   */
  mfree(NetDef);
  mfree(NN);
  mfree(U);
  if(nu!=0) mfree(nb);
  mfree(Y);
  mfree(trparms);
}



⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -