📄 em.m
字号:
function EM(data,r)
% function of EM for data1,2,3
%input: data---one of the four data sets
% r---the times of initialization
%choose a data set
%k is the number of classes
if data == 1
load gauss2;
m = 2;
[n f] = size(gauss2);
d = f - 1;
data1 = gauss2(:,1:d);
elseif data == 2
load gauss3;
m = 3;
[n f] = size(gauss3);
d = f - 1;
data1 = gauss3(:,1:d);
elseif data == 3
load iris;
m = 3;
[n f] = size(iris);
d = f - 1;
data1 = iris(:,1:d);
end
max_log_like = -1e5;
for ini_num = 1:r
%initialization
for k = 1:m
initial = round(rand()*n);
if initial == 0
initial = 1;
end
cParams(k).mu = data1(initial,:);
cParams(k).covar = eye(d);
cParams(k).prior = 1/m;
end
%plot the initial parameter values
if ini_num == 1
plotgauss(cParams,200,m);
end
log_like = 0;
log_like_last = 1;
counter = 0;
while(abs(log_like_last-log_like) >1e-4 & counter < 50)
counter = counter + 1;
log_like_last = log_like;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%E-Step%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%compute E[z(ij)]
sum_dem = zeros(1,n); %compute the sum of the denominator
for i = 1:n
for s = 1:m
sum_dem(i) = sum_dem(i) + ...
mvnpdf(data1(i,:),cParams(s).mu,cParams(s).covar) * ...
cParams(s).prior;
end
end
for i = 1:n %compute E[z(ij)]
for j = 1:m
E_z(i,j) = ( mvnpdf(data1(i,:),cParams(j).mu,cParams(j).covar)* ...
cParams(j).prior ) / sum_dem(i);
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%M-Step%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%update prior
for j = 1:m
cParams(j).prior = mean(E_z(:,j));
end
%update mu
for j = 1:m
sum_mu = 0;
for i = 1:n
sum_mu = sum_mu + E_z(i,j)*data1(i,:);
end
cParams(j).mu = 1/(n*cParams(j).prior) * sum_mu;
end
%update covar
for j = 1:m
sum_covar = 0;
for i = 1:n
sum_covar = sum_covar + E_z(i,j)* ...
( (data1(i,:)-cParams(j).mu)'* ...
(data1(i,:)-cParams(j).mu) );
end
cParams(j).covar = 1/(n*cParams(j).prior) * sum_covar;
%make sure covar(i,i) greater than threshold(larger than 0)
for i = 1:d
if cParams(j).covar(i,i) < 1e-3
cParams(j).covar(i,i) = 1e-3
end
end
end
%%%%%%%%%%%%%%%%compute the log-likelihood%%%%%%%%%%%%%%%%%%%%%%%%%
for i = 1:n
for j = 1:m
log_like = 0;
log_like = log_like + ...
log(mvnpdf(data1(i,:),cParams(j).mu,cParams(j).covar));
end
end
log_like_plot(counter,ini_num) = log_like; %for plot log_like
end
%save the max log-likelihood and the corresponding perameters in the r trials
if max_log_like < log_like
max_log_like = log_like;
max_cParams = cParams;
max_ind = ini_num;
end
end
plotgauss(max_cParams,200,m); %plot the final parameter value
[a b] = size(log_like_plot(:,max_ind));
%a plot of the log-likelihood as a function of iteration number
figure;plot(1:a,log_like_plot(:,max_ind));
xlabel('iteration number');
ylabel('log-likelihood');
title('the log-likelihood as a function of iteration number');
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -