📄 gmmeval.m
字号:
function [logProb, gaussianProb] = gmmEval(data, mu, sigma, w);
% gmmEval: Evaluation of a GMM (Gaussian mixture model)
% Usage: [logProb, gaussianProb] = gmmEval(data, mu, sigma, w);
% data: dim x dataNum matrix where each column is a data point
% mu: dim x gaussianNum matrix where each column is a mean vector
% sigma: 1 x gaussianNum vector where each element represents the covariance of a gaussian
% w: 1 x gaussianNum vector where each element is a weighting coefficient
% logProb: 1 x dataNum vector of output probabilities
% gaussianProb(i,j) is the probability of data(:,j) to the i-th Gaussian (This is for gmmTrain.m only)
%
% For example, to plot 2D GMM:
% mu = [-5 0 5];
% sigma = [1 4 3];
% w = [0.1, 0.5, 0.4];
% x=linspace(-10, 10, 101);
% logProb = gmmEval(x, mu, sigma, w);
% prob=exp(logProb);
% plot(x, prob, '.-');
% line(x, w(1)*gaussian(x, mu(1), sigma(1)), 'color', 'r');
% line(x, w(2)*gaussian(x, mu(2), sigma(2)), 'color', 'm');
% line(x, w(3)*gaussian(x, mu(3), sigma(3)), 'color', 'g');
%
% Another example, to plot 3D GMM:
% mu = [-3 3; 3, -3; 3, 3]';
% sigma = [1 3 2];
% w = [0.2, 0.3, 0.5];
% bound = 8;
% pointNum = 31;
% x = linspace(-bound, bound, pointNum);
% y = linspace(-bound, bound, pointNum);
% [xx, yy] = meshgrid(x, y);
% data = [xx(:), yy(:)]';
% logProb = gmmeval(data, mu, sigma, w);
% zz = reshape(exp(logProb), pointNum, pointNum);
% mesh(xx, yy, zz); axis tight; box on
% Roger Jang, 20000602
if nargin==0, selfdemo; return; end
if length(w)==1 % Reduce to a single gaussian!
if w~=1
fprintf('Error in the case of single gaussian where w is not equal to 1!\n');
end
logProb=sum(gaussianLog(data, mu, sigma)); % No need to have the second output
return
end
[dim, dataNum]=size(data);
gaussianNum=length(sigma);
log2pi=log(2*pi);
logw=log(w);
logProb=zeros(1, dataNum);
logGaussianProb=zeros(gaussianNum, dataNum);
for i=1:gaussianNum
dataMinusMu = data-mu(:,i)*ones(1, dataNum);
logGaussianProb(i,:) = (-sum(dataMinusMu.*dataMinusMu, 1)/sigma(i)-dim*(log2pi+log(sigma(i))))/2;
end
for i=1:dataNum
% logProb(i)=mixLogSum(logw(:)+logGaussianProb(:,i));
logProb(i)=mixLogSumMex(logw(:)+logGaussianProb(:,i));
end
if nargout>1
gaussianProb=exp(logGaussianProb); % This output is necessary for gmmTrain.m!
end
% ====== Self demo
function selfdemo
% == 1D example
mu = [-5 0 5];
sigma = [1 4 3];
w = [0.1, 0.5, 0.4];
x=linspace(-10, 10, 101);
logProb = feval(mfilename, x, mu, sigma, w);
prob=exp(logProb);
figure; plot(x, prob, '.-');
line(x, w(1)*gaussian(x, mu(1), sigma(1)), 'color', 'r');
line(x, w(2)*gaussian(x, mu(2), sigma(2)), 'color', 'm');
line(x, w(3)*gaussian(x, mu(3), sigma(3)), 'color', 'g');
% === 2D example
mu = [-3 3; 3, -3; 3, 3]';
sigma = [1 3 2];
w = [0.2, 0.3, 0.5];
bound = 8;
pointNum = 31;
x = linspace(-bound, bound, pointNum);
y = linspace(-bound, bound, pointNum);
[xx, yy] = meshgrid(x, y);
data = [xx(:), yy(:)]';
logProb = feval(mfilename, data, mu, sigma, w);
zz = reshape(exp(logProb), pointNum, pointNum);
figure; mesh(xx, yy, zz); axis tight; box on
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -