⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 learn_mix1d.m

📁 matlab环境下
💻 M
字号:
function src = learn_mix1d(src,x,x_sq,tol,max_steps)% src = learn_mix1d(src,x,x_sq,tol,max_steps)%% Train a 1-dimensional mixture model using the % Variational Bayes framework.%% Called from 'mixmodel1d', 'learn_ica1' and 'learn_ica2'.%%% -----------% Input% -----------%% Necessary parameters%% src       Current model% x         mixmodel1d: The observation data (sent as mxN)%           vbICA1: Expectation of source signals (mxN) %           vbICA2: Location parameters of source posterior% x_sq      mixmodel1d: x.^2 (sent as mxN)%           vbICA1: Expectation of (source signal.^2) (mxN)%           vbICA2: Precisions of source posterior%%% Optional parameters%% tol       Convergence tolerance           (Default = 1e-5)% max_steps Max number of iteration steps   (Default = 500)%%% Note:     Exponentials take longer to converge, particularly %           if source pdfs not well described by mixture of   %           exponentials. Therefore, tol=1e-6 for MoE. Truncated %           Gaussians are more robust and have better convergence.%%%% -----------% Output% -----------%% SRC is a data structure with the following fields:%% type             'g', 'e', 't' - Mixture of Gaussians, %                  exponentials or truncated Gaussians%                  'hg','he','ht' - HMM versions% m                The number of comp% f_hidd           The negative free energy of hidden vars.% f_params         The negative free energy of the params.% fm               The total negative free energy%%                  In the field priors:% lambda_0         Dirichlet parameters for mixing coeffs% b_0,c_0          Gamma parameters for precisions% m_0,tau_0        Normal parameters for means%%                  In the field posts:% lambda           Dirichlet parameters  for mixing coeffs% b,c              Gamma parameters for precisions% mm,tau           Normal parameters for means%%                  Expected posterior values:% pi               Mixing coefficients% centres          Means % precs            Precisions% gammas           Component probabilities%%% If HMM, also%% pi               Initial state probabilities.% P                Calculated trans. probs. for HMM% posts.eta        Dirichlet parameters for P% posts.lambda     Dirichlet parameters for pi%%% --------------------------------------------------------------%% Original code by Rizwan Choudrey % Thesis: Variational Methods for Bayesian Independent%         Component Analysis (www.robots.ox.ac.uk/~parg)global CHECK_PROGRESS;if nargin<4  tol = 1e-5;endif nargin<5  max_steps = 500;endsrc_type = src.type;pdf_fact = 0.5;HMM=0;if length(src_type) == 2  HMM=1;  src_type = src_type(2);end  %========================Extract appropriate variables=====================[m N]=size(x);  %N is number of data pointsif m==1  ALGORITHM = 1;  m=src.m;        %m is the number of components in mixture model  x = repmat(x,m,1);  x_sq = repmat(x_sq,m,1);  %Dummy variables for gammas.m  m_q = x;  b_q = x_sq;elseif m==src.m  ALGORITHM = 2;  %Dummy variables for gammas.m  m_q = x;  b_q = x_sq;  %Calculate 'true' x,x_sq.  switch src_type   case 'g'    [x,x_sq] = deal(m_q,m_q.^2+1./b_q);   otherwise    [x,x_sq] = rect_expect(m_q,b_q);  endelse  error('Source signal dimensionality and source model mismatch!');endif(src_type == 'e')  x_sq = x;  pdf_fact=1;  tol = min(tol,1e-6);end% PRIORS% mixture priorlambda_0=src.priors.lambda_0;% component precision priorsb_0=src.priors.b_0;c_0=src.priors.c_0;% component mean priorsm_0=src.priors.m_0;tau_0=src.priors.tau_0;% POSTERIORS% component precision postsb=src.posts.b;c=src.posts.c;% component mean postsmm=src.posts.mm;tau=src.posts.tau;%========================Extract appropriate variables=====================% Initialiseftot=0;lik=[];littlebit=eps;Fgauss = 0;% ALGORITHM = 2 for vbICA2. E-step only once for good convergence.outer_steps = max_steps;inner_steps = 1;if ALGORITHM == 2  outer_steps = 1;  inner_steps = max_steps;end% Run...for steps1=1:outer_steps  %==================================E_Step================================  if m==1    gamma = ones(1,N);  else    [gamma xi] = gammas(src,m_q,b_q,ALGORITHM);  end  %==================================E_Step================================            for steps2=1:inner_steps  %==================================M_Step================================        %--------------------Update lambda etc.--------------------  if HMM        % M-step for HMM specific variables     [src,Fp,Fpi,ent_gam] = hmm_mstep(src,gamma,xi);    gamma_sum = sum(gamma');    lambda = gamma_sum;  % for plotting only        % contribution to energy        Fdir = Fp+Fpi;      else        gamma_sum = sum(gamma');    lambda = lambda_0+gamma_sum;    src.pi = lambda./sum(lambda);        % store for E-step    src.posts.lambda=lambda;        % contribution to energy        lambda_p=lambda_0*ones(1,m);    dir1 = sum(gammaln(lambda+eps) - gammaln(lambda_p+eps));    dir2 = gammaln(sum(lambda+eps)) - gammaln(sum(lambda_p+eps));    Fdir = dir1-dir2;    ent_gam=-sum(sum(gamma.*log(gamma+eps)));  end  %--------------------Update lambda etc.--------------------        %--------------------Update precisions---------------------  mean_xsq = sum(gamma.*x_sq,2);  mean_x = sum(gamma.*x,2);  mu_sq = (gamma_sum.*(mm.^2+1./tau))';    data_bit = mean_xsq-2*mm'.*mean_x+mu_sq;  b = 1./(1/b_0+pdf_fact*data_bit');  c = c_0+pdf_fact*gamma_sum;  mean_beta = b.*c;    % store for E-step  src.posts.b=b;  src.posts.c=c;    % contribution to energy  beta1 = gammaln(c) - gammaln(c_0);  beta2 = c.*log(b) - c_0.*log(b_0);  Fbeta = sum(beta1 + beta2);  %--------------------Update precisions---------------------          %-----------------------Update means-----------------------  if src_type == 'g'    tau = tau_0+mean_beta.*gamma_sum;    mm = 1./tau.*(m_0+mean_beta.*mean_x');        % store for E-step    src.posts.mm=mm;    src.posts.tau=tau;        % contribution to energy    b_ratio = tau_0./tau;    gauss1 = -log(b_ratio);    gauss2 = b_ratio-1;    gauss3 = tau_0.*(mm-m_0).^2;    Fgauss = -0.5*sum(gauss1 + gauss2 + gauss3);  end  %-----------------------Update means-----------------------      %==================================M_Step================================             %==================================Energy================================  oldfm = ftot;  f_hidd = ent_gam-N/2*log(2*pi);  f_params =  Fgauss+Fbeta+Fdir;  ftot = f_hidd+f_params;  %==================================Energy================================   % Plot if required  if CHECK_PROGRESS    switch src_type     case 'g'      plotMoG(mm,1./(b.*c),lambda./sum(lambda))      drawnow     case 'e'      plotMoE(1./(b.*c),lambda./sum(lambda),max(max(x'))+1/max(sqrt(b.*c)))      drawnow     case 't'      plotRMoG(1./(b.*c),lambda./sum(lambda),max(max(x'))+1/max(sqrt(b.*c)))      drawnow    end  end      % Convergence criterion for inner_steps  err = abs((oldfm-ftot)/ftot);  if err<tol    break  end    end % End of inner_steps    % Convergence criterion for outer_steps  if err<tol    break  endend  % End of outer_steps  % Put variables into data structuresrc.centres=mm;src.precs=mean_beta;src.gammas=gamma;src.f_hidd = f_hidd;src.f_params = f_params;src.ftot=ftot;  

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -