📄 gmmtrain.m
字号:
function [M, V, W, logProb] = gmmTrain(data, gaussianNum, gmmTrainParam)
% gmmTrain: Parameter training for gaussian mixture model (GMM)
% Usage: function [M, V, W, logProb] = gmmTrain(data, gaussianNum, gmmTrainParam)
% data: dim x dataNum matrix where each column is a data point
% gaussianNum: No. of Gaussians or initial centers
% gmmTrainParam: gmm training parameters (this can be obtained from gmmTrainParamSet.m)
% gmmTrainParam.dispOpt: Displaying info during training
% gmmTrainParam.useKmeans: Use k-means to find initial centers
% gmmTrainParam.maxIteration: Max. number of iterations
% gmmTrainParam.minImprove: Min. improvement over the previous iteration
% gmmTrainParam.minVariance: Min. variance for each mixture
% M: dim x meanNum matrix where each column is a mean vector
% V: 1 x gaussianNum vector where each element is a variance for a Gaussian
% W: 1 x gaussianNum vector where each element is a weighting factor for a Gaussian
%
% For example, please refer to gmmDemo01.m & gmmDemo02.m.
% Roger Jang 20000610
if nargin<1, selfdemo; return; end
if nargin<3, gmmTrainParam=gmmTrainParamSet; end
% Error checking
[dim, dataNum] = size(data);
if (dataNum<=gaussianNum)
error(sprintf('The given data size is less than the Gaussian number!\n', dataNum, gaussianNum));
end
range=max(data, [], 2)-min(data, [], 2);
if any(range==0)
fprintf('Warning: Some of dimensions has the same data. Perhaps you should remove data of the dimension first.\n');
elseif max(range)/min(range)>1000000
fprintf('Warning: max(range)/min(range)=%g>1000000 ===> Perhaps you should normalize the data first.\n', max(range)/min(range));
end
if any(isnan(data(:))) | any(isinf(data(:)))
error('Some element is nan or inf in the given data!');
end
% ====== If only one gaussian, no iterative training is needed.
if gaussianNum==1
W=1;
M=mean(data, 2);
V=sum(sum(data-repmat(M, 1, dataNum)).^2)/(dim*dataNum); % A scalar!
% logProb(1) = sum(gaussianLog(data, M, V*eye(dim)));
logProb(1) = sum(gmmEvalMex(data, M, V, W));
return
end
logProb = zeros(gmmTrainParam.maxIteration, 1); % Array for objective function
% ====== Set initial parameters
% ====== Set initial M
if length(gaussianNum)==1,
% Here we try several methods to find the initial centers
if gmmTrainParam.useKmeans
if gmmTrainParam.dispOpt, fprintf('\tStart KMEANS to find the initial mu...\n'); end
% M = vqKmeansMex(data, gaussianNum, 0); % Method 1: Fast but less robust
M = vqKmeans(data, gaussianNum, 0); % Method 2: Slow but more robust
if any(any(~isfinite(M)))
M = vqLBG(data, gaussianNum, 0); % Try another method of vqLBG
end
if any(any(~isfinite(M)))
M = data(:, 1+floor(rand(gaussianNum,1)*dataNum)); % Try another method of random selection
end
else
M = data(:, 1+floor(rand(gaussianNum,1)*dataNum)); % Randomly select several data points as the centers
end
if any(any(~isfinite(M)))
error('Initial centers by k-means are not finite!');
end
else
% gaussianNum is in fact the initial centers
M = gaussianNum;
gaussianNum = size(M, 2);
end
% ====== Set initial V as the distance to the nearest center
distance=pairwiseSqrDist(M);
distance(1:(gaussianNum+1):gaussianNum^2)=inf; % Diagonal elements are inf
[V, index]=min(distance); % Initial variance for each Gaussian
V = max(V, gmmTrainParam.minVariance);
% ====== Set initial W
W = ones(1, gaussianNum)/gaussianNum; % Weight for each Gaussian
% ====== Start EM iteration
if gmmTrainParam.dispOpt & dim==2, displayGmm(M, V, data); end
for i = 1:gmmTrainParam.maxIteration
% ====== Expectation step:
% P(i,j) is the probability of data(:,j) to the i-th Gaussian
[logProbs, P]=gmmEvalMex(data, M, V, W);
logProb(i)=sum(logProbs);
if gmmTrainParam.dispOpt
fprintf('\tGMM iteration: i = %d, log prob. = %f\n',i-1, logProb(i));
end
PW = diag(W)*P;
sumPW = sum(PW, 1);
BETA=PW./repmat(sumPW, gaussianNum, 1); % BETA(i,j) is beta_i(x_j)
sumBETA=sum(BETA, 2);
% ====== Maximization step: eqns (2.96) to (2.98) from Bishop p.67:
M = (data*BETA')./repmat(sumBETA', dim, 1);
DISTSQ = pairwiseSqrDist(M, data); % Distance of M to data
V = max((sum(BETA.*DISTSQ, 2)./sumBETA)/dim, gmmTrainParam.minVariance); % (2.97)
W = (1/dataNum)*sumBETA; % (2.98)
if gmmTrainParam.dispOpt & dim==2, displayGmm(M, V, data); end
if i>1, if logProb(i)-logProb(i-1)<gmmTrainParam.minImprove, break; end; end
end
%[prob, P]=gmmEval2(data, M, V, W);
%logProb(i)=sum(myLog(prob));
[logProbs, P]=gmmEval(data, M, V, W);
logProb(i)=sum(logProbs);
if gmmTrainParam.dispOpt, fprintf('\tGMM total iteration count = %d, log prob. = %f\n',i, logProb(i)); end
logProb(i+1:gmmTrainParam.maxIteration) = [];
% ====== Subfunctions ======
function displayGmm(M, V, data)
% Display function for EM algorithm
figureH=findobj(0, 'tag', mfilename);
if isempty(figureH)
figureH=figure;
set(figureH, 'tag', mfilename);
plot(data(1,:), data(2,:),'.r'); axis image
theta=linspace(-pi, pi, 21);
x=cos(theta); y=sin(theta);
sigma=sqrt(V);
for i=1:length(sigma)
circleH(i)=line(x*sigma(i)+M(1,i), y*sigma(i)+M(2,i), 'color', 'k', 'linewidth', 3);
end
set(circleH, 'tag', 'circleH', 'erasemode', 'xor');
else
circleH=findobj(figureH, 'tag', 'circleH');
theta=linspace(-pi, pi, 21);
x=cos(theta); y=sin(theta);
sigma=sqrt(V);
for i=1:length(sigma)
set(circleH(i), 'xdata', x*sigma(i)+M(1,i), 'ydata', y*sigma(i)+M(2,i));
end
drawnow
end
% ====== Self Demo ======
function selfdemo
gmmDemo02;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -