📄 gmmmlewrtgaussiannum.m
字号:
function [trainLp, testLp]=gmmMleWrtGaussianNum(trainData, testData, vecOfGaussianNum, covType, gmmTrainParam, plotOpt)
if nargin<1, selfdemo; return; end
if nargin<4, covType=1; end
if nargin<5, gmmTrainParam=gmmTrainParamSet; end
if nargin<6, plotOpt=0; end
%h=waitbar(0, 'Please wait...');
for i=1:length(vecOfGaussianNum)
fprintf('%d/%d: No. of mixtures = %d\n', i, length(vecOfGaussianNum), vecOfGaussianNum(i));
% ====== Training GMM model
[gmmParam, lp] = gmmTrain(trainData, [vecOfGaussianNum(i), covType], gmmTrainParam);
trainLp(i)=max(lp);
testLp(i)=sum(gmmEval(testData, gmmParam));
% waitbar(i/length(vecOfGaussianNum), h);
% keyboard
end
%close(h);
if plotOpt
plot(vecOfGaussianNum, trainLp, '-o', vecOfGaussianNum, testLp, '-o');
xlabel('No. of Gaussian mixtures');
ylabel('Total log prob.');
[junk, index]=max(testLp);
line(vecOfGaussianNum(index), testLp(index), 'marker', '*', 'color', 'r');
legend('Training log prob.', 'Test log prob.', 'Location', 'SouthEast');
end
% ====== Self demo
function selfdemo
dataNum = 100;
data1 = randn(1,2*dataNum);
data2 = randn(1,3*dataNum)/2+3;
data3 = randn(1,1*dataNum)/3-3;
data4 = randn(1,1*dataNum)+6;
data = [data1, data2, data3, data4];
subplot(2,1,1); hist(data, 50);
trainData=data(:, 1:2:end);
testData=data(:, 2:2:end);
vecOfGaussianNum=2:20;
covType=1;
gmmTrainParam=gmmTrainParamSet;
plotOpt=1;
subplot(2,1,2);
[trainLp, testLp]=gmmMleWrtGaussianNum(trainData, testData, vecOfGaussianNum, covType, gmmTrainParam, plotOpt);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -