📄 gaussmix.m
字号:
ii=1:jx;
kk=repmat(ii,k,1);
km=repmat(1:k,1,jx);
py=reshape(sum((xs(kk(:),:)-m(km(:),:)).^2.*vi(km(:),:),2),k,jx)+lvm(:,wnj);
mx=max(py,[],1); % find normalizing factor for each data point to prevent underflow when using exp()
px=exp(py-mx(wk,:)); % find normalized probability of each mixture for each datapoint
ps=sum(px,1); % total normalized likelihood of each data point
px=px./ps(wk,:); % relative mixture probabilities for each data point (columns sum to 1)
lpx(ii)=log(ps)+mx;
pk=sum(px,2); % effective number of data points for each mixture (could be zero due to underflow)
sx=px*xs(ii,:);
sx2=px*xs2(ii,:);
for il=2:nl
ix=jx+1;
jx=jx+nb; % increment upper limit
ii=ix:jx;
kk=repmat(ii,k,1);
py=reshape(sum((xs(kk(:),:)-m(im,:)).^2.*vi(im,:),2),k,nb)+lvm(:,wnb);
mx=max(py,[],1); % find normalizing factor for each data point to prevent underflow when using exp()
px=exp(py-mx(wk,:)); % find normalized probability of each mixture for each datapoint
ps=sum(px,1); % total normalized likelihood of each data point
px=px./ps(wk,:); % relative mixture probabilities for each data point (columns sum to 1)
lpx(ii)=log(ps)+mx;
pk=pk+sum(px,2); % effective number of data points for each mixture (could be zero due to underflow)
sx=sx+px*xs(ii,:);
sx2=sx2+px*xs2(ii,:);
end
g=sum(lpx); % total log probability summed over all data points
gg(j)=g;
w=pk/n; % normalize to get the weights
if pk % if all elements of pk are non-zero
m=sx./pk(:,wp);
v=sx2./pk(:,wp);
else
wm=pk==0; % mask indicating mixtures with zero weights
[vv,mk]=sort(lpx); % find the lowest probability data points
m=zeros(k,p); % initialize means and variances to zero (variances are floored later)
v=m;
m(wm,:)=xs(mk(1:sum(wm)),:); % set zero-weight mixture means to worst-fitted data points
wm=~wm; % mask for non-zero weights
m(wm,:)=sx(wm,:)./pk(wm,wp); % recalculate means and variances for mixtures with a non-zero weight
v(wm,:)=sx2(wm,:)./pk(wm,wp);
end
v=max(v-m.^2,c); % apply floor to variances
if g-g1<=th && j>1
if ~ss, break; end % stop
ss=ss-1; % stop next time
end
end
if sd && ~fv % we need to calculate the final probabilities
pp=lpx'-0.5*p*log(2*pi)-lsx; % log of total probability of each data point
gg=gg(1:j)/n-0.5*p*log(2*pi)-lsx; % average log prob at each iteration
g=gg(end);
% gg' % *** DEBUG ***
m=m1; % back up to previous iteration
v=v1;
w=w1;
mm=sum(m,1)/k;
f=(m(:)'*m(:)-k*mm(:)'*mm(:))/sum(v(:));
end
if ~fv
m=m.*sx0(ones(k,1),:)+mx0(ones(k,1),:); % unscale means
v=v.*repmat(sx0.^2,k,1); % and variances
else
v1=v;
v=zeros(p,p,k);
mk=eye(p)==1; % mask for diagonal elements
v(repmat(mk,[1 1 k]))=v1'; % set from v1
end
end
if fv % check if full covariance matrices were requested
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Full Covariance matrices %
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
pl=p*(p+1)/2;
lix=1:p^2;
cix=repmat(1:p,p,1);
rix=cix';
lix(cix>rix)=[]; % index of lower triangular elements
cix=cix(lix); % index of lower triangular columns
rix=rix(lix); % index of lower triangular rows
dix=find(rix==cix);
lixi=zeros(p,p);
lixi(lix)=1:pl;
lixi=lixi';
lixi(lix)=1:pl; % reverse index to build full matrices
v=reshape(v,p^2,k);
v=v(lix,:)'; % lower triangular in rows
% If data size is large then do calculations in chunks
nb=min(n,max(1,floor(memsize/(24*p*k)))); % chunk size for testing data points
nl=ceil(n/nb); % number of chunks
jx0=n-(nl-1)*nb; % size of first chunk
%
th=(l-floor(l))*n;
sd=(nargout > 3*(l~=0)); % = 1 if we are outputting log likelihood values
l=floor(l)+sd; % extra loop needed to calculate final G value
%
lpx=zeros(1,n); % log probability of each data point
wk=ones(k,1);
wp=ones(1,p);
wpl=ones(1,pl); % 1 index for lower triangular matrix
wnb=ones(1,nb);
wnj=ones(1,jx0);
% EM loop
g=0; % dummy initial value for comparison
gg=zeros(l+1,1);
ss=sd; % initialize stopping count (0 or 1)
vi=zeros(p*k,p); % stack of k inverse cov matrices each size p*p
vim=zeros(p*k,1); % stack of k vectors of the form inv(v)*m
mtk=vim; % stack of k vectors of the form m
lvm=zeros(k,1);
wpk=repmat((1:p)',k,1);
for j=1:l
g1=g; % save previous log likelihood (2*pi factor omitted)
m1=m; % save previous means, variances and weights
v1=v;
w1=w;
for ik=1:k
% these lines added for debugging only
% vk=reshape(v(k,lixi),p,p);
% condk(ik)=cond(vk);
%%%%%%%%%%%%%%%%%%%%
[uvk,dvk]=eig(reshape(v(ik,lixi),p,p)); % convert lower triangular to full and find eigenvalues
dvk=max(diag(dvk),c); % apply variance floor to eigenvalues
vik=-0.5*uvk*diag(dvk.^(-1))*uvk'; % calculate inverse
vi((ik-1)*p+(1:p),:)=vik; % vi contains all mixture inverses stacked on top of each other
vim((ik-1)*p+(1:p))=vik*m(ik,:)'; % vim contains vi*m for all mixtures stacked on top of each other
mtk((ik-1)*p+(1:p))=m(ik,:)'; % mtk contains all mixture means stacked on top of each other
lvm(ik)=log(w(ik))-0.5*sum(log(dvk)); % vm contains the weighted sqrt of det(vi) for each mixture
end
%
% % first do partial chunk
%
jx=jx0;
ii=1:jx;
xii=xs(ii,:).';
py=reshape(sum(reshape((vi*xii-vim(:,wnj)).*(xii(wpk,:)-mtk(:,wnj)),p,jx*k),1),k,jx)+lvm(:,wnj);
mx=max(py,[],1); % find normalizing factor for each data point to prevent underflow when using exp()
px=exp(py-mx(wk,:)); % find normalized probability of each mixture for each datapoint
ps=sum(px,1); % total normalized likelihood of each data point
px=px./ps(wk,:); % relative mixture probabilities for each data point (columns sum to 1)
lpx(ii)=log(ps)+mx;
pk=sum(px,2); % effective number of data points for each mixture (could be zero due to underflow)
sx=px*xs(ii,:);
sx2=px*(xs(ii,rix).*xs(ii,cix)); % accumulator for variance calculation (lower tri cov matrix as a row)
for il=2:nl
ix=jx+1;
jx=jx+nb; % increment upper limit
ii=ix:jx;
xii=xs(ii,:).';
py=reshape(sum(reshape((vi*xii-vim(:,wnb)).*(xii(wpk,:)-mtk(:,wnb)),p,nb*k),1),k,nb)+lvm(:,wnb);
mx=max(py,[],1); % find normalizing factor for each data point to prevent underflow when using exp()
px=exp(py-mx(wk,:)); % find normalized probability of each mixture for each datapoint
ps=sum(px,1); % total normalized likelihood of each data point
px=px./ps(wk,:); % relative mixture probabilities for each data point (columns sum to 1)
lpx(ii)=log(ps)+mx;
pk=pk+sum(px,2); % effective number of data points for each mixture (could be zero due to underflow)
sx=sx+px*xs(ii,:); % accumulator for mean calculation
sx2=sx2+px*(xs(ii,rix).*xs(ii,cix)); % accumulator for variance calculation
end
g=sum(lpx); % total log probability summed over all data points
gg(j)=g; % save convergence history
w=pk/n; % normalize to get the column of weights
if pk % if all elements of pk are non-zero
m=sx./pk(:,wp); % find mean and mean square
v=sx2./pk(:,wpl);
else
wm=pk==0; % mask indicating mixtures with zero weights
[vv,mk]=sort(lpx); % find the lowest probability data points
m=zeros(k,p); % initialize means and variances to zero (variances are floored later)
v=zeros(k,pl);
m(wm,:)=xs(mk(1:sum(wm)),:); % set zero-weight mixture means to worst-fitted data points
wm=~wm; % mask for non-zero weights
m(wm,:)=sx(wm,:)./pk(wm,wp); % recalculate means and variances for mixtures with a non-zero weight
v(wm,:)=sx2(wm,:)./pk(wm,wpl);
end
v=v-m(:,cix).*m(:,rix); % subtract off mean squared
if g-g1<=th && j>1
if ~ss, break; end % stop
ss=ss-1; % stop next time
end
end
if sd % we need to calculate the final probabilities
pp=lpx'-0.5*p*log(2*pi)-lsx; % log of total probability of each data point
gg=gg(1:j)/n-0.5*p*log(2*pi)-lsx; % average log prob at each iteration
g=gg(end);
% gg' % *** DEBUG ONLY ***
m=m1; % back up to previous iteration
v=zeros(p,p,k);
trv=0; % sum of variance matrix traces
for ik=1:k
[uvk,dvk]=eig(reshape(v1(ik,lixi),p,p)); % convert lower triangular to full and find eigenvectors
dvk=max(diag(dvk),c); % apply variance floor
v(:,:,ik)=uvk*diag(dvk)*uvk'; % reconstitute full matrix
trv=trv+sum(dvk);
end
w=w1;
mm=sum(m,1)/k;
f=(m(:)'*m(:)-k*mm(:)'*mm(:))/trv;
else
v1=v; % lower triangular form
v=zeros(p,p,k);
for ik=1:k
[uvk,dvk,]=eig(reshape(v1(k,lixi),p,p)); % convert lower triangular to full and find eigenvectors
dvk=max(diag(dvk),c); % apply variance floor
v(:,:,ik)=uvk*diag(dvk)*uvk'; % reconstitute full matrix
end
end
m=m.*sx0(ones(k,1),:)+mx0(ones(k,1),:); % unscale means
v=v.*repmat(sx0'*sx0,[1 1 k]);
end
if l==0 % suppress the first three output arguments if l==0
m=g;
v=f;
w=pp;
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -