⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 demo1.m

📁 EM算法介绍及Matlab演示代码(一维和多维高斯混合模型学习算法)
💻 M
字号:
%% settings

M=4;        % number of Gaussian
N=65536;    % total number of data samples

th=1e-3;    % convergent threshold
Nit=200;    % maximal iteration
Nrep=4;     % number of repetation to find global maximal

plot_flag=1;
print_flag=1;

%% paramethers for random signal genrator

% random parameters for M Gaussian signals
 mu_real = randn(M,1);                    % mean
var_real = abs(randn(M,1));               % variance

% probablilty of a channel being selected
a_real = abs(randn(M,1));
a_real = a_real/sum(a_real);    % normlize

if print_flag==1
      a_real
     mu_real
    var_real
end

%% generate random sample of Gaussian variables

x=randn(1,N);

a_cum=cumsum(a_real);
a_cum=[0;a_cum(1:end-1)];
m=rand(1,N);
for c=1:N
    m(c)=sum(m(c)>a_cum);
    %m(c)=sum(m(c)<a_cum);
end
x=x.*var_real(m)'+mu_real(m)';

if plot_flag==1
    figure(1); clf; hold on; grid on;
    [h hx]=hist(m);
    stem(hx,h/N)
end

%% EM Algorothm

% loop
f_best=-inf;
for crep=1:Nrep
    c=1;    
    
    % initial values of parameters for EM
    a=abs(randn(M,1));  % randomly generated 
    a=a/sum(a); % normlize, such that sum(a_EM)=1
    mu=randn(M,1);
    var=abs(randn(M,1));

    while 1
          a_old=  a;
         mu_old= mu;
        var_old=var;
        
        % pmx(m,x|param)
        pmx=zeros(M,N);
        for cm=1:M
            pmx(cm,:)=a(cm)/(sqrt(2*pi)*var(cm))*exp(-0.5*((x-mu(cm))./var(cm)).^2);
        end
    
        % p(m|x,param) for estimated parameters
        p=pmx./kron(ones(M,1),sum(pmx));
    
          a = 1/N*sum(p')';    
         mu = sum((kron(ones(M,1),x).*p)')'./(N*a);
        var = sqrt(sum((p.*(kron(ones(M,1),x)-kron(mu,ones(1,N))).^2)')')./sqrt(N*a);

        t=max([norm(  a_old-  a)/norm(  a_old);
               norm( mu_old- mu)/norm( mu_old);
               norm(var_old-var)/norm(var_old)]);
        if print_flag==1
            fprintf('c=%04d: t=%f\n',c,t);
            c=c+1;
        end
        
        if t<th
            break;
        end
    
        if c>Nit
            disp('reach maximal iteration')
            break;
        end
    end

    f=sum(log(sum(pmx.*kron(ones(1,N),a))));
    if f>f_best
          a_best=a;
         mu_best=mu;
        var_best=var;
          f_best=f
    end
end

%% plot all

if plot_flag==1  
    [h hx]=hist(x,N/50);

    px=zeros(1,length(hx));
    for cm=1:M
        px=px+a(cm)/(sqrt(2*pi)*var(cm))*exp(-0.5*((hx-mu(cm))/var(cm)).^2);
    end

    px_real=zeros(1,length(hx));
    for cm=1:M
        px_real=px_real+a_real(cm)/(sqrt(2*pi)*var_real(cm))*exp(-0.5*((hx-mu_real(cm))/var_real(cm)).^2);
    end

    figure(2); clf; hold on; grid on;
    plot(hx,h/max(h), 'c')
    plot(hx,px     /max(px     ),'r-')
    plot(hx,px_real/max(px_real),'k-')
    legend('normlized hist', ...
           'resimated PDF', ...
           'real PDF');
end

if print_flag==1
    for cm=1:M
        fprintf(1,'a[%d]=%+01.04f\t, a_real[%d]=%+01.04f\n', cm,a(cm),cm,a_real(cm));
    end
    for cm=1:M
        fprintf(1,'mu[%d]=%+01.04f\t, mu_real[%d]=%+01.04f\n', cm,mu(cm),cm,mu_real(cm));
    end
    for cm=1:M
        fprintf(1,'var[%d]=%+01.04f\t, var_real[%d]=%+01.04f\n', cm,var(cm),cm,var_real(cm));
    end
end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -