📄 icamf.m
字号:
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% mean field solvers %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% expectation consistent solver %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [m,chi,G,ite,dmdm]=ec_solver(theta,J,par) % ,debug_draw)% ec_solver formerly known as ec_fac_ep.m % expectation consistent factorized inference% optimized using ep-style updates % Ole Winther, IMM, February 2005 % initialize variablesM = par.M ;N = par.N ; % initialize constants specific for ec_solver try minchi=par.minchi; catch minchi = 1e-7; endtry m=par.m; catch m=zeros(M,N); end% initial value of m=<S>=mean values of sources - called S in outer% routinestry minLam_q=par.minLam_q; catch minLam_q = 1e-5; endtry eta=par.eta; catch eta=1; end % set learning rate for lam_r updatedm = zeros(M,N); m_r = zeros(M,N); v_r = zeros(M,N);eigJ=eig(J); mineigJ = min(eigJ); maxeigJ = max(eigJ);Lam_r = 1 * ( maxeigJ - mineigJ ) * ones(M,1) ;chi = inv( diag(Lam_r) - J ) ; Lam_r = repmat(Lam_r,[1 N]) ;% set mean value of r-distribution to be the same as mgam_r = zeros(M,N) ;m_r = m ;v_r = diag(chi) ;% make a covariance matrix for each samplechi = repmat(chi,[1 1 N]) ; % could be made more effective with lightspeed.%chiv = zeros(M,M,N) ; chivt = zeros(M,M,N) ;ite=0; dmdm=Inf;dmdmN=Inf*ones(N,1);tolN = par.S_tol / N ;I = 1:N;while (~isempty(I) && dmdm>par.S_tol && ite<par.S_max_ite) ite=ite+1; indx=randperm(M); for sindx=1:M cindx=indx(sindx); par.Sindx = cindx; %% find mean and variance of r m_r(cindx,I) = ... sum( shiftdim( chi(cindx,:,I) , 1 ) .* ( gam_r(:,I) + theta(:,I) ) , 1 ) ; v_r(cindx,I) = chi(cindx,cindx,I) ; % find Lagrange parameters of s Lam_s1 = 1 ./ max( minchi , v_r(cindx,I) ) ; gam_s1 = Lam_s1 .* m_r(cindx,I); % update lam_q gam_q = gam_s1 - gam_r(cindx,I) ; Lam_q = Lam_s1 - Lam_r(cindx,I) ; % update moments of q distribution [m(cindx,I),v_q] = par.Smeanf( gam_q , max( Lam_q , minLam_q ) , par ); %%% OBS hack here!!!! dm(cindx,I) = m(cindx,I) - m_r(cindx,I) ; % find Lagrange parameters of s Lam_s2 = 1 ./ max( minchi , v_q ) ; gam_s2 = Lam_s2 .* m(cindx,I) ; % update lam_r dgam_r = eta * ( gam_s2 - gam_s1 ) ; gam_r(cindx,I) = dgam_r + gam_r(cindx,I) ; dLam_r = eta * ( Lam_s2 - Lam_s1 ) ; Lam_r(cindx,I) = dLam_r + Lam_r(cindx,I) ; % update chi using Sherman-Woodbury switch par.ecchiupdate case 'parallel' oM = ones(M,1) ; kappa = dLam_r ./ ( 1 + dLam_r .* v_r(cindx,I) ) ; kappa=reshape(kappa,1,1,length(I)); chiv = chi(:,cindx,I) ; chiv = kappa(oM,:,:) .* chiv ; chiv=chiv(:,oM,:); chivt = chi(cindx,:,I) ; chivt=chivt(oM,:,:); chi(:,:,I) = chi(:,:,I) - chiv .* chivt ; case 'sequential' for j=1:length(I) i = I(j) ; chiv = chi(:,cindx,i) ; chi(:,:,i) = chi(:,:,i) ... - dLam_r(j) / ( 1 + dLam_r(j) * v_r(cindx,i) ) * chiv * chiv' ; % v_r(cindx,:) = chi(cindx,cindx,:) end end end % over variables - sindx dmdmN=sum(dm.*dm,1) ; I=find(dmdmN > tolN); dmdm=sum(dmdmN); endif nargout > 2 % calculate free energy per sample Lam_s = 1 ./ max( minchi , v_r ) ; gam_s = Lam_s .* m_r ; gam_q = gam_s - gam_r ; Lam_q = Lam_s - Lam_r ; par.Sindx = (1:M)'; logZ_q = par.logZ0f(gam_q, max( minLam_q , Lam_q ) , par ) ; logZ_r = 0.5 * ( N * log(2*pi) - N * par.logdet2piSigma - par.XinvSigmaX ) ; for i=1:N logZ_r = logZ_r + 0.5 * logdet( chi(:,:,i) ) ... + 0.5 * ( gam_r(:,i) + theta(:,i) )' * ... chi(:,:,i) * ( gam_r(:,i) + theta(:,i) ) ; end logZ_s = 0.5 * ( N * log(2*pi) + ... sum( sum( - log( Lam_s ) + gam_s.^2 ./ Lam_s ) ) ) ; G = ( - logZ_q - logZ_r + logZ_s ) / N ;end%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% variational solver %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [m,chi,G,ite,dmdm]=variational_solver(theta,J,par)M = par.M ;N = par.N ; chi = zeros( M , M , N ) ;% initialize constantstry minchi=par.minchi; catch minchi = 1e-7; endtry m=par.m; catch m=zeros(M,N); end% initial value of m=<S>=mean values of sources - called S in outer% routinesdm = m ;Lam_q = - diag(J) ; % variational resultite=0; dmdm=Inf;dmdmN=zeros(N,1);tolN = par.S_tol / N ;I=1:N;while (~isempty(I) && dmdm>par.S_tol && ite<par.S_max_ite) ite=ite+1 ; indx=randperm(M); for sindx=1:M cindx=indx(sindx); par.Sindx = cindx; % current index % update hcav gam_q = theta(cindx,I) + J(cindx,:) * m(:,I) + Lam_q(cindx) * m(cindx,I) ; % update m and derivative m_old=m(cindx,I); m(cindx,I)=par.Smeanf(gam_q,Lam_q(cindx,ones(size(I))),par) ; dm(cindx,I)=m(cindx,I)-m_old; end dmdmN=sum(dm.*dm,1); I=find(dmdmN > tolN); dmdm=sum(dmdmN) ; end if nargout > 1 % calculate chi gam_q = theta + ( J + diag(Lam_q) )* m ; par.Sindx = (1:M)'; [m_q,v_q] = par.Smeanf( gam_q , Lam_q(:,ones(N,1)) , par); for i=1:N % Lam_s = 1 ./ v_q ; % chi = inv( diag( Lam_r ) - J ) = inv( diag( Lam_s -Lam_q ) - J ) chi(:,:,i) = inv( eye(M) - diag( v_q(:,i) ) * ( J + diag( Lam_q ) ) ) * diag( v_q(:,i) ) ; endendif nargout > 2 % calculate free energy per sample Lam_s = 1 ./ max( minchi , v_q ) ; gam_s = Lam_s .* m_q ; gam_r = gam_s - gam_q ; % Lam_r = Lam_s - Lam_q(:,ones(N,1)) ; logZ_q = par.logZ0f(gam_q,Lam_q(:,ones(N,1)) , par ) ; logZ_r = 0.5 * ( N * log(2*pi) - N * par.logdet2piSigma - par.XinvSigmaX ) ; for i=1:N logZ_r = logZ_r + 0.5 * logdet( chi(:,:,i) ) ... + 0.5 * ( gam_r(:,i) + theta(:,i) )' * ... chi(:,:,i) * ( gam_r(:,i) + theta(:,i) ) ; end logZ_s = 0.5 * ( N * log(2*pi) + ... sum( sum( - log( Lam_s ) + gam_s.^2 ./ Lam_s ) ) ) ; G = ( - logZ_q - logZ_r + logZ_s ) / N ;end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% parameter conversion %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [theta,J,par] = XASigma2thetaJ(X,A,Sigma,par) % invert Sigmaswitch size(Sigma,1)*size(Sigma,2) case 1 invSigma = 1 / Sigma; par.logdet2piSigma = par.D * log( 2 * pi * Sigma ) ; case par.D invSigma = diag(1./Sigma); par.logdet2piSigma = sum( log( 2 * pi * Sigma ) ) ; case par.D^2 invSigma = inv(Sigma); par.logdet2piSigma = logdet( 2 * pi * Sigma ) ; end% calculate external field and coupling matrixtheta = A' * invSigma * X ;J = - A' * invSigma * A ; par.XinvSigmaX = sum( sum( X .* ( invSigma * X ) ) ) ; function [A,Sigma] = popt2ASigma(popt,par) ;D = par.D ; M = par.M ; switch par.Aprior case 'free' A = reshape( popt(1:par.Asize) , D , M ) ; case 'positive' A = exp( reshape(popt(1:par.Asize), D , M ) ) ; case 'constant' A = par.A_init ;endswitch par.Sigmaprior case 'free' par.sSigma = reshape(popt(par.Asize+1:end),par.D,par.D) ; Sigma = par.sSigma' * par.sSigma ; case {'isotropic','diagonal'} Sigma = exp( popt(par.Asize+1:end) ) ; case 'constant' Sigma = par.Sigma_init ;endfunction [popt] = ASigma2popt(A,Sigma,par) ;D = par.D ; popt = zeros( par.Asize + par.Sigmasize , 1 ) ;if strcmp(par.Aprior,'positive') & ... ( strcmp(par.optimizer,'bfgs') | strcmp(par.optimizer,'conjgrad') ) popt(1:par.Asize) = log(A(:)) ;else popt(1:par.Asize) = A(:) ;end switch par.Sigmaprior case 'free' popt(par.Asize+1:end) = par.sSigma(:) ; %R(par.triui) ; case {'isotropic','diagonal'} popt(par.Asize+1:end) = log(Sigma(:)) ; case 'constant'end%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% parameter derivatives %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function dpopt = dpoptf(X,A,Sigma,S,chi,par) M = par.M ;dpopt = zeros( par.Asize + par.Sigmasize , 1 ) ; for i=1:M % set up matrices for finding derivatives for j=i:M tracechi(i,j)=sum(chi(i,j,:)); tracechi(j,i)=tracechi(i,j); traceSS(i,j)=sum(chi(i,j,:))+sum(S(i,:)'.*S(j,:)'); traceSS(j,i)=traceSS(i,j); end;end;if size(Sigma) == [par.D,1] invSigma = diag(1 ./ Sigma ) ;else invSigma = inv(Sigma) ;enddA = invSigma * ( A * traceSS - X * S' ) / par.N ;switch par.Aprior case 'free' dpopt(1:par.Asize) = dA(:) ; case 'positive' dpopt(1:par.Asize) = A(:) .* dA(:) ; case 'constant'endswitch par.Sigmaprior case 'isotropic' dSigma = 0.5 * invSigma * ( par.D - ... ( sum(sum((X-A*S).^2)) + ... sum(sum((A*tracechi).*A)) ) * invSigma / par.N ) ; dpopt(par.Asize+1) = Sigma * dSigma ; case 'diagonal' invSigma = diag( invSigma ) ; dSigma = 0.5 * invSigma .* ( ones(par.D,1) - ... ( sum((X-A*S).^2,2) + ... sum((A*tracechi).*A,2) ) .* invSigma / par.N ) ; dpopt(par.Asize+1:end) = Sigma .* dSigma ; case 'free' dSigma = 0.5 * invSigma * ( eye(par.D) - ... ( (X-A*S)*(X-A*S)' + A*tracechi*A' ) * invSigma / par.N ) ; dSigma = 2 * dSigma - diag(diag(dSigma)) ; % take into account Sigma is symmetric dSigma = 2 * par.sSigma * dSigma ; dpopt(par.Asize+1:end) = dSigma(:) ; case 'constant'endfunction [A,Sigma] = dASigmaf(X,A,Sigma,S,chi,par) M = par.M ;for i=1:M % set up matrices for finding derivatives for j=i:M tracechi(i,j)=sum(chi(i,j,:)); tracechi(j,i)=tracechi(i,j); traceSS(i,j)=sum(chi(i,j,:))+sum(S(i,:)'.*S(j,:)'); traceSS(j,i)=traceSS(i,j); end;end;if size(Sigma) == [par.D,1] invSigma = diag(1 ./ Sigma ) ;else invSigma = inv(Sigma) ;endswitch par.Aprior case 'free' A = X * S' * inv(traceSS) ; case 'positive' % iterate eq. (2.13) in Ref. [3] A = A_positive_aem(traceSS,X * S',invSigma,A) ; case 'constant' A = par.A_init ;end switch par.Sigmaprior case 'isotropic' Sigma = ( sum(sum((X-A*S).^2)) + ... sum(sum((A*tracechi).*A)) ) / ( par.N * par.D ) ; case 'diagonal' Sigma = ( sum((X-A*S).^2,2) + ... sum((A*tracechi).*A,2) ) / par.N ; case 'free' Sigma = ( (X-A*S)*(X-A*S)' + A*tracechi*A' ) / par.N ; case 'constant' Sigma = par.Sigma_init ;endfunction [A] = A_positive(traceSS,XSt,invSigma,A) Atol = 1e-6;KT_max_ite=100; A=A+10^-3*(A<eps);sizeA = size(A,1) * size(A,2) ;invSigmaXSt = invSigma * XSt ; amp=invSigmaXSt./max(eps,invSigma*A*traceSS);Aneg=any(any(amp<0));KT_ite=0; Aerror = inf ;while ~Aneg & KT_ite<KT_max_ite & Aerror > Atol KT_ite=KT_ite+1; %dA = A - A.*amp ; dAerror = sum(sum(dA)) / (size(A,1)*size(A,2)) Aold = A ; A=A.*amp ; amp=invSigmaXSt./max(eps,invSigma*A*traceSS); Aneg=any(any(amp<0)); Aerror = sum(sum(abs(A-Aold))) / sizeA ;endif Aneg % use quadratic programming instead - doesn't work for Sigma full M=size(traceSS,1); D=size(XSt,1); options=optimset('Display','off','TolX',10^5*eps,'TolFun',10^5*eps); for i=1:D B=quadprog(traceSS,-XSt(i,:)',[],[],[],[],zeros(M,1),[],A(i,:)',options); A(i,:)=B'; end endfunction [A] = A_positive_aem(traceSS,XSt,invSigma,A) Atol = 1e-6;KT_max_ite=100; A=A+10^-3*(A<eps);sizeA = size(A,1) * size(A,2) ;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -