📄 vssminpn.m
字号:
SigB = inv( Udd - GA'*SigA*GA + diag(beta) ); Bb = MB' - GA'*SigA*SA; % VBM step compute sufficient statistics (dynamics model) B = Bb'*SigB; % < B > A = ( SA - GA*B' )'*SigA; % < A > AA = A'*A + k*( SigA + SigA*GA*SigB*GA'*SigA ); % < A' A > AB = SigA*SA*B - SigA*GA*( B'*B + k*SigB ); % < A' B > BB = B'*B + k*SigB; % < B' B > %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % UPDATE rho,C,D %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % VBM step (observation model) % Compute parameters to represent Q(rho,C,D) WC = zeros(k,k); SC = zeros(k,p); MD = zeros(k,pinp); % Nota Bene, MD is equivalent to GC in thesis. for i=1:n WC = WC+Xmn{i}*Xmn{i}'; for t = 1:Tn(i) WC = WC+inv(Xcin{i}(:,:,t)); end SC = SC+ Xmn{i}*Yn{i}'; MD = MD+ Xmn{i}*inpn{i}'; end SigC = inv( WC + diag(gamma) ); SigD = inv( Udd - MD'*SigC*MD + diag(delta) ); Db = UY - MD'*SigC*SC; G = diag( Ydd - SC'*SigC*SC -Db'*SigD*Db ); % VBM step compute sufficient statistics (observation model) rho = (pa+sum(Tn,2)/2)*diag(1./(pb+G/2)); % < R^{-1} > matrix lnrho = digamma(pa+sum(Tn,2)/2) - log(pb+G/2); % < -log R_ss > vector D = Db'*SigD; % < D > C = ( SC - MD*D' )'*SigC; % < C > CrhoC = C'*rho*C + p*( SigC + SigC*MD*SigD*MD'*SigC );% < C' R^{-1} C > rhoC = rho*C; % < R^{-1} C > CrhoD = SigC*( SC*rho*D - MD*D'*rho*D - p*MD*SigD ); % < C' R^{-1} D > rhoD = rho*D; % < R^{-1} D > DrhoD = D'*rho*D + p*SigD; % < D' R^{-1} D > % Is this the last iteration? finalit = ~(it<its); %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % UPDATE hidden state %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % VBE step : q(X) is Gaussian for i=1:n [lnZpTn{i},Xm_a,Xci_a] = forwardpass(Yn{i},inpn{i},X0ci_p,X0m_p,A,AA,AB,B,BB,rho,lnrho,CrhoC,rhoC,CrhoD,rhoD,DrhoD,cFbool|finalit); if cFbool|finalit , lnZpn(i) = sum(lnZpTn{i},2); end [Xm_b,Xci_b,X0m_b,X0ci_b] = backwardpass(Yn{i},inpn{i},A,AA,B,AB,CrhoC,rhoC,CrhoD); [Xmn{i},Xcin{i},Upsn{i},X0mn{i},X0cin{i},Ups0n{i}] = getmarginals(Xm_a,Xci_a,Xm_b,Xci_b,X0m_p,X0ci_p,X0m_b,X0ci_b,A,AA,CrhoC); end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % CALCULATE LOWER BOUND, F %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% if cFbool|finalit % Calculate F = -kl_A -kl_B -kl_C -kl_D -kl_rho + \sum_{i=1}^n lnZpn_i % For this we require parameter sufficient statistics and lnZp [F] = Fcalc(pa,pb,alpha,beta,gamma,delta,SigA,SigB,SigC,SigD,Bb,Db,SA,SC,GA,MD,G,lnZpn,Tn); % book keeping for F constituents hist = [hist; F sum(F,2)]; dF = hist(end,7)-hist(end-1,7); if (dispopt>=3)&cFbool for hh = 1:size(hFpart); set(hFpart(hh),'XData',1:(size(hist,1)-1),'YData',hist(2:end,hh)); end set(hFdiff,'XData',1:(size(hist,1)-1),'YData',log(diff(hist(1:end,end)))); drawnow end if dispopt>=2 fprintf('\nit:%4i F: B=%.3f A|B=%.3f rho=%.3f D|rho=%.3f C|rho,D=%.3f Y=%.3f F=%.3f dF=%.3f ',it,hist(end,:),dF); if dF < 0 % it *should* monotonically increase fprintf('\n\n!!!!!\n\n F violation \n\n!!!!!\n'); alpha, beta, gamma, delta, A,B,C,D fprintf('BREAKING (paused)'); pause break end end end % Terminate VBEM iteration loop if number of iterations has been reached, or if the lower bound plateaus. if ~( (it<its) & (dF>1e-14) ) break end % Display hint of progress if (round(it/50)==it/50) & (dispopt>=1) & ~cFbool, fprintf('.'); end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % HYPERPARAMETER UPDATES %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% if it>=dynhypit % ARD for the dynamics alpha = k./diag( k*SigA + SigA*( SA*SA' - 2*GA*B'*SA' + GA*( k*SigB+B'*B )*GA' )*SigA ); beta = k./diag( k*SigB + B'*B ); end if it>=outhypit % ARD for the output process gamma = p./diag( p*SigC + SigC*( SC*rho*SC' - 2*SC*rho*D*MD' + p*MD*SigD*MD' + MD*D'*rho*D*MD' )*SigC ); delta = p./diag( p*SigD + D'*rho*D ); end if it>=etchypit % Hyperparameters for output noise precision (please see report) [pa,pb] = fixedpointsolver(pa,pb,1/p*sum(lnrho,1),1/p*sum(diag(rho),1)); % Hyperparameters for prior mean and covariance of auxiliary state x_0m (please see report) X0m_p = 1/n*sum(cat(2,X0mn{:}),2); tmp = zeros(k,k); for i=1:n tmp = tmp+inv(X0cin{i}) + (X0mn{i}-X0m_p)*(X0mn{i}-X0m_p)'; end X0ci_p = diag(n./diag(tmp)); end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % SAVE HISTORIES %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% alphahist = [alphahist alpha]; betahist = [betahist beta]; gammahist = [gammahist gamma]; deltahist = [deltahist delta]; pahist = [pahist pa]; pbhist = [pbhist pb]; X0m_phist = [X0m_phist X0m_p]; X0ci_phist = [X0ci_phist X0ci_p]; %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % DISPLAY HYP-PROGRESS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% if dispopt>=3 for kk = 1:k set(halpha(kk),'XData',1:size(alphahist,2),'YData',-log(alphahist(kk,:))); set(hgamma(kk),'XData',1:size(gammahist,2),'YData',-log(gammahist(kk,:))); end for pinpp = 1:pinp set(hbeta(pinpp),'XData',1:size(betahist,2),'YData',-log(betahist(pinpp,:))); set(hdelta(pinpp),'XData',1:size(deltahist,2),'YData',-log(deltahist(pinpp,:))); end set(hrhomean,'XData',1:size(pahist,2),'YData',pahist./pbhist); set(hrhovar,'XData',1:size(pahist,2),'YData',log(pahist./(pbhist.^2))); end endif dispopt>=1; fprintf('\n--finished--\n'); end%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% RETURN LEARNT NETWORK FOR ANALYSIS & PERUSAL%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%net = struct('type','Variational State-Space Model (vSSM) with inputs');net.hyp.pa = pa;net.hyp.pb = pb;net.hyp.alpha = alpha;net.hyp.beta = beta;net.hyp.gamma = gamma;net.hyp.delta = delta;net.hyp.X0m_p = X0m_p;net.hyp.X0ci_p = X0ci_p;net.param.SigA = SigA;net.param.SigB = SigB;net.param.SigC = SigC;net.param.SigD = SigD;net.param.Bb = Bb;net.param.Db = Db;net.param.SA = SA;net.param.SC = SC;net.param.GA = GA;net.param.MB = MB;net.param.MD = MD;net.param.G = G;net.param.qrhoa = pa+sum(Tn,2)/2;net.param.qrhob = pb+G/2;net.exp.A = A;net.exp.B = B;net.exp.AA = AA;net.exp.AB = AB;net.exp.BB = BB;net.exp.rho = rho;net.exp.lnrho = lnrho;net.exp.C = C;net.exp.D = D;net.exp.CrhoC = CrhoC;net.exp.rhoC = rhoC;net.exp.CrhoD = CrhoD;net.exp.rhoD = rhoD;net.exp.DrhoD = DrhoD;net.hidden.X0mn = X0mn;net.hidden.Xmn = Xmn;net.hidden.X0cin= X0cin;net.hidden.Xcin = Xcin;net.hidden.Ups0n= Ups0n;net.hidden.Upsn = Upsn;net.hist.F = hist;net.hist.pa = pahist;net.hist.pb = pbhist;net.hist.alpha = alphahist;net.hist.beta = betahist;net.hist.gamma = gammahist;net.hist.delta = deltahist;net.hist.X0m_p = X0m_phist;net.hist.X0ci_p = X0ci_phist;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -