📄 mgmmem.m
字号:
function [mix, options, errlog] = mgmmem(x, options)% MGMMEM EM algorithm for Gaussian mixture model.% Modified EM algorithm, % Uses the Figuredo/Jain algorithm in% "Unsupervised Learning of Finite Mixture Models"%% Description% [MIX, OPTIONS, ERRLOG] = GMMEM(X, OPTIONS) uses the Expectation% Maximization algorithm of Dempster et al. to estimate the parameters% of a Gaussian mixture model defined by a data structure MIX. The% matrix X represents the data whose expectation is maximized, with% each row corresponding to a vector. The optional parameters have% the following interpretations.%% OPTIONS(1) is set to 1 to display error values; also logs error% values in the return argument ERRLOG. If OPTIONS(1) is set to 0, then% only warning messages are displayed. If OPTIONS(1) is -1, then% nothing is displayed.%% OPTIONS(3) is a measure of the absolute precision required of the% error function at the solution. If the change in log likelihood% between two steps of the EM algorithm is less than this value, then% the function terminates.%% OPTIONS(5) is set to 1 if a covariance matrix is reset to its% original value when any of its singular values are too small (less% than MIN_COVAR which has the value eps). With the default value of% 0 no action is taken.%% OPTIONS(10) is the starting number of components; default ??%% OPTIONS(14) is the maximum number of iterations; default 100.%% The optional return value OPTIONS contains the final error value% (i.e. data log likelihood) in OPTIONS(8).%% See also% GMM, GMMINIT%% Code modified from gmmem:% Copyright (c) Ian T Nabney (1996-2001)%% Initializes the model: there is no need to call gmm and gmminit%% Check that inputs are consistentprintall =0;[ndata, xdim] = size(x);% Sort out the optionsif (options(14)) niters = options(14);else niters = 100;end% This should be changed...if (options(10)) startNComponents = options(10);else startNComponents = 0;endif startNComponents == 0, startNComponents = floor(ndata / (xdim * xdim * 10)); startNComponents = max(startNComponents, 1); startNComponents = min(startNComponents, ndata); if (printall == 1) fprintf('Starting with %d components\n',startNComponents); endend% ----------------% initializes the algorithmif (printall == 1) disp('Initializing Parameter Vector');endmix = gmm (xdim, startNComponents,'full');mx = mean(x,1);cov= diag( diag(((x - ones(ndata,1)*mx)'*(x - ones(ndata,1)*mx)) ... / (ndata * 10 * xdim),0));for j = 1:startNComponents, mix.priors(j) = 1./startNComponents; mix.centres(j,:) = x(floor(rand(1,1)*(ndata-.001))+1,:); mix.covars(:,:,j) = cov;enderrstring = consist(mix, 'gmm', x);if ~isempty(errstring) error(errstring);end% more argument processingdisplay = options(1);store = 0;if (nargout > 2) store = 1; % Store the error values to return them errlog = zeros(1, niters);endtest = 0;if options(3) > 0.0 test = 1; % Test log likelihood for terminationendcheck_covars = 0;if options(5) >= 1 if display >= 0 disp('check_covars is on'); end check_covars = 1; % Ensure that covariances don't collapse MIN_COVAR = eps; % Minimum singular value of covariance matrixendinit_covars = mix.covars;% computes the number of parameters% per mixture componentswitch mix.covar_typecase 'spherical' N = xdim +1;case 'diag' N = xdim *2;case 'full' N = xdim^2/2 + xdim;case 'ppca' N = xdim *2 + xdim; otherwise error(['Unknown covariance type ', mix.covar_type]); endN = ceil(N);% N = N * 10;if (printall == 1) fprintf('N = %d\n',N);end% Main loop of algorithm% E-M stepsn = 0;while n < niters * mix.ncentres, n = n+1; % use componentwise EM algorithm (CEM) % add reference here % at each step, modify only one component % Step 1: % Calculate posteriors and activations % based on old parameters [post, act] = gmmpost(mix, x); % -------------------------------- % Calculate error value if needed if (display | store | test) prob = act*(mix.priors)'; % Error value is negative log likelihood of data e = - sum(log(prob + (prob ==0))); if store errlog(n) = e; end if display > 0 fprintf(1, 'Cycle %4d Error %11.6f\n', n, e); end if test if (n > 1 & abs(e - eold) < options(3)) options(8) = e; return; else eold = e; end end end % Adjust the new estimates for the parameters % This is new, see eq. 18 of cited paper new_pr = sum(post, 1); % column sum aux_new_pr = new_pr; new_pr = max(zeros(size(new_pr)), new_pr - N/2); % the following avoids the unpleasant case % in which te sum of the new priors is zero % due to the above adjustment (note: it should not happen, however ...) if (sum(new_pr) == 0 ) new_pr = aux_new_pr / sum(aux_new_pr); else new_pr = new_pr / sum(new_pr); end % recomputes the centroids % M step 2 new_c = post' * x; % Now move new estimates to old parameter vectors % mix.priors = new_pr ./ ndata; % NOTE: we could have priors == 0 componentToChange = mod(n,mix.ncentres)+1; mix.priors = new_pr; % Already normalized foo = new_c ./ ((sum(post,1) + (new_pr == 0))' * ones(1,xdim)); mix.centres(componentToChange,:) = foo(componentToChange,:); % update the covariance matrices % M step 3 switch mix.covar_type case 'full' %for j = 1:mix.ncentres %diffs = x - (ones(ndata, 1) * mix.centres(j,:)); % diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin)); % if ( new_pr(j) > 0) % ignore covariances that have zero value % mix.covars(:,:,j) = (diffs'*diffs)/new_pr(j); % else % mix.covars(:,:,j) = init_covars(:,:,j); % end % end if new_pr(componentToChange ) >0, % updates the parameters diffs = x - (ones(ndata, 1) * mix.centres(componentToChange,:)); diffs = diffs.*(sqrt(post(:,componentToChange))*ones(1, xdim)); mix.covars(:,:,componentToChange) ... = (diffs'*diffs)/sum(post(:,componentToChange)); if check_covars % Ensure that no covariance is too small for j = 1:mix.ncentres if ( new_pr(j) > 0) if min(svd(mix.covars(:,:,j))) < MIN_COVAR mix.covars(:,:,j) = init_covars(:,:,j); end end end end else % and now removes the components with % priors equal to zero newNCentres = mix.ncentres-1; if (printall == 1) fprintf('Removing component, left with %d\n', newNCentres); end newmix.type = 'gmm'; newmix.nin = mix.nin; newmix.covar_type = mix.covar_type; newmix.ncentres = newNCentres; newmix.priors = ones(1,newNCentres); newmix.centres = zeros(newmix.ncentres,newmix.nin); newmix.covars = repmat(eye(newmix.nin), [1 1 newmix.ncentres]); newmix.nwts = newmix.ncentres + newmix.ncentres*newmix.nin + ... newmix.ncentres*newmix.nin*newmix.nin; for j = 1:componentToChange-1, newmix.priors(j) = new_pr(j); newmix.centres(j,:) = mix.centres(j,:); newmix.covars(:,:,j) = mix.covars(:,:,j); end for j = componentToChange+1:mix.ncentres, newmix.priors(j-1) = new_pr(j); newmix.centres(j-1,:) = mix.centres(j,:); newmix.covars(:,:,j-1) = mix.covars(:,:,j); end %newNcentres = sum(new_pr > 0); % if ( newNcentres < mix.ncentres) % newmix.type = 'gmm'; % newmix.nin = mix.nin; % newmix.ncentres = newNcentres; % newmix.covar_type = mix.covar_type; % newmix.priors = ones(1,newNcentres); % newmix.centres = zeros(newmix.ncentres, newmix.nin); % newmix.covars = repmat(eye(newmix.nin), [1 1 newmix.ncentres]); % newmix.nwts = newmix.ncentres + newmix.ncentres*newmix.nin + ... % newmix.ncentres*newmix.nin*newmix.nin; % jold = 1; jnew = 1; % while jnew <= newNcentres, % while new_pr(jold) == 0, % ignore % jold = jold+1; % end % copies priors, covariances and centroids % newmix.priors(jnew) = new_pr(jold); % newmix.centres(jnew,:) = mix.centres(jold,:); % newmix.covars(:,:,jnew) = mix.covars(:,:,jold); % jold = jold+1; % jnew = jnew+1; % end % fprintf('reduced number of components to %d %g \n',newNcentres, sum(newmix.priors)); % % replace the mixture and updates the nr of components % % in the options % mix = newmix; % options(10) = newNcentres; mix = newmix; options(10) = mix.ncentres; end case 'spherical' disp('Not finished yet'); n2 = dist2(x, mix.centres); for j = 1:mix.ncentres v(j) = (post(:,j)'*n2(:,j)); end mix.covars = ((v./new_pr))./mix.nin; if check_covars % Ensure that no covariance is too small for j = 1:mix.ncentres if mix.covars(j) < MIN_COVAR mix.covars(j) = init_covars(j,:); end end end case 'diag' disp('Not finished yet'); for j = 1:mix.ncentres diffs = x - (ones(ndata, 1) * mix.centres(j,:)); mix.covars(j,:) = sum((diffs.*diffs).*(post(:,j)*ones(1, ... mix.nin)), 1)./new_pr(j); end if check_covars % Ensure that no covariance is too small for j = 1:mix.ncentres if min(mix.covars(j,:)) < MIN_COVAR mix.covars(j,:) = init_covars(j,:); end end end case 'ppca' disp('Not finished yet'); for j = 1:mix.ncentres diffs = x - (ones(ndata, 1) * mix.centres(j,:)); diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin)); [tempcovars, tempU, templambda] = ... ppca((diffs'*diffs)/new_pr(j), mix.ppca_dim); if length(templambda) ~= mix.ppca_dim error('Unable to extract enough components'); else mix.covars(j) = tempcovars; mix.U(:, :, j) = tempU; mix.lambda(j, :) = templambda; end end if check_covars if mix.covars(j) < MIN_COVAR mix.covars(j) = init_covars(j); end end otherwise error(['Unknown covariance type ', mix.covar_type]); endend% % Now removes the mixture components with too small of a covariance matrix% recomputing everything[post, act] = gmmpost(mix, x);% Adjust the new estimates for the parameters% This is new, see eq. 18 of cited papernew_pr = sum(post, 1); % column sumnew_pr = max(zeros(size(new_pr)), new_pr - N/2);new_pr = new_pr / sum(new_pr);% recomputes the centroids% M step 2new_c = post' * x;mix.priors = new_pr; % Already normalizedfoo = new_c ./ ((sum(post,1) + (new_pr == 0))' * ones(1,xdim));for j = 1:mix.ncentres, if ( mix.priors(j) > 0) mix.centres(j,:) = foo(j,:); diffs = x - (ones(ndata, 1) * mix.centres(j,:)); diffs = diffs.*(sqrt(post(:,j))*ones(1, xdim)); mix.covars(:,:,j) = (diffs'*diffs)/sum(post(:,j)); endend %% Throws away the components with small% priorssurvivingComponents = sum(mix.priors > 0);if survivingComponents < mix.ncentres, newmix.ncentres = survivingComponents; newmix.type = 'gmm'; newmix.nin = mix.nin; newmix.covar_type = mix.covar_type; newmix.priors = ones(1,newmix.ncentres); newmix.centres = zeros(newmix.ncentres,newmix.nin); newmix.covars = repmat(eye(newmix.nin), [1 1 newmix.ncentres]); newmix.nwts = newmix.ncentres + newmix.ncentres*newmix.nin + ... newmix.ncentres*newmix.nin*newmix.nin; jold = 1; jnew = 1; while jnew <= survivingComponents, while mix.priors(jold) == 0, jold = jold+1; end newmix.priors(jnew) = mix.priors(jold); newmix.centres(jnew,:) = mix.centres(jold,:); newmix.covars(:,:,jnew) = mix.covars(:,:,jold); jnew = jnew+1; jold = jold+1; end mix = newmix; options(10) = mix.ncentres;end% options(8) = -sum(log(gmmprob(mix, x)));if (display >= 0) fprintf('Warning: Maximum number of iterations %d has been exceeded\n', niters);end %%%
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -