📄 icamf.m
字号:
invSigmaXSt = invSigma * XSt ; amp=invSigmaXSt./(invSigma*A*traceSS);f = sum( sum( traceSS .* ( A' * invSigma * A ) ) ) ... - 2 * sum( sum( A .* ( invSigma * XSt ) ) ) ;Aneg=any(any(amp<0));KT_ite=0; Aerror = inf ;alpha = 2; %1.1 ;eta = 1 ;%maxratio=4; minratio=0.25; % maximum increase/decrease factorswhile ~Aneg & KT_ite<KT_max_ite & Aerror > Atol KT_ite=KT_ite+1; Aold = A ; A=A.*amp.^eta; %A=min(A,maxratio*Aold); A=max(A,minratio*Aold); fnew = sum( sum( traceSS .* ( A' * invSigma * A ) ) ) ... - 2 * sum( sum( A .* ( invSigma * XSt ) ) ) ; if f < fnew eta = 1 ; A = Aold .* amp ; %A=min(A,maxratio*Aold); A=max(A,minratio*Aold); f = sum( sum( traceSS .* ( A' * invSigma * A ) ) ) ... - 2 * sum( sum( A .* ( invSigma * XSt ) ) ) ; else eta = alpha * eta ; f = fnew ; end amp=invSigmaXSt./(invSigma*A*traceSS); Aneg=any(any(amp<0)); Aerror = sum(sum(abs(A-Aold))) / sizeA ;end%KT_iteif Aneg % use quadratic programming instead - doesn't work for Sigma full M=size(traceSS,1); D=size(XSt,1); options=optimset('Display','off','TolX',10^5*eps,'TolFun',10^5*eps); for i=1:D B=quadprog(traceSS,-XSt(i,:)',[],[],[],[],zeros(M,1),[],A(i,:)',options); A(i,:)=B'; end end%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% initialization %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [A,Sigma,par] = initializeMF(X,par) % function handle to mean field solvertry par.mf_solver = str2func( sprintf('%s_solver',par.solver) ) ;catch par.mf_solver = str2func('ec_solver') ;end % methodtry switch par.method case 'constant' par.Aprior='constant'; par.Sigmaprior='constant'; case 'fa' % factor analysis par.Aprior='free'; par.Sprior='Gauss'; par.Sigmaprior='diagonal'; case 'neg_kurtosis' par.Sprior='bigauss'; case 'pos_kurtosis' par.Sprior='mog'; case 'positive' par.Sprior='exponential'; par.Aprior='positive'; case 'ppca' % probabilistic PCA par.Aprior='free'; par.Sprior='Gauss'; par.Sigmaprior='isotropic'; endend % number of inputs, examples[ par.D , par.N ] = size(X) ;% A and number of sourcestry par.Aprior; catch par.Aprior = 'free' ; endtry par.M = par.sources; catch par.M = par.D ; end % default square mixing try A = par.A_init; [ par.D , par.M ] = size( A ); catch Ascale = 0.1 / sqrt(par.M) ; A = Ascale * randn( par.D , par.M ); % use small random matrix if strcmp(par.Aprior, 'positive' ) A = A .* sign(A) / sqrt( par.M) ; endendif strcmp( par.Aprior ,'constant' ) par.Asize = 0 ; try par.A_init; catch error('par.A_init should be defined when par.Aprior = ''constant'''); endelse par.Asize = par.D * par.M ;end% the sourcestry % so far the initial value of S is only used to initialize Sigma!!! S = par.S_init ; if size(S,1) ~= par.M error('size(S,1) must equal the number of sources'); endcatch S=zeros(par.M,par.N); endtry % function handle to mean functions par.Smeanf = str2func(par.Sprior) ; par.logZ0f = str2func( sprintf('logZ0_%s',par.Sprior) ) ;catch par.Smeanf = str2func('mog') ; par.logZ0f = str2func('logZ0_mog');end% parameters for prior distribution should be set here in field% par.S(i).xxxtry par.Sprior; catch par.Sprior='mog'; endswitch lower(par.Sprior) case 'combi' for i=1:par.sources % maximum number of different priors try par.S(i).meanf = str2func( par.S(i).prior ) ; par.S(i).logZ0f = str2func( sprintf('logZ0_%s',par.S(i).prior) ) ; switch par.S(i).prior case 'mog' try par.S(i).pi; catch par.S(i).pi = []; end try par.S(i).mu; catch par.S(i).mu = []; end try par.S(i).Sigma; catch par.S(i).Sigma = []; end if isempty(par.S(i).pi) par.S(i).pi = [ 0.5 0.5 ]'; end if isempty(par.S(i).mu) par.S(i).mu = [ 0 0 ]'; end if isempty(par.S(i).Sigma) par.S(i).Sigma = [ 1 0.01 ]'; end end % parameter values for other priors should be added here! catch error('Sprior is combi, but not all sources are specified in par.S(i).prior') ; end end case 'mog' try par.S.pi; catch par.S(1).pi = [0.5 0.5]'; end % mixing proportions try par.S.mu; catch par.S(1).mu = [0 0]'; end % means try par.S.Sigma; catch par.S(1).Sigma = [1 0.01]'; end % means Ssources = length(par.S) ; par.S(1).K = length(par.S(1).pi) ; if Ssources ~= par.sources % use same parameters for all sources for indx=2:par.sources par.S(indx) = par.S(1) ; end endend% Sigmatry par.Sigmaprior; catch par.Sigmaprior = 'isotropic' ; endswitch par.Sigmaprior case 'isotropic' par.Sigmasize = 1 ; case 'diagonal' par.Sigmasize = par.D ; case 'free' par.Sigmasize = par.D^2 ; case 'constant' par.Sigmasize = 0 ; try par.Sigma_init; catch error('par.Sigma_init should be defined when par.Sigmaprior = ''constant'''); endendtry Sigma=par.Sigma_init; catch Sigmascale = 1 ; switch par.Sigmaprior case {'isotropic','constant'} Sigma = Sigmascale * sum(sum( ( X - A * S ).^2 ) ) / (par.D*par.N) ; case 'diagonal' Sigma = Sigmascale * sum( ( X - A * S ).^2 , 2 ) / par.N ; case 'free' Sigma = Sigmascale * ( X - A * S ) * ( X - A * S )' / par.N ; endend% optimizertry par.optimizer; catch par.optimizer = 'aem' ; endif par.Asize + par.Sigmasize == 0 % no parameters to be optimized par.optimizer = 'constant';end% more init of parameters...if strcmp(par.Sigmaprior,'free') & ... ( strcmp(par.optimizer,'bfgs') | strcmp(par.optimizer,'conjgrad') ) % use par.sSigma as variable for unconstrained optimization in bfgs and % conjgrad par.sSigma = sqrtm(Sigma) ; end% run time outputtry par.draw ; catch par.draw = 1 ; end% number of iterations and error tolerancetry par.max_ite; catch par.max_ite = 50 ; endtry par.S_tol; catch par.S_tol = 1e-10 ; end try par.S_max_ite; catch par.S_max_ite = 100 ; end% chi update specific for ec_solvertry par.ecchiupdate ;catch if par.M > 10 par.ecchiupdate = 'sequential' ; else par.ecchiupdate = 'parallel' ; endend %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% mean (S) functions %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [f,df] = bigauss(gamma,lambda,par)sigma2=1; mu=1;lambda=1./(1+lambda*sigma2);f=tanh(mu*gamma.*lambda);if nargout > 1 df=lambda.*(sigma2+mu^2.*lambda.*(1-f.^2));endf=lambda.*(sigma2*gamma+mu*f);function [f,df] = binary(gamma,lambda,par)f=tanh(gamma);if nargout > 1 df=1-f.^2;endfunction [f,df] = binary_01(gamma,lambda,par)tmp=exp(0.5*lambda+gamma);f=1./(1+exp(0.5*lambda-gamma));if nargout > 1 df=f.*(1-f);endfunction [f,df] = combi(gamma,lambda,par)f = zeros(size(gamma)) ; df = zeros(size(gamma)) ;Sindx = par.Sindx ;for i=1:length(Sindx) par.Sindx = Sindx(i) ; % to pass information about current source [f(i,:),df(i,:)] = par.S(Sindx(i)).meanf(gamma(i,:),lambda(i,:),par) ;endfunction [f,df] = exponential(gamma,lambda,par)eta=1;erfclimit=-35;%minlambda=10^-4;%lambda=lambda.*(lambda>minlambda)+minlambda.*(lambda<=minlambda);xi=(gamma-eta)./sqrt(lambda);cc=(xi>erfclimit);%i=find(cc==0);%if ~isempty(i)% cc, pause%endxi1=xi.*cc;epfrac=exp(-(xi1.^2)/2)./(Phi(xi1)*sqrt(2*pi));f=cc.*(xi1+epfrac)./sqrt(lambda); % need to go to higher order to get asymptotics rightif nargout > 1 df=cc.*(1./lambda+f.*(xi1./sqrt(lambda)-f)); % need to go to higher order to get asymptotics right -fix at some point!!!endfunction [f,df] = Gauss(gamma,lambda,par)f=gamma./(1+lambda);if nargout > 1 df=1./(1+lambda);endfunction [f,df] = heavy_tail(gamma,lambda,par)alpha=1; % if changed change also in logZ0_heavy_tailf=gamma./lambda-alpha*gamma./(2*alpha*lambda+gamma.^2);if nargout > 1 df=1./lambda+alpha*(gamma.^2-2*alpha*lambda)./(2*alpha*lambda+gamma.^2).^2;endfunction [f,df] = heavy_tail_plus_delta(gamma,lambda,par)alpha=1; % if changed change also in logZ0_heavy_tail_plus_deltabeta=0.3; % proporation delta if changed change also in logZ0_heavy_tail_plus_deltaZ0ht=exp(0.5*gamma.^2./lambda).*(1+gamma.^2./(2*alpha*lambda)).^(-0.5*alpha);f=(1-beta)*(gamma./lambda-alpha*gamma./(2*alpha*lambda+gamma.^2))./... (beta./Z0ht+(1-beta));if nargout > 1 df=(1-beta)./(beta./Z0ht+(1-beta)).*... (1./lambda+alpha*(gamma.^2-2*alpha*lambda)./(2*alpha*lambda+gamma.^2).^2+... (gamma./lambda-alpha*gamma./(2*alpha*lambda+gamma.^2)).^2)-f.^2;end function [f,df] = Laplace(gamma,lambda,par)erfclimit=-25;eta=1;%minlambda=10^-4;%lambda=lambda.*(lambda>minlambda)+minlambda.*(lambda<=minlambda);xip=(gamma-eta)./sqrt(lambda);ccp=(xip>erfclimit);ccpc=not(ccp);xip1=ccp.*xip;xim=-(gamma+eta)./sqrt(lambda);ccm=(xim>erfclimit);ccmc=not(ccm);xim1=ccm.*xim;Dp=exp(-(xip1.^2)/2)/sqrt(2*pi);Dm=exp(-(xim1.^2)/2)/sqrt(2*pi);kp=Phi(xip1).*Dm; km=Phi(xim1).*Dp; f=ccp.*ccm.*(xip.*kp-xim.*km)./(sqrt(lambda).*(kp+km))+(-ccpc.*xim+ccmc.*xip)./sqrt(lambda);if nargout > 1 df=(ccp.*ccm.*(1+xim.*xip+Dp.*Dm.*(xip+xim)./(kp+km)+sqrt(lambda).*(xip.*km-xim.*kp)./(kp+km).*f)+ccpc+ccmc)./lambda;endfunction [f,df] = mog(gamma,lambda,par)[M N ] = size(gamma);oN = ones(N,1) ;K = length(par.S(par.Sindx(1)).pi) ;oK = ones(K,1) ; for i = 1:length(par.Sindx) % loop over sources cindx = par.Sindx(i) ; cgamma = gamma(i,:) ; clambda = lambda(i,:) ; % 1 * N logpi = log(par.S(cindx).pi) ; % K * N cmu = par.S(cindx).mu ; % K * 1 cSigma = par.S(cindx).Sigma ; % K * 1 mu2dSigma = cmu.^2 ./ cSigma ; musqrtSigma = cmu .* sqrt(cSigma) ; opSigmalambda = 1 + cSigma * clambda ; resp = logpi(:,oN) - 0.5 * log( opSigmalambda ) .... + 0.5 * ( sqrt(cSigma) * cgamma + musqrtSigma(:,oN) ) .^2 ./ ... opSigmalambda - 0.5 * mu2dSigma(:,oN) ; % K * N maxresp = max( resp, [] , 1 ) ; resp = exp( resp - maxresp(oK,:) ) ; normalizer = sum( resp, 1 ) ; mconst = ( cgamma(oK,:) + cmu(:,oN) ) .* cSigma(:,oN) ./ opSigmalambda ; f(i,:) = sum( mconst .* resp , 1 ) ./ normalizer ; if nargout > 1 df(i,:) = ... sum( ( cSigma(:,oN) ./ opSigmalambda + mconst.^2 ) .* resp , 1 ) ./ ... normalizer ; endendif nargout > 1 df = df - f.^2 ;end%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% logZ0 functions %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [logZ0] = logZ0_bigauss(gamma,lambda,par)sigma2=1; mu=1;lambda=1./(1+lambda*sigma2);logZ0terms=0.5*(log(lambda)-mu^2/sigma2+lambda.*(mu^2/sigma2+gamma.^2*sigma2))... -log(2)+abs(gamma*mu.*lambda)+log(1+exp(-2*abs(gamma*mu.*lambda))); % log(cosh(gamma*mu.*lambda))logZ0=sum(sum(logZ0terms));function [logZ0] = logZ0_binary(gamma,lambda,par)logZ0terms=-0.5*lambda-log(2)+abs(gamma)+log(1+exp(-2*abs(gamma)));logZ0=sum(sum(logZ0terms));function [logZ0] = logZ0_binary_01(gamma,lambda,par)gamma=0.5*(gamma-0.5*lambda);logZ0terms=gamma-log(2)+abs(gamma)+log(1+exp(-2*abs(gamma)));logZ0=sum(sum(logZ0terms));function [logZ0] = logZ0_combi(gamma,lambda,par)logZ0=0;for i=1:par.M % assumes that run through all sources par.Sindx = i ; % to pass information about current source logZ0 = logZ0 + par.S(i).logZ0f(gamma(i,:),lambda(i,:),par) ;endfunction [logZ0] = logZ0_exponential(gamma,lambda,par)eta=1;erfclimit=-35;%minlambda=10^-4;%lambda=lambda.*(lambda>minlambda)+minlambda.*(lambda<=minlambda);xi=(gamma-eta)./sqrt(lambda);cc=(xi>erfclimit);xi1=xi.*cc;logZ0terms=cc.*(log(Phi(xi1))+0.5*log(2*pi)+0.5*xi1.^2)-(1-cc).*log(abs(xi)+cc)+log(eta)-0.5*log(lambda);logZ0=sum(sum(logZ0terms));function [logZ0] = logZ0_Gauss(gamma,lambda,par)logZ0terms=0.5*gamma.^2./(1+lambda)-0.5*log(1+lambda);logZ0=sum(sum(logZ0terms));function [logZ0] = logZ0_heavy_tail(gamma,lambda,par)alpha=1; % if changed change also in heavy_tail.mlogZ0terms=0.5*gamma.^2./lambda-0.5*alpha*log(1+gamma.^2./(2*lambda*alpha));logZ0=sum(sum(logZ0terms));function [logZ0] = logZ0_heavy_tail_plus_delta(gamma,lambda,par)alpha=1; % if changed change also in heavy_tail_plus_deltabeta=0.3; % proporation delta if changed change also in heavy_tail_plus_deltaZ0ht=exp(0.5*gamma.^2./lambda).*(1+gamma.^2./(2*alpha*lambda)).^(-0.5*alpha);logZ0terms=log(beta+(1-beta)*Z0ht);logZ0=sum(sum(logZ0terms));function [logZ0] = logZ0_Laplace(gamma,lambda,par)erfclimit=-25;eta=1;%minlambda=10^-4;%lambda=lambda.*(lambda>minlambda)+minlambda.*(lambda<=minlambda);xip=(gamma-eta)./sqrt(lambda);ccp=(xip>erfclimit);ccpc=not(ccp);xip1=ccp.*xip;xim=-(gamma+eta)./sqrt(lambda);ccm=(xim>erfclimit);ccmc=not(ccm);xim1=ccm.*xim;Dp=exp(-(xip1.^2)/2)/sqrt(2*pi);Dm=exp(-(xim1.^2)/2)/sqrt(2*pi);Phip=Phi(xip1); Phim=Phi(xim1); logZ0terms=log(0.5*eta)+0.5*log(2*pi./lambda)+... ccp.*ccm.*(0.5*xip1.^2+0.5*xim1.^2+log(Dm.*Phip+Dp.*Phim)-0.5*log(2*pi))+...ccp.*ccmc.*(0.5*xip1.^2+log(Phip+Dp./(abs(xim)+ccm)))+...ccpc.*ccm.*(0.5*xim1.^2+log(Phim+Dm./(abs(xip)+ccp)))+...ccpc.*ccmc.*(log(1./(abs(xip)+ccp)+1./(abs(xim)+ccm))-0.5*log(2*pi));
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -