⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 gmmtrain.m

📁 一个关于数据聚类和模式识别的程序,在生物化学,化学中因该都可以用到.希望对大家有用,谢谢支持
💻 M
字号:
function [M, V, W, logProb] = gmmTrain(data, gaussianNum, gmmTrainParam)
% gmmTrain: Parameter training for gaussian mixture model (GMM)
%	Usage: function [M, V, W, logProb] = gmmTrain(data, gaussianNum, gmmTrainParam)
%		data: dim x dataNum matrix where each column is a data point
%		gaussianNum: No. of Gaussians or initial centers
%		gmmTrainParam: gmm training parameters (this can be obtained from gmmTrainParamSet.m)
%			gmmTrainParam.dispOpt: Displaying info during training
%			gmmTrainParam.useKmeans: Use k-means to find initial centers
%			gmmTrainParam.maxIteration: Max. number of iterations
%			gmmTrainParam.minImprove: Min. improvement over the previous iteration
%			gmmTrainParam.minVariance: Min. variance for each mixture
%		M: dim x meanNum matrix where each column is a mean vector
%		V: 1 x gaussianNum vector where each element is a variance for a Gaussian
%		W: 1 x gaussianNum vector where each element is a weighting factor for a Gaussian
%
%	For example, please refer to gmmDemo01.m & gmmDemo02.m.

% Roger Jang 20000610

if nargin<1, selfdemo; return; end
if nargin<3, gmmTrainParam=gmmTrainParamSet; end

% Error checking
[dim, dataNum] = size(data);
if (dataNum<=gaussianNum)
	error(sprintf('The given data size is less than the Gaussian number!\n', dataNum, gaussianNum));
end
range=max(data, [], 2)-min(data, [], 2);
if any(range==0)
	fprintf('Warning: Some of dimensions has the same data. Perhaps you should remove data of the dimension first.\n');
elseif max(range)/min(range)>1000000
	fprintf('Warning: max(range)/min(range)=%g>1000000 ===> Perhaps you should normalize the data first.\n', max(range)/min(range));
end
if any(isnan(data(:))) | any(isinf(data(:)))
	error('Some element is nan or inf in the given data!');
end

% ====== If only one gaussian, no iterative training is needed.
if gaussianNum==1
	W=1;
	M=mean(data, 2);
	V=sum(sum(data-repmat(M, 1, dataNum)).^2)/(dim*dataNum);	% A scalar!
%	logProb(1) = sum(gaussianLog(data, M, V*eye(dim)));
	logProb(1) = sum(gmmEvalMex(data, M, V, W));
	return
end

logProb = zeros(gmmTrainParam.maxIteration, 1);   % Array for objective function

% ====== Set initial parameters
% ====== Set initial M
if length(gaussianNum)==1,
	% Here we try several methods to find the initial centers
	if gmmTrainParam.useKmeans
		if gmmTrainParam.dispOpt, fprintf('\tStart KMEANS to find the initial mu...\n'); end
%		M = vqKmeansMex(data, gaussianNum, 0);			% Method 1: Fast but less robust
		M = vqKmeans(data, gaussianNum, 0);			% Method 2: Slow but more robust
		if any(any(~isfinite(M)))
			M = vqLBG(data, gaussianNum, 0);			% Try another method of vqLBG
		end
		if any(any(~isfinite(M)))
			M = data(:, 1+floor(rand(gaussianNum,1)*dataNum));	% Try another method of random selection
		end	
	else
		M = data(:, 1+floor(rand(gaussianNum,1)*dataNum));	% Randomly select several data points as the centers
	end
	if any(any(~isfinite(M)))
		error('Initial centers by k-means are not finite!');
	end
else
	% gaussianNum is in fact the initial centers
	M = gaussianNum;
	gaussianNum = size(M, 2);
end
% ====== Set initial V as the distance to the nearest center
distance=pairwiseSqrDist(M);
distance(1:(gaussianNum+1):gaussianNum^2)=inf;	% Diagonal elements are inf
[V, index]=min(distance);	% Initial variance for each Gaussian
V = max(V, gmmTrainParam.minVariance);
% ====== Set initial W
W = ones(1, gaussianNum)/gaussianNum;	% Weight for each Gaussian

% ====== Start EM iteration
if gmmTrainParam.dispOpt & dim==2, displayGmm(M, V, data); end
for i = 1:gmmTrainParam.maxIteration
	% ====== Expectation step:
	% P(i,j) is the probability of data(:,j) to the i-th Gaussian
	[logProbs, P]=gmmEvalMex(data, M, V, W);
	logProb(i)=sum(logProbs);

	if gmmTrainParam.dispOpt
		fprintf('\tGMM iteration: i = %d, log prob. = %f\n',i-1, logProb(i));
	end
	PW = diag(W)*P;
	sumPW = sum(PW, 1);
	BETA=PW./repmat(sumPW, gaussianNum, 1);		% BETA(i,j) is beta_i(x_j)
	sumBETA=sum(BETA, 2);

	% ====== Maximization step:  eqns (2.96) to (2.98) from Bishop p.67:
	M = (data*BETA')./repmat(sumBETA', dim, 1);
	DISTSQ = pairwiseSqrDist(M, data);					% Distance of M to data
	V = max((sum(BETA.*DISTSQ, 2)./sumBETA)/dim, gmmTrainParam.minVariance);	% (2.97)
	W = (1/dataNum)*sumBETA;					% (2.98)

	if gmmTrainParam.dispOpt & dim==2, displayGmm(M, V, data); end
	if i>1, if logProb(i)-logProb(i-1)<gmmTrainParam.minImprove, break; end; end
end
%[prob, P]=gmmEval2(data, M, V, W);
%logProb(i)=sum(myLog(prob));
[logProbs, P]=gmmEval(data, M, V, W);
logProb(i)=sum(logProbs);

if gmmTrainParam.dispOpt, fprintf('\tGMM total iteration count = %d, log prob. = %f\n',i, logProb(i)); end
logProb(i+1:gmmTrainParam.maxIteration) = [];

% ====== Subfunctions ======
function displayGmm(M, V, data)
% Display function for EM algorithm
figureH=findobj(0, 'tag', mfilename);
if isempty(figureH)
	figureH=figure;
	set(figureH, 'tag', mfilename);
	plot(data(1,:), data(2,:),'.r'); axis image
	theta=linspace(-pi, pi, 21);
	x=cos(theta); y=sin(theta);
	sigma=sqrt(V);
	for i=1:length(sigma)
		circleH(i)=line(x*sigma(i)+M(1,i), y*sigma(i)+M(2,i), 'color', 'k', 'linewidth', 3);
	end
	set(circleH, 'tag', 'circleH', 'erasemode', 'xor');
else
	circleH=findobj(figureH, 'tag', 'circleH');
	theta=linspace(-pi, pi, 21);
	x=cos(theta); y=sin(theta);
	sigma=sqrt(V);
	for i=1:length(sigma)
		set(circleH(i), 'xdata', x*sigma(i)+M(1,i), 'ydata', y*sigma(i)+M(2,i));
	end
	drawnow
end

% ====== Self Demo ======
function selfdemo
gmmDemo02;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -