⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 gmmtrain.m

📁 一个关于数据聚类和模式识别的程序,在生物化学,化学中因该都可以用到.希望对大家有用,谢谢支持
💻 M
字号:
function [gmmParam, logProb] = gmmTrain(data, gaussianNumCovType, gmmTrainParam)
% gmmTrain: Parameter training for gaussian mixture model (GMM)
%	Usage: function [gmmParam, logProb] = gmmTrain(data, gaussianNumCovType, gmmTrainParam)
%		data: dim x dataNum matrix where each column is a data point
%		gaussianNumCovType: A two element vector indicating no. of Gaussians and type of covariance matrix
%			This function will use gmmInitParamSet() to determine the initial parameters.
%			On the other hand, this argument could be a gmmParam which specifies the initial parameters.
%		gmmTrainParam: gmm training parameters (this can be obtained from gmmTrainParamSet.m)
%			gmmTrainParam.dispOpt: Displaying info during training
%			gmmTrainParam.useKmeans: Use k-means to find initial centers
%			gmmTrainParam.maxIteration: Max. number of iterations
%			gmmTrainParam.minImprove: Min. improvement over the previous iteration
%			gmmTrainParam.minVariance: Min. variance for each mixture
%		gmmParam: The final parameters for GMM
%		logProb: Vector of log probabilities during training
%
%	For example, please refer to
%		1-d example: gmmTrainDemo1d.m
%		2-d example: gmmTrainDemo2dCovType01.m, gmmTrainDemo2dCovType02.m, and gmmTrainDemo2dCovType03.

% Roger Jang 20000610, 20080726

if nargin<1, selfdemo; return; end
if nargin<2, gaussianNumCovType=[3, 1]; end
if nargin<3, gmmTrainParam=gmmTrainParamSet; end

if isnumeric(gaussianNumCovType)
	gaussianNum = gaussianNumCovType(1);
	covType = gaussianNumCovType(2);
else	% gaussianNumCovType is in fact the gmmParam.
	gmmParam = gaussianNumCovType;
	gaussianNum = length(gmmParam);
end

% ======= Error checking
[dim, dataNum] = size(data);
if (dataNum<=gaussianNum)
	error(sprintf('The given data size is less than the Gaussian number!\n', dataNum, gaussianNum));
end
range=max(data, [], 2)-min(data, [], 2);
if any(range==0)
	fprintf('Warning: Some of dimensions has the same data. Perhaps you should remove data of the dimension first.\n');
elseif max(range)/min(range)>10000
	fprintf('Warning: max(range)/min(range)=%g>10000 ===> Perhaps you should normalize the data first.\n', max(range)/min(range));
end
if any(isnan(data(:))) | any(isinf(data(:)))
	error('Some element is nan or inf in the given data!');
end

% ====== Set initial parameters
if isnumeric(gaussianNumCovType)
	gmmParam=gmmInitParamSet(data, gaussianNum, covType, gmmTrainParam);
end

% ====== Start EM iteration
logProb = zeros(gmmTrainParam.maxIteration, 1);   % Array for objective function
if gmmTrainParam.dispOpt & dim==2, displayGmm(data, gmmParam); end

for i = 1:gmmTrainParam.maxIteration
	% ====== Expectation step:
	% P(i,j) is the probability of data(:,j) to the i-th Gaussian
	[logProbs, P]=gmmEval(data, gmmParam);
%	[logProbs, P]=gmmEvalMex(data, M, V, W);
	logProb(i)=sum(logProbs);
	if gmmTrainParam.dispOpt, fprintf('\tGMM iteration: %d/%d, log prob. = %f\n', i-1, gmmTrainParam.maxIteration, logProb(i)); end
	W = [gmmParam.w]';
	PW = repmat(W, 1, dataNum).*P;
	sumPW = sum(PW, 1);
	if any(sumPW==0), keyboard; error('some entries of sumPW==0! You need to make sure each point is at least covered by a Gaussian!'); end
	BETA=PW./repmat(sumPW, gaussianNum, 1);		% BETA(i,j) is beta_i(x_j)
	sumBETA=sum(BETA,2);
	
	% ====== Maximization step:  eqns (2.96) to (2.98) from Bishop p.67:
	% === Compute mu
	M = (data*BETA')./repmat(sumBETA', dim, 1);
	% Distribute the parameters to gmmParam
	meanCell=mat2cell(M, dim, ones(1, gaussianNum));
	[gmmParam.mu]=deal(meanCell{:});
	% === Compute w
	W = (1/dataNum)*sumBETA;					% (2.98)
	wCell=mat2cell(W', 1, ones(1, gaussianNum));
	[gmmParam.w]=deal(wCell{:});
	% === Compute sigma
	DISTSQ = pairwiseSqrDistance(M, data);						% Distance of M to data
	if prod(size(gmmParam(1).sigma))==1	% identity covariance matrix times a constant for each Gaussian
		V = max((sum(BETA.*DISTSQ, 2)./sumBETA)/dim, gmmTrainParam.minVariance);	% (2.97)
		for j=1:gaussianNum
			gmmParam(j).sigma=V(j);
		end
	elseif prod(size(gmmParam(1).sigma))==dim	% diagonal covariance matrix for each Gaussian
		% This segment remains to be double checked
		for j=1:gaussianNum
			dataMinusMu = data-repmat(gmmParam(j).mu, 1, dataNum);
			weight = repmat(BETA(j,:), dim, 1);
			gmmParam(j).sigma=sum((weight.*dataMinusMu).*dataMinusMu, 2)/sum(BETA(j,:));
		end
	else	% full covariance matrix for each Gaussian
		for j=1:gaussianNum
			dataMinusMu = data-repmat(gmmParam(j).mu, 1, dataNum);
			weight = repmat(BETA(j,:), dim, 1);
			gmmParam(j).sigma=(weight.*dataMinusMu)*dataMinusMu'/sum(BETA(j,:));
		end
	end
	% === Animation
	if gmmTrainParam.dispOpt & dim==2, displayGmm(data, gmmParam); end
	
	% ====== Check stopping criterion
	if i>1, if logProb(i)-logProb(i-1)<gmmTrainParam.minImprove, break; end; end
end
[logProbs, P]=gmmEval(data, gmmParam);
logProb(i)=sum(logProbs);

if gmmTrainParam.dispOpt, fprintf('\tGMM total iteration count = %d, log prob. = %f\n',i, logProb(i)); end
logProb(i+1:gmmTrainParam.maxIteration) = [];

% ====== Subfunctions ======
function displayGmm(data, gmmParam)
% Display function for EM algorithm
figureH=findobj(0, 'tag', mfilename);
if isempty(figureH)
	figureH=figure;
	set(figureH, 'tag', mfilename);
	plot(data(1,:), data(2,:),'.r'); axis image
	for i=1:length(gmmParam)
		[xData, yData]=halfHeightContour(gmmParam(i));
		circleH(i)=line(xData, yData, 'color', 'k', 'linewidth', 3);
	end
	set(circleH, 'tag', 'circleH', 'erasemode', 'xor');
else
	circleH=findobj(figureH, 'tag', 'circleH');
	for i=1:length(gmmParam)
		[xData, yData]=halfHeightContour(gmmParam(i));
		set(circleH(i), 'xdata', xData, 'ydata', yData);
	end
	drawnow
end

% ====== Obtain the contour data at half height of an Gaussian
function [xData, yData]=halfHeightContour(gaussianParam)
dim=length(gaussianParam.mu);
theta=linspace(-pi, pi, 21);
if prod(size(gaussianParam.sigma))==1	% identity covariance matrix times a constant for each Gaussian
	r=sqrt(2*log(2)*gaussianParam.sigma);	% Gaussian reaches it's 50% height at this distance from the mean
	xData=r*cos(theta)+gaussianParam.mu(1);
	yData=r*sin(theta)+gaussianParam.mu(2);
elseif prod(size(gaussianParam.sigma))==dim	% diagonal covariance matrix for each Gaussian
	r1=sqrt(2*log(2)*gaussianParam.sigma(1));
	r2=sqrt(2*log(2)*gaussianParam.sigma(2));
	xData=r1*cos(theta)+gaussianParam.mu(1);
	yData=r2*sin(theta)+gaussianParam.mu(2);
else	% full covariance matrix for each Gaussian
	[V, D]=eig(gaussianParam.sigma);
	r1=sqrt(2*log(2)*D(1,1));
	r2=sqrt(2*log(2)*D(2,2));
	rotatedData=[r1*cos(theta); r2*sin(theta)];
	origData=V*rotatedData;
	xData=origData(1,:)+gaussianParam.mu(1);
	yData=origData(2,:)+gaussianParam.mu(2);
end

% ====== Self Demo ======
function selfdemo
gmmTrainDemo2dCovType01;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -