📄 em_vc.m
字号:
% Learns classifier and classifies test set% using the expectation-maximization algorithm% Uses a modified version of E-M which automatically selects the number of components%% Usage:% [trainError, testError, estTrainLabels, estTestLabels] = ...% = EM_VC(train_features, trainLabels, testFeatures, testLabels)%% Inputs:% features - Train features% labels - Train labels% algo_pars - Ignored% testFeatures - test set, one column per vector% testLabels - labels for test set%% Outputs% train_error - the error rate on the training set (one entry per% class + total error)% test_error - the error rate on the test set (one entry per class% + total error)function [train_error, test_error, hatTrainLabels, hatTestLabels] = ... EM_VC(trainFeatures, trainLabels, algo_pars, testFeatures, testLabels)hm = findobj('Tag', 'Messages'); disp('Patience, this might take a while');if (isempty(hm) ==0) s = sprintf('Patience, this might take a while'); set(hm,'String',s); refresh;endtestLabels = testLabels(:);trainLabels = trainLabels(:);options = zeros(1,20);[Nclasses, classes] = find_classes([trainLabels(:);testLabels(:)]);% Number of classes in labelsmax_iter = 100;% Computing Priorsmixtures = cell(1,Nclasses);for cl =1:Nclasses, classLabel = classes(cl); train = find(trainLabels == classLabel); classPrior(cl) = length(train)/ length(trainLabels); [mix, options, errlog] = mgmmem((trainFeatures(:,train))', options); evalstr = sprintf('mixtures%d = mix(1,1);',cl); eval(evalstr); Ngaussians(cl) = mix.ncentres;end%% Now evaluates the probability of errordisp('Computing Error Rates');if (isempty(hm)==0) s = sprintf('Computing Error Rates'); set(hm,'String',s); refresh;end[Dim, Nsam] = size(trainFeatures);oneTrain = ones(1,Nsam);posteriors = zeros(Nclasses,Nsam);for cl = 1:Nclasses, evalstr = sprintf('componentPriors = mixtures%d.priors;',cl); eval(evalstr); for gaus =1:Ngaussians(cl) % extracts centroid and covariance evalstr = sprintf('cent = mixtures%d.centres(gaus,:);',cl); eval(evalstr); evalstr = sprintf('covmat = mixtures%d.covars(:,:,gaus);',cl); eval(evalstr); % now computes the posterior of the individual gaussian component % as product of the likelihood times the % component prior, and adds the posterior to the likelihood of the class % (stored in posteriors(cl)) cent = cent' * oneTrain; diff = cent - trainFeatures; diff1 = diff' * inv(covmat); diff = diff1' .* diff; diff = exp(-sum(diff)/2)/sqrt(det(covmat)); posteriors(cl,:) = posteriors(cl,:) +diff *componentPriors(gaus); end % computes the class posterior as the % product of the class likelihood, computed in the loop, % times the class prior. Now posteriors(cl) contains % the actual posterior of the class posteriors(cl,:) = posteriors(cl,:)*classPrior(cl);end% Finally labels the samples using the MAP algorithm[maxval, maxind] = max(posteriors);hatTrainLabels = classes(maxind);% now computes the error on the training datatrain_error = zeros(Nclasses+1,1);for cl = 1:Nclasses, samWithCurrLabel = (trainLabels == classes(cl)); curr_errors = sum((hatTrainLabels ~= trainLabels).*samWithCurrLabel); train_error(cl) = curr_errors/sum(samWithCurrLabel); train_error(Nclasses+1) = train_error(Nclasses+1) + curr_errors;endtrain_error(Nclasses+1) = train_error(Nclasses+1)/length(trainLabels);% --------------------------------% Computes error on test set[Dim, Nsam] = size(testFeatures);oneTest = ones(1,Nsam);posteriors = zeros(Nclasses,Nsam);for cl = 1:Nclasses, evalstr = sprintf('componentPriors = mixtures%d.priors;',cl); eval(evalstr); for gaus =1:Ngaussians(cl) evalstr = sprintf('cent = mixtures%d.centres(gaus,:);',cl); eval(evalstr); evalstr = sprintf('covmat = mixtures%d.covars(:,:,gaus);',cl); eval(evalstr); cent = cent'*oneTest; diff = cent - testFeatures; diff1 = diff' * inv(covmat); diff = diff1' .* diff; diff = exp(-sum(diff)/2)/sqrt(det(covmat)); posteriors(cl,:) = posteriors(cl,:) +diff *componentPriors(gaus); end posteriors(cl,:) = posteriors(cl,:)*classPrior(cl);end[maxval, maxind] = max(posteriors);hatTestLabels = classes(maxind);% now computes the error on the test datatest_error = zeros(Nclasses+1,1);for cl = 1:Nclasses, samWithCurrLabel = (testLabels == classes(cl)); curr_errors = sum((hatTestLabels ~= testLabels).*samWithCurrLabel); test_error(cl) = curr_errors/sum(samWithCurrLabel); test_error(Nclasses+1) = test_error(Nclasses+1) + curr_errors;endtest_error(Nclasses+1) = test_error(Nclasses+1)/length(testLabels);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -