📄 gmmtrainevalwrtgaussiannum.m
字号:
function [gmmData, recogRate1, recogRate2, validMixNumIndex]=gmmTrainEvalWrtGaussianNum(DS, TS, vecOfMixNum, covType, gmmTrainParam)
% gmmTrainEvalWrtMixNum: GMM training and test, w.r.t. varying number of mixtures
% Usage: [gmmData, recogRate1, recogRate2, validMixNumIndex]=gmmTrainEvalWrtGaussianNum(DS, TS, vecOfMixNum, trainParam, covType, gmmTrainParam)
% DS: training set
% TS: test set
% vecOfMixNum: vector of numbers of mixtures
% covType: type of covariance matrix, 1: identity times a constant, 2: diagonal, 3: full
% gmmTrainParam: parameter for training GMM
% gmmData: GMM parameters
% gmmData(i): in which each gmm has vecOfMixNum(i) gaussians
% gmmData(i).gmm(j): gmm of class j at case i
% gmmData(i).gmm(j).gmmParam(k): gaussian k of class j at case i
% gmmData(i).gmm(j).gmmParam(k).mu: mean vector
% gmmData(i).gmm(j).gmmParam(k).sigma: covariance matrix
% gmmData(i).gmm(j).gmmParam(k).w: weight
% recogRate1: inside-test recognition rate
% recogRate2: outside-test recognition rate
% validMixNumIndex: Actually valid index for vecOfMixNum. We need to have this output parameters since
% sometimes we are given a large number of mixtures which cannot be used for GMM training at all
%
% For example:
% [DS, TS]=prData('wine');
% vecOfMixNum=2:30;
% covType=1;
% gmmTrainParam=gmmTrainParamSet;
% gmmTrainParam.plotOpt=1;
% [gmmData, recogRate1, recogRate2, validMixNumIndex]=gmmTrainEvalWrtGaussianNum(DS, TS, vecOfMixNum, covType, gmmTrainParam);
% Roger Jang, 20070516
if nargin<1, selfdemo; return; end
if nargin<4, gmmTrainParam=gmmTrainParamSet; end
classLabel=unique(DS.output);
classNum=length(classLabel);
recogRate1=zeros(length(vecOfMixNum), 1);
recogRate2=zeros(length(vecOfMixNum), 1);
[dim, dsNum]=size(DS.input);
[dim, tsNum]=size(TS.input);
fprintf('DS data count = %d, TS data count = %d\n', dsNum, tsNum);
[classLabel, classSizeDS]=classSize(DS); fprintf('DS class data count = %s\n', mat2str(classSizeDS));
[classLabel, classSizeTS]=classSize(TS); fprintf('TS class data count = %s\n', mat2str(classSizeTS));
% ====== Perform training and compute recognition rates
errorMixNumIndex=0;
errorClassIndex=0;
h=waitbar(0, 'Please wait...');
for j=1:length(vecOfMixNum)
fprintf('%d/%d: No. of Gaussian = %d ===> ', j, length(vecOfMixNum), vecOfMixNum(j));
% ====== Training GMM model for each class
for i=1:classNum
% fprintf(' class %d... ', i);
index=find(DS.output==classLabel(i));
theData=DS.input(:, index);
try
% gmmTrainParam.dispOpt=1;
[gmmData(j).gmm(i).gmmParam, gmmData(j).gmm(i).logProb] = gmmTrain(theData, [vecOfMixNum(j), covType], gmmTrainParam);
catch
errorClassIndex=i;
break;
end
end
if errorClassIndex>0
errorMixNumIndex=vecOfMixNum(j);
fprintf('Error out on errorMixNumIndex=%d and errorClassIndex=%i\n', errorMixNumIndex, errorClassIndex);
break;
end
% ====== Compute inside-test recognition rate
outProb=zeros(classNum, dsNum);
for i=1:classNum
outProb(i,:)=gmmEval(DS.input, gmmData(j).gmm(i).gmmParam);
end
[maxValue, computedOutput]=max(outProb);
recogRate1(j)=sum(DS.output==computedOutput)/length(DS.output);
% ====== Compute outside-test recognition rate
outProb=zeros(classNum, tsNum);;
for i=1:classNum
outProb(i,:)=gmmEval(TS.input, gmmData(j).gmm(i).gmmParam);
end
[maxValue, computedOutput]=max(outProb);
recogRate2(j)=sum(TS.output==computedOutput)/length(TS.output);
fprintf('inside RR = %g%%, outside RR = %g%%\n', recogRate1(j)*100, recogRate2(j)*100);
waitbar(j/length(vecOfMixNum), h);
end
close(h);
if errorMixNumIndex>0
gmmData(errorMixNumIndex:end)=[];
recogRate1(errorMixNumIndex:end)=[];
recogRate2(errorMixNumIndex:end)=[];
vecOfMixNum(errorMixNumIndex:end)=[];
end
validMixNumIndex=errorMixNumIndex-1;
% ====== Plot the result
if gmmTrainParam.plotOpt
plot(vecOfMixNum, recogRate1*100, 'o-', vecOfMixNum, recogRate2*100, 'square-'); grid on
legend('Inside test', 'Outside test', 4);
xlabel('No. of Gaussian mixtures'); ylabel('Recognition Rates (%)');
end
% ====== Self demo
function selfdemo
[DS, TS]=prData('wine');
vecOfMixNum=2:30;
covType=1;
gmmTrainParam=gmmTrainParamSet;
gmmTrainParam.plotOpt=1;
[gmmData, recogRate1, recogRate2, validMixNumIndex]=feval(mfilename, DS, TS, vecOfMixNum, covType, gmmTrainParam);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -