📄 gmmtrain.m
字号:
function [M, V, W, logProb] = gmmTrain(data, gaussianNum, dispOpt)
% gmmTrain: Parameter training for gaussian mixture model (GMM)
% Usage: function [M, V, W, logProb] = gmm(data, gaussianNum, dispOpt)
% data: dim x dataNum matrix where each column is a data point
% gaussianNum: No. of Gaussians or initial centers
% dispOpt: Option for displaying info during training
% M: dim x meanNum matrix where each column is a mean vector
% V: 1 x gaussianNum vector where each element is a variance for a Gaussian
% W: 1 x gaussianNum vector where each element is a weighting factor for a Gaussian
% Roger Jang 20000610
if nargin==0, selfdemo; return; end
if nargin<3, dispOpt=0; end
maxLoopCount = 10; % Max. iteration
minImprove = 1e-6; % Min. improvement
minVariance = 1e-6; % Min. variance
logProb = zeros(maxLoopCount, 1); % Array for objective function
[dim, dataNum] = size(data);
% Set initial parameters
% Set initial M
%M = data(1+floor(rand(gaussianNum,1)*dataNum),:); % Randomly select several data points as the centers
if length(gaussianNum)==1,
% Using vqKmeans to find initial centers
fprintf('Using KMEANS to find the initial mu...\n');
% M = vqKmeansMex(data, gaussianNum, 0);
M = vqKmeans(data, gaussianNum, 0);
% M = vqLBG(data, gaussianNum, 0);
fprintf('Done with kmeans!\n');
if any(any(~isfinite(M))); keyboard; end
else
% gaussianNum is in fact the initial centers
M = gaussianNum;
gaussianNum = size(M, 2);
end
% Set initial V as the distance to the nearest center
if gaussianNum==1
V=1;
else
distance=vecdist(M');
distance(1:(gaussianNum+1):gaussianNum^2)=inf; % Diagonal elements are inf
[minDistance, index]=min(distance);
V=minDistance.^2; % Initial variance for each Gaussian
end
% Set initial W
W = ones(1, gaussianNum)/gaussianNum; % Weight for each Gaussian
if dispOpt & dim==2, displayGmm(M, V, data); end
for i = 1:maxLoopCount
% Expectation step:
% P(i,j) is the probability of data(:,j) to the i-th Gaussian
[prob, P]=gmmEval(data, M, V, W);
logProb(i)=sum(log(prob));
if dispOpt, fprintf('i = %d, log prob. = %f\n',i-1, logProb(i)); end
PW = diag(W)*P;
BETA=PW./(ones(gaussianNum,1)*sum(PW)); % BETA(i,j) is beta_i(x_j)
sumBETA=sum(BETA,2);
% Maximization step: eqns (2.96) to (2.98) from Bishop p.67:
M = (data*BETA')./(ones(dim,1)*sumBETA');
DISTSQ = vecdist(M', data').^2; % Distance of M to data
V = max((sum(BETA.*DISTSQ, 2)./sumBETA)/dim, minVariance); % (2.97)
W = (1/dataNum)*sumBETA; % (2.98)
if dispOpt & dim==2, displayGmm(M, V, data); end
if i>1, if logProb(i)-logProb(i-1)<minImprove, break; end; end
end
[prob, P]=gmmEval(data, M, V, W);
logProb(i)=sum(log(prob));
fprintf('Iteration count = %d, log prob. = %f\n',i, logProb(i));
logProb(i+1:maxLoopCount) = [];
% ====== Self Demo ======
function selfdemo
[data, gaussianNum] = dcdata(2);
data=data';
plotOpt=1;
[M, V, W, lp] = feval(mfilename, data, gaussianNum, plotOpt);
pointNum = 40;
x = linspace(min(data(1,:)), max(data(1,:)), pointNum);
y = linspace(min(data(2,:)), max(data(2,:)), pointNum);
[xx, yy] = meshgrid(x, y);
data = [xx(:) yy(:)]';
z = gmmEval(data, M, V, W);
zz = reshape(z, pointNum, pointNum);
figure; mesh(xx, yy, zz); axis tight; box on; rotate3d on
figure; contour(xx, yy, zz, 30); axis image
% ====== Other subfunctions ======
function displayGmm(M, V, data)
% Display function for EM algorithm
figureH=findobj(0, 'tag', mfilename);
if isempty(figureH)
figureH=figure;
set(figureH, 'tag', mfilename);
colordef black
plot(data(1,:), data(2,:),'.r'); axis image
theta=linspace(-pi, pi, 21);
x=cos(theta); y=sin(theta);
sigma=sqrt(V);
for i=1:length(sigma)
circleH(i)=line(x*sigma(i)+M(1,i), y*sigma(i)+M(2,i), 'color', 'y');
end
set(circleH, 'tag', 'circleH', 'erasemode', 'xor');
else
circleH=findobj(figureH, 'tag', 'circleH');
theta=linspace(-pi, pi, 21);
x=cos(theta); y=sin(theta);
sigma=sqrt(V);
for i=1:length(sigma)
set(circleH(i), 'xdata', x*sigma(i)+M(1,i), 'ydata', y*sigma(i)+M(2,i));
end
drawnow
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -