📄 gmmtrain.m
字号:
function [gmmParam, logProb] = gmmTrain(data, gaussianNumCovType, gmmTrainParam)
% gmmTrain: Parameter training for gaussian mixture model (GMM)
% Usage: function [gmmParam, logProb] = gmmTrain(data, gaussianNumCovType, gmmTrainParam)
% data: dim x dataNum matrix where each column is a data point
% gaussianNumCovType: A two element vector indicating no. of Gaussians and type of covariance matrix
% This function will use gmmInitParamSet() to determine the initial parameters.
% On the other hand, this argument could be a gmmParam which specifies the initial parameters.
% gmmTrainParam: gmm training parameters (this can be obtained from gmmTrainParamSet.m)
% gmmTrainParam.dispOpt: Displaying info during training
% gmmTrainParam.useKmeans: Use k-means to find initial centers
% gmmTrainParam.maxIteration: Max. number of iterations
% gmmTrainParam.minImprove: Min. improvement over the previous iteration
% gmmTrainParam.minVariance: Min. variance for each mixture
% gmmParam: The final parameters for GMM
% logProb: Vector of log probabilities during training
%
% For example, please refer to
% 1-d example: gmmTrainDemo1d.m
% 2-d example: gmmTrainDemo2dCovType01.m, gmmTrainDemo2dCovType02.m, and gmmTrainDemo2dCovType03.
% Roger Jang 20000610, 20080726
if nargin<1, selfdemo; return; end
if nargin<2, gaussianNumCovType=[3, 1]; end
if nargin<3, gmmTrainParam=gmmTrainParamSet; end
if isnumeric(gaussianNumCovType)
gaussianNum = gaussianNumCovType(1);
covType = gaussianNumCovType(2);
else % gaussianNumCovType is in fact the gmmParam.
gmmParam = gaussianNumCovType;
gaussianNum = length(gmmParam);
end
% ======= Error checking
[dim, dataNum] = size(data);
if (dataNum<=gaussianNum)
error(sprintf('The given data size is less than the Gaussian number!\n', dataNum, gaussianNum));
end
range=max(data, [], 2)-min(data, [], 2);
if any(range==0)
fprintf('Warning: Some of dimensions has the same data. Perhaps you should remove data of the dimension first.\n');
elseif max(range)/min(range)>10000
fprintf('Warning: max(range)/min(range)=%g>10000 ===> Perhaps you should normalize the data first.\n', max(range)/min(range));
end
if any(isnan(data(:))) | any(isinf(data(:)))
error('Some element is nan or inf in the given data!');
end
% ====== Set initial parameters
if isnumeric(gaussianNumCovType)
gmmParam=gmmInitParamSet(data, gaussianNum, covType, gmmTrainParam);
end
% ====== Start EM iteration
logProb = zeros(gmmTrainParam.maxIteration, 1); % Array for objective function
if gmmTrainParam.dispOpt & dim==2, displayGmm(data, gmmParam); end
for i = 1:gmmTrainParam.maxIteration
% ====== Expectation step:
% P(i,j) is the probability of data(:,j) to the i-th Gaussian
[logProbs, P]=gmmEval(data, gmmParam);
% [logProbs, P]=gmmEvalMex(data, M, V, W);
logProb(i)=sum(logProbs);
if gmmTrainParam.dispOpt, fprintf('\tGMM iteration: %d/%d, log prob. = %f\n', i-1, gmmTrainParam.maxIteration, logProb(i)); end
W = [gmmParam.w]';
PW = repmat(W, 1, dataNum).*P;
sumPW = sum(PW, 1);
if any(sumPW==0), keyboard; error('some entries of sumPW==0! You need to make sure each point is at least covered by a Gaussian!'); end
BETA=PW./repmat(sumPW, gaussianNum, 1); % BETA(i,j) is beta_i(x_j)
sumBETA=sum(BETA,2);
% ====== Maximization step: eqns (2.96) to (2.98) from Bishop p.67:
% === Compute mu
M = (data*BETA')./repmat(sumBETA', dim, 1);
% Distribute the parameters to gmmParam
meanCell=mat2cell(M, dim, ones(1, gaussianNum));
[gmmParam.mu]=deal(meanCell{:});
% === Compute w
W = (1/dataNum)*sumBETA; % (2.98)
wCell=mat2cell(W', 1, ones(1, gaussianNum));
[gmmParam.w]=deal(wCell{:});
% === Compute sigma
DISTSQ = pairwiseSqrDistance(M, data); % Distance of M to data
if prod(size(gmmParam(1).sigma))==1 % identity covariance matrix times a constant for each Gaussian
V = max((sum(BETA.*DISTSQ, 2)./sumBETA)/dim, gmmTrainParam.minVariance); % (2.97)
for j=1:gaussianNum
gmmParam(j).sigma=V(j);
end
elseif prod(size(gmmParam(1).sigma))==dim % diagonal covariance matrix for each Gaussian
% This segment remains to be double checked
for j=1:gaussianNum
dataMinusMu = data-repmat(gmmParam(j).mu, 1, dataNum);
weight = repmat(BETA(j,:), dim, 1);
gmmParam(j).sigma=sum((weight.*dataMinusMu).*dataMinusMu, 2)/sum(BETA(j,:));
end
else % full covariance matrix for each Gaussian
for j=1:gaussianNum
dataMinusMu = data-repmat(gmmParam(j).mu, 1, dataNum);
weight = repmat(BETA(j,:), dim, 1);
gmmParam(j).sigma=(weight.*dataMinusMu)*dataMinusMu'/sum(BETA(j,:));
end
end
% === Animation
if gmmTrainParam.dispOpt & dim==2, displayGmm(data, gmmParam); end
% ====== Check stopping criterion
if i>1, if logProb(i)-logProb(i-1)<gmmTrainParam.minImprove, break; end; end
end
[logProbs, P]=gmmEval(data, gmmParam);
logProb(i)=sum(logProbs);
if gmmTrainParam.dispOpt, fprintf('\tGMM total iteration count = %d, log prob. = %f\n',i, logProb(i)); end
logProb(i+1:gmmTrainParam.maxIteration) = [];
% ====== Subfunctions ======
function displayGmm(data, gmmParam)
% Display function for EM algorithm
figureH=findobj(0, 'tag', mfilename);
if isempty(figureH)
figureH=figure;
set(figureH, 'tag', mfilename);
plot(data(1,:), data(2,:),'.r'); axis image
for i=1:length(gmmParam)
[xData, yData]=halfHeightContour(gmmParam(i));
circleH(i)=line(xData, yData, 'color', 'k', 'linewidth', 3);
end
set(circleH, 'tag', 'circleH', 'erasemode', 'xor');
else
circleH=findobj(figureH, 'tag', 'circleH');
for i=1:length(gmmParam)
[xData, yData]=halfHeightContour(gmmParam(i));
set(circleH(i), 'xdata', xData, 'ydata', yData);
end
drawnow
end
% ====== Obtain the contour data at half height of an Gaussian
function [xData, yData]=halfHeightContour(gaussianParam)
dim=length(gaussianParam.mu);
theta=linspace(-pi, pi, 21);
if prod(size(gaussianParam.sigma))==1 % identity covariance matrix times a constant for each Gaussian
r=sqrt(2*log(2)*gaussianParam.sigma); % Gaussian reaches it's 50% height at this distance from the mean
xData=r*cos(theta)+gaussianParam.mu(1);
yData=r*sin(theta)+gaussianParam.mu(2);
elseif prod(size(gaussianParam.sigma))==dim % diagonal covariance matrix for each Gaussian
r1=sqrt(2*log(2)*gaussianParam.sigma(1));
r2=sqrt(2*log(2)*gaussianParam.sigma(2));
xData=r1*cos(theta)+gaussianParam.mu(1);
yData=r2*sin(theta)+gaussianParam.mu(2);
else % full covariance matrix for each Gaussian
[V, D]=eig(gaussianParam.sigma);
r1=sqrt(2*log(2)*D(1,1));
r2=sqrt(2*log(2)*D(2,2));
rotatedData=[r1*cos(theta); r2*sin(theta)];
origData=V*rotatedData;
xData=origData(1,:)+gaussianParam.mu(1);
yData=origData(2,:)+gaussianParam.mu(2);
end
% ====== Self Demo ======
function selfdemo
gmmTrainDemo2dCovType01;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -