⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 gmmtrain.m

📁 gmmTrain: Parameter training for gaussian mixture model (GMM)
💻 M
字号:
function [M, V, W, logProb] = gmmTrain(data, gaussianNum, dispOpt)
% gmmTrain: Parameter training for gaussian mixture model (GMM)
%	Usage: function [M, V, W, logProb] = gmm(data, gaussianNum, dispOpt)
%		data: dim x dataNum matrix where each column is a data point
%		gaussianNum: No. of Gaussians or initial centers
%		dispOpt: Option for displaying info during training
%		M: dim x meanNum matrix where each column is a mean vector
%		V: 1 x gaussianNum vector where each element is a variance for a Gaussian
%		W: 1 x gaussianNum vector where each element is a weighting factor for a Gaussian

% Roger Jang 20000610

if nargin==0, selfdemo; return; end
if nargin<3, dispOpt=0; end

maxLoopCount = 10;	% Max. iteration
minImprove = 1e-6;	% Min. improvement
minVariance = 1e-6;	% Min. variance
logProb = zeros(maxLoopCount, 1);   % Array for objective function
[dim, dataNum] = size(data);

% Set initial parameters
% Set initial M
%M = data(1+floor(rand(gaussianNum,1)*dataNum),:);	% Randomly select several data points as the centers
if length(gaussianNum)==1,
	% Using vqKmeans to find initial centers
	fprintf('Using KMEANS to find the initial mu...\n');
%	M = vqKmeansMex(data, gaussianNum, 0);
	M = vqKmeans(data, gaussianNum, 0);
%	M = vqLBG(data, gaussianNum, 0);
	fprintf('Done with kmeans!\n');
	if any(any(~isfinite(M))); keyboard; end
else
	% gaussianNum is in fact the initial centers
	M = gaussianNum;
	gaussianNum = size(M, 2);
end
% Set initial V as the distance to the nearest center
if gaussianNum==1
	V=1;
else
	distance=vecdist(M');
	distance(1:(gaussianNum+1):gaussianNum^2)=inf;	% Diagonal elements are inf
	[minDistance, index]=min(distance);
	V=minDistance.^2;	% Initial variance for each Gaussian
end
% Set initial W
W = ones(1, gaussianNum)/gaussianNum;	% Weight for each Gaussian

if dispOpt & dim==2, displayGmm(M, V, data); end
for i = 1:maxLoopCount
	% Expectation step:
	% P(i,j) is the probability of data(:,j) to the i-th Gaussian
	[prob, P]=gmmEval(data, M, V, W);
	logProb(i)=sum(log(prob));
	if dispOpt, fprintf('i = %d, log prob. = %f\n',i-1, logProb(i)); end
	PW = diag(W)*P;
	BETA=PW./(ones(gaussianNum,1)*sum(PW));	% BETA(i,j) is beta_i(x_j)
	sumBETA=sum(BETA,2);

	% Maximization step:  eqns (2.96) to (2.98) from Bishop p.67:
	M = (data*BETA')./(ones(dim,1)*sumBETA');
	DISTSQ = vecdist(M', data').^2;					% Distance of M to data
	V = max((sum(BETA.*DISTSQ, 2)./sumBETA)/dim, minVariance);	% (2.97)
	W = (1/dataNum)*sumBETA;					% (2.98)

	if dispOpt & dim==2, displayGmm(M, V, data); end
	if i>1, if logProb(i)-logProb(i-1)<minImprove, break; end; end
end
[prob, P]=gmmEval(data, M, V, W);
logProb(i)=sum(log(prob));
fprintf('Iteration count = %d, log prob. = %f\n',i, logProb(i));
logProb(i+1:maxLoopCount) = [];

% ====== Self Demo ======
function selfdemo
[data, gaussianNum] = dcdata(2);
data=data';
plotOpt=1;
[M, V, W, lp] = feval(mfilename, data, gaussianNum, plotOpt);

pointNum = 40;
x = linspace(min(data(1,:)), max(data(1,:)), pointNum);
y = linspace(min(data(2,:)), max(data(2,:)), pointNum);
[xx, yy] = meshgrid(x, y);
data = [xx(:) yy(:)]';
z = gmmEval(data, M, V, W);
zz = reshape(z, pointNum, pointNum);
figure; mesh(xx, yy, zz); axis tight; box on; rotate3d on
figure; contour(xx, yy, zz, 30); axis image

% ====== Other subfunctions ======
function displayGmm(M, V, data)
% Display function for EM algorithm
figureH=findobj(0, 'tag', mfilename);
if isempty(figureH)
	figureH=figure;
	set(figureH, 'tag', mfilename);
	colordef black
	plot(data(1,:), data(2,:),'.r'); axis image
	theta=linspace(-pi, pi, 21);
	x=cos(theta); y=sin(theta);
	sigma=sqrt(V);
	for i=1:length(sigma)
		circleH(i)=line(x*sigma(i)+M(1,i), y*sigma(i)+M(2,i), 'color', 'y');
	end
	set(circleH, 'tag', 'circleH', 'erasemode', 'xor');
else
	circleH=findobj(figureH, 'tag', 'circleH');
	theta=linspace(-pi, pi, 21);
	x=cos(theta); y=sin(theta);
	sigma=sqrt(V);
	for i=1:length(sigma)
		set(circleH(i), 'xdata', x*sigma(i)+M(1,i), 'ydata', y*sigma(i)+M(2,i));
	end
	drawnow
end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -