⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 em.m

📁 Matlab源代码
💻 M
字号:
function EM(data,r)
% function of EM for data1,2,3
%input: data---one of the four data sets
%       r---the times of initialization

%choose a data set
%k is the number of classes
if data == 1
    load gauss2;
    m = 2;
    [n f] = size(gauss2);
    d = f - 1;
    data1 = gauss2(:,1:d);
elseif data == 2
    load gauss3;
    m = 3;
    [n f] = size(gauss3);
    d = f - 1;
    data1 = gauss3(:,1:d);
elseif data == 3
    load iris;
    m = 3;
    [n f] = size(iris);
    d = f - 1;
    data1 = iris(:,1:d);
end

max_log_like = -1e5;
for ini_num = 1:r
    %initialization
    for k = 1:m
        initial = round(rand()*n);
        if initial == 0
            initial = 1;
        end
        cParams(k).mu = data1(initial,:);
        cParams(k).covar = eye(d);
        cParams(k).prior = 1/m;
    end
    
    %plot the initial parameter values
    if ini_num == 1
        plotgauss(cParams,200,m);
    end
    
    log_like = 0;
    log_like_last = 1;
    counter = 0;
    
    while(abs(log_like_last-log_like) >1e-4 & counter < 50)
        counter = counter + 1;
        log_like_last = log_like;
        
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%E-Step%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        %compute E[z(ij)]
        sum_dem = zeros(1,n); %compute the sum of the denominator
        for i = 1:n 
            for s = 1:m
                sum_dem(i) =  sum_dem(i) + ...
                              mvnpdf(data1(i,:),cParams(s).mu,cParams(s).covar) * ...
                              cParams(s).prior;
            end
        end

        for i = 1:n %compute E[z(ij)]
            for j = 1:m
                E_z(i,j) = ( mvnpdf(data1(i,:),cParams(j).mu,cParams(j).covar)* ...
                             cParams(j).prior ) / sum_dem(i);
            end
        end

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%M-Step%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        %update prior
        for j = 1:m
            cParams(j).prior = mean(E_z(:,j));
        end

        %update mu
        for j = 1:m
            sum_mu = 0;
            for i = 1:n
                sum_mu = sum_mu + E_z(i,j)*data1(i,:);
            end
            cParams(j).mu = 1/(n*cParams(j).prior) * sum_mu;
        end

        %update covar
        for j = 1:m
            sum_covar = 0;
            for i = 1:n
                sum_covar = sum_covar + E_z(i,j)* ...
                                        ( (data1(i,:)-cParams(j).mu)'* ...
                                          (data1(i,:)-cParams(j).mu) );
            end
            cParams(j).covar = 1/(n*cParams(j).prior) * sum_covar;
    
            %make sure covar(i,i) greater than threshold(larger than 0)
            for i = 1:d
                if cParams(j).covar(i,i) < 1e-3
                    cParams(j).covar(i,i) = 1e-3
                end
            end
        end        
        %%%%%%%%%%%%%%%%compute the log-likelihood%%%%%%%%%%%%%%%%%%%%%%%%%
        for i = 1:n
            for j = 1:m
                log_like = 0;
                log_like = log_like + ...
                    log(mvnpdf(data1(i,:),cParams(j).mu,cParams(j).covar));
            end
        end
        log_like_plot(counter,ini_num) = log_like; %for plot log_like
    end
    
    %save the max log-likelihood and the corresponding perameters in the r trials
    if max_log_like < log_like
        max_log_like = log_like;
        max_cParams = cParams;
        max_ind = ini_num;
    end
end

plotgauss(max_cParams,200,m); %plot the final parameter value

[a b] = size(log_like_plot(:,max_ind));
%a plot of the log-likelihood as a function of iteration number
figure;plot(1:a,log_like_plot(:,max_ind)); 
xlabel('iteration number');
ylabel('log-likelihood');
title('the log-likelihood as a function of iteration number');

    
    
            
            
            
            
    
    
    

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -