📄 nnssif.c
字号:
for(t=0;t<N;t++){
/*-- Derivative of states wrt. hidden outputs --*/
for(j=0;j<louts;j++){
i=(int)cvget(L_output,j);
for(k=0;k<hidden;k++) put_val(dxdy1,i,k,get_val(W2,i,k));
}
for(j=0;j<houts;j++){
i=(int)cvget(H_output,j);
for(k=0;k<hidden;k++) put_val(dxdy1,i,k,get_val(W2,i,k)*(1-\
get_val(y2,i,t)*get_val(y2,i,t)));
}
/*-- Partial deriv. of output from each hidden unit wrt. net inputs --*/
for(j=0;j<lhids;j++){
i=(int)cvget(L_hidden,j);
for(k=0;k<nx;k++) put_val(dy1dx,i,k,get_val(W1,i,k));
for(k=nxu;k<inputs;k++) put_val(dy1de,i,k-nxu,get_val(W1,i,k));
}
for(j=0;j<hhids;j++){
i=(int)cvget(H_hidden,j);
for(k=0;k<nx;k++) put_val(dy1dx,i,k,\
get_val(W1,i,k)*(1-get_val(y1,i,t)*get_val(y1,i,t)));
for(k=nxu;k<inputs;k++) put_val(dy1de,i,\
k-nxu,get_val(W1,i,k)*(1-get_val(y1,i,t)*get_val(y1,i,t)));
}
/*--Partial derivative of states w.r.t. past states and residuals --*/
mmul(Ahat,dxdy1,dy1dx);
for(k=0;k<nx-ny;k++) put_val(Ahat,(int)cvget(nrowidx,k),\
(int)cvget(nrowidx,k)+1,1.0);
mmul(Khat,dxdy1,dy1de);
mmul(AKC,Khat,C);
msub(AKC,Ahat,AKC);
/*
>>>>>>>>>>>>>>>>>>> Filter partial derivatives <<<<<<<<<<<<<<<<<<<<
*/
if(t>=1){
/* PSIx = PSIx + PSIx*AKC' */
index5 = t*nx;
index6 = (t-1)*nx;
for(i=0;i<reduced;i++){
ii =(int)cvget(theta_index,i);
for(j=0;j<nx;j++){
for(k=0;k<nx;k++){
PSIx->mat[ii][index5+j]+=get_val(PSIx,ii,index6+k)*\
get_val(AKC,j,k);
}
}
}
}
/*PSI=PSIx*C';*/
index5 = t*ny;
index6 = t*nx;
for(i=0;i<reduced;i++){
ii =(int)cvget(theta_index,i);
for(j=0;j<ny;j++){
for(sum=0,k=0;k<nx;k++){
sum+=get_val(PSIx,ii,index6+k)*get_val(C,j,k);
}
put_val(PSI,ii,index5+j,sum);
}
}
}
minit(PSIx);
dw = 0;
/*
>>>>>>>>>>>> Gradient (G = PSI_red*E_vector - D*theta_red) <<<<<<<<<<<<<
*/
for(i=0; i<reduced; i++){
ii = (int)cvget(theta_index,i);
for(sum=0.0,k=skipstart; k<Nny; k++)
sum+=get_val(PSI,ii,k)*rvget(Evec,k);
cvput(G,i,sum - cvget(D,i)*cvget(theta_red,i));
}
/*
>>>>>>>>>> Mean square error part of Hessian (PSI_red*PSI_red') <<<<<<<<<<
*/
for(i=0; i<reduced; i++){
ii = (int)cvget(theta_index,i);
for(j=i; j<reduced; j++){
jj = (int)cvget(theta_index,j);
for(sum=0.0,k=skipstart; k<Nny; k++)
sum += get_val(PSI,ii,k)*get_val(PSI,jj,k);
put_val(R,i,j,sum);
put_val(R,j,i,sum);
}
}
}
/*
>>>>>>>>>>>>>>>>>>>>>>>>>>> COMPUTE h_k <<<<<<<<<<<<<<<<<<<<<<<<<<<
*/
/* -- Hessian (H = R + lambda*I + D) --*/
mset(H,R);
for(i=0;i<reduced;i++) /* Add diagonal matrix */
put_val(H,i,i,get_val(H,i,i)+lambda+cvget(D,i));
/* -- Search direction -- */
choldc(H, Htmp);
cholsl(Htmp,h,G);
/* -- Compute 'apriori' iterate -- */
madd(theta_red_new,theta_red,h); /* Update parameter vector */
mcopyi(theta,theta_index,index0,theta_red_new,index7,index0);
/* -- Put the parameters back into the weight matrices -- */
v2mreshape(W1_new,theta,parameters2);
v2mreshape(W2_new,theta,0);
/*
>>>>>>>>>>>>> Compute network output y2(theta+h) <<<<<<<<<<<<<<
*/
for(t=0;t<N;t++){
mvmul(h1,W1_new,PHI,t);
vtanh(y1,H_hidden,t,h1,H_hidden,0);
vcopyi(y1,L_hidden,t,h1,L_hidden,0);
mvmul(h2,W2_new,y1,t);
vtanh(y2,H_output,t,h2,H_output,0);
vcopyi(y2,L_output,t,h2,L_output,0);
for(k=0;k<(nx-ny);k++){
i=(int)cvget(nrowidx,k);
y2->mat[i][t]+=get_val(PHI,i+1,t);
}
mvmul(Yhat,C,y2,t); /* Output prediction */
for(i=0;i<ny;i++){
cvput(E,i,get_val(Y2,i,t)-cvget(Yhat,i));/* Prediction error */
rvput(Evec_new,t*ny+i,cvget(E,i)); /* Store E in Evec_new */
}
if(t<N-1){
for(i=0;i<nx;i++)
put_val(PHI,i,t+1,get_val(y2,i,t));
for(i=0;i<ny;i++)
put_val(PHI,nx+nu+i,t+1,cvget(E,i));
}
}
for(SSE_new=0,t=skipstart;t<Nny;t++)
SSE_new+=rvget(Evec_new,t)*rvget(Evec_new,t); /* Sum of squared errors */
for(tmp1=0,i=0;i<reduced;i++) tmp1+=cvget(theta_red,i)*cvget(theta_red,i)*cvget(D,i);
NSSE_new = (SSE_new+tmp1)/(2*N2); /* Value of cost function*/
/*
>>>>>>>>>>>>>>>>>>>>>>>>>>> UPDATE lambda <<<<<<<<<<<<<<<<<<<<<<<<<<
*/
for(tmp1=0,i=0;i<reduced;i++) tmp1+=cvget(h,i)*cvget(h,i)*(cvget(D,i)+lambda);
L = sprod3(h,G) + tmp1;
/* Decrease lambda if SSE has fallen 'sufficiently' */
if(2*N2*(NSSE - NSSE_new) > (0.75*L)) lambda = lambda/2;
/* Increase lambda if SSE has grown 'sufficiently' */
else if(2*N2*(NSSE-NSSE_new) <= (0.25*L)) lambda = 2*lambda;
/*
>>>>>>>>>>>>>>>>>>> UPDATES FOR NEXT ITERATION <<<<<<<<<<<<<<<<<<<<
*/
/* Update only if criterion has decreased */
if(NSSE_new<NSSE)
{
tmp = W1; W1=W1_new; W1_new=tmp;
tmp = W2; W2=W2_new; W2_new=tmp;
tmp = theta_red; theta_red=theta_red_new; theta_red_new = tmp;
tmp = Evec; Evec = Evec_new; Evec_new = tmp;
dw = 1;
NSSE = NSSE_new;
cvput(NSSEvec,iteration-1,NSSE);
printf("iteration # %i PI = %4.3e\r",iteration,NSSE); /* On-line information */
++iteration;
}
}
/*
>>>>>>>>>> RETURN POINTERS TO RETURN ARGUMENTS & FREE MEMORY <<<<<<<<<<<<
*/
/* Swap pointers if they have been messed up */
if ((iteration&1) == 0) {
mset(W1_new,W1);
tmp = W1; W1=W1_new; W1_new=tmp;
mset(W2_new,W2);
tmp = W2; W2=W2_new; W2_new=tmp;
}
iteration=iteration-1;
if(iteration==0){
*NSSEvecpp = mmake(1,1);
(*NSSEvecpp)->row=0;
(*NSSEvecpp)->col=0;
}
else{
*NSSEvecpp = mmake(iteration,1);
subvec(*NSSEvecpp,NSSEvec,0,iteration-1);
}
*iter = iteration;
*lam = lambda;
mfree(L_hidden); mfree(H_hidden); mfree(L_output); mfree(H_output); mfree(Evec);
mfree(h1); mfree(h2); mfree(y1); mfree(y2); mfree(Y2); mfree(Ahat); mfree(C);
mfree(E); mfree(Evec_new); mfree(dxdy1); mfree(Khat); mfree(dy1de); mfree(dy1dx);
mfree(W1_new); mfree(W2_new); mfree(D);mfree(Dtmp); mfree(NSSEvec);mfree(Htmp);
mfree(theta); mfree(thtmp); mfree(theta_index); mfree(theta_red); mfree(theta_red_new);
mfree(PSI); mfree(PSIx); mfree(G); mfree(H); mfree(R); mfree(h), mfree(Yhat);
mfree(all); mfree(index0); mfree(index7);mfree(onesvec);mfree(tmp0); mfree(tmp2);
mfree(tmp3); mfree(index); mfree(index2); mfree(PHI);
mfree(rowidx); mfree(nrowidx); mfree(AKC);
lt = time(NULL);
c = localtime(<);
printf("\n\nNetwork training ended at %.8s\n\n\n",asctime(c)+11);
}
/*
--------------------------------------------------------------------------------
---------------- END OF NETWORK TRAINING --------------
--------------------------------------------------------------------------------
*/
/*********************************************************************************
* *
* G A T E W A Y R O U T I N E *
* *
*********************************************************************************/
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
/*
>>>>>>>>>>>>>>>>>> VARIABLE DECLARATIONS <<<<<<<<<<<<<<<<<<<
*/
matrix *NSSEvec;
matrix *NetDef, *W1, *W2, *obsidx, *U, *Y, *trparms;
double *M, lambda;
int iter, skip, ny, N, nu, nx, hidden, k;
/*
>>>>>>>>>>>>>>>> CHECK FOR PROPER NUMBER OF ARGUMENTS <<<<<<<<<<<<<<<
*/
if (nrhs!=9) mexErrMsgTxt("Wrong number of input arguments");
else if (nlhs > 6) mexErrMsgTxt("Two many output arguments");
nu = mxGetM(prhs[8]); /* # of control signals */
ny = mxGetM(prhs[7]); /* Rows of vector Y */
if(nu<1) mexErrMsgTxt("Wrong dimension of input vector");
N = mxGetN(prhs[7]); /* Columns of vector Y */
if(N!=mxGetN(prhs[7])) mexErrMsgTxt("U and Y should have the same number of columns!");
hidden = mxGetN(prhs[0]); /* # of hidden units */
if(mxGetM(prhs[0])!=2) mexErrMsgTxt("Error in architecture definition!");
if(hidden<2) mexErrMsgTxt("Use at least two hidden units!");
/*
>>>>>>>>>>>>>>>>> CONVERT INPUT ARGUMENTS TO SM FORMAT <<<<<<<<<<<<<<<<
*/
NetDef = matstring2sm(prhs[0]); /* Network architecture */
nx = (int)(*mxGetPr(prhs[1]));
if(mxGetM(prhs[5])!=0) /* Parameters associated with minimization */
trparms = mat2sm(prhs[5]);
else{ /* Use defaults if passed as [] */
trparms = mmake(1,4);
mput(trparms,0,0,500);
mput(trparms,0,1,0.0);
mput(trparms,0,2,1);
mput(trparms,0,3,0.0);
}
if(mxGetM(prhs[6])!=0) /* Skip first gradients */
skip = (int)(*mxGetPr(prhs[6]));
else skip=0; /* Use skip=0 if passed as [] */
Y = mat2sm(prhs[7]); /* Vector of observed outputs */
U = mat2sm(prhs[8]); /* Vector of inputs */
/* Initialize pseudo-observability indices if obsidx passed as [] */
if(mxGetM(prhs[4])==0 && mxGetN(prhs[4])==0){
obsidx=mmake(1,ny);
minitx(obsidx,1.0);
vput(obsidx,ny-1,(double)nx-ny+1);
}
else
obsidx = mat2sm(prhs[4]);
/* Initialize weight matrices if passed as [] */
if(mxGetM(prhs[2])==0 || mxGetN(prhs[2])==0 || mxGetM(prhs[3])==0\
|| mxGetN(prhs[3])==0){
W1 = mmake(hidden,nx+nu+ny+1);
W2 = mmake(nx,hidden+1);
W2->row = 0; /* Hack telling that the weights should be initialized */
}
else{
if(mxGetM(prhs[2])!=hidden) mexErrMsgTxt("W1 has the wrong dimension");
if(mxGetN(prhs[2])!=nx+nu+ny+1) mexErrMsgTxt("W1 has the wrong dimension");
if(mxGetM(prhs[3])!=nx) mexErrMsgTxt("W2 has the wrong dimension");
if(mxGetN(prhs[3])!=hidden+1) mexErrMsgTxt("W2 has the wrong dimension");
W1 = mat2sm(prhs[2]); /* Input-to-hidden layer weights */
W2 = mat2sm(prhs[3]); /* Hidden-to-output layer weights */
}
/*
>>>>>>>>>>>>>>>>>>>>>> CALL THE C-ROUTINE <<<<<<<<<<<<<<<<<<<<<
*/
nnssif(&NSSEvec, &iter, &lambda, NetDef, nx, W1, W2, obsidx, trparms, skip, Y, U);
/*
>>>>>>>>>>>>>>>>>>> CREATE OUTPUT MATICES <<<<<<<<<<<<<<<<<<
*/
plhs[0] = mxCreateDoubleMatrix(getrows(W1),getcols(W1),mxREAL);
plhs[1] = mxCreateDoubleMatrix(getrows(W2),getcols(W2),mxREAL);
plhs[2] = mxCreateDoubleMatrix(getrows(obsidx),getcols(obsidx),mxREAL);
plhs[3] = mxCreateDoubleMatrix(getrows(NSSEvec),getcols(NSSEvec),mxREAL);
plhs[4] = mxCreateDoubleMatrix(1,1,mxREAL);
plhs[5] = mxCreateDoubleMatrix(1,1,mxREAL);
sm2mat(plhs[0],W1);
sm2mat(plhs[1],W2);
sm2mat(plhs[2],obsidx);
sm2mat(plhs[3],NSSEvec);
M = mxGetPr(plhs[4]); M[0] = (double)iter;
M = mxGetPr(plhs[5]); M[0] = (double)lambda;
/*
>>>>>>>>>>>>>>>>>>>> FREE ARGUMENT MATRICES <<<<<<<<<<<<<<<<<<<<<
*/
mfree(NetDef);
mfree(U);
mfree(Y);
mfree(trparms);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -