📄 gmm_em.m
字号:
function [g, fit] = gmm_em(s, g, N)
%function [g, fit] = gmm_em(s, g, N)
%
% INPUTS:
% s - samples
% g - initial gmm
% N - number of iterations of EM
%
% OUTPUT:
% g - resultant gmm
% fit - negative log-likelihood of fit
%
% REFERENCES:
% Figueiredo et al, On Fitting Mixture Models, 1999. Section 2.2, Equations 6 to 9.
% Ian Nabney, NetLab, http://www.ncrg.aston.ac.uk/netlab/index.php
%
% Tim Bailey 2005.
NM = size(g.x, 2); % number of mixtures
NS = size(s, 2); % number of data samples
g.w = g.w / sum(g.w);
while N > 0
N = N - 1;
% E-step: compute assignment likelihood
for i=1:NM
v = s - repvec(g.x(:,i), NS);
w(i,:) = g.w(i) * gauss_likelihood(v, g.P(:,:,i));
end
wsr = sum(w, 1); % sum across rows: wsr(i) = sum(w(1:NM,i)), where i in NS
fit = -sum(log(wsr)); % negative log-likelihood of fit (from NetLab gmmem.m and gmmpost.m)
wsr = checkzeros(wsr); % avoid divide-by-zero
% TODO: need to adjust also w for positions where wsr is zero (see NetLab 3.3 gmmpost.m)
w = w ./ reprow(wsr, NM); % normalise columns: sum(w(1:NM,i)) == 1, where i in NS
% M-step: compute new (x,P,w) values for gmm
wsc = sum(w, 2); % sum across columns: wsc(i) = sum(w(i,1:NS)), where i in NM
g.w = wsc / sum(wsc); % note, sum(wsc) should equal NS (due to normalisation above)
for i=1:NM
w_norm = w(i,:) ./ wsc(i); % note, wsc(i) is equal to sum(w(i,:)), so sum(w_norm) = 1
% TODO: above line has a possible divide-by-zero error, fix it.
[g.x(:,i), g.P(:,:,i)] = sample_mean_weighted(s, w_norm);
g.P(:,:,i) = checkP(g.P(:,:,i)); % check P has not collapsed
end
end
%
%
% Replicate a column-vector N times
function x = repvec(x,N)
x = x(:, ones(1,N));
% Replicate a row-vector N times
function x = reprow(x,N)
x = x(ones(1,N), :);
% Check array for zero terms, change them to ones
function x = checkzeros(x)
i = find(x==0);
x(i) = 1;
% Alternatives:
% x = x + (x==0);
%or if ~isempty(i), x(i) = 1; end
% Check covariance for collapse, if so, inflate it
function P = checkP(P)
%if any(abs(diag(P)) < 1e-9) % check trace
if det(P) < eps % check determinant
P = eye(size(P));
end
% TODO: improve checkP. NetLab uses measure: if min(svd(P)) < MINCOV, P=Pinit; end
% Where Pinit is the original covariance for that component.
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -