⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 em_gm.m

📁 EM算法是机器学习领域中常用的一种算法
💻 M
字号:
function [W,M,V,L] = EM_GM(X,k,ltol,maxiter,pflag,Init)% [W,M,V,L] = EM_GM(X,k,ltol,maxiter,pflag,Init) % % EM algorithm for k multidimensional Gaussian mixture estimation%% Inputs:%   X(n,d) - input data, n=number of observations, d=dimension of variable%   k - maximum number of Gaussian components allowed%   ltol - percentage of the log likelihood difference between 2 iterations ([] for none)%   maxiter - maximum number of iteration allowed ([] for none)%   pflag - 1 for plotting GM for 1D or 2D cases only, 0 otherwise ([] for none)%   Init - structure of initial W, M, V: Init.W, Init.M, Init.V ([] for none)%% Ouputs:%   W(1,k) - estimated weights of GM%   M(d,k) - estimated mean vectors of GM%   V(d,d,k) - estimated covariance matrices of GM%   L - log likelihood of estimates%% Written by%   Patrick P. C. Tsui,%   PAMI research group%   Department of Electrical and Computer Engineering%   University of Waterloo, %   September, 2005%%%%% Validate inputs %%%%if nargin <= 1,    disp('EM_GM must have at least 2 inputs: X,k!/n')    returnelseif nargin == 2,    ltol = 0.1; maxiter = 1000; pflag = 0; Init = [];    err_X = Verify_X(X);    err_k = Verify_k(k);    if err_X | err_k, return; endelseif nargin == 3,    maxiter = 1000; pflag = 0; Init = [];    err_X = Verify_X(X);    err_k = Verify_k(k);    [ltol,err_ltol] = Verify_ltol(ltol);        if err_X | err_k | err_ltol, return; endelseif nargin == 4,    pflag = 0;  Init = [];    err_X = Verify_X(X);    err_k = Verify_k(k);    [ltol,err_ltol] = Verify_ltol(ltol);        [maxiter,err_maxiter] = Verify_maxiter(maxiter);    if err_X | err_k | err_ltol | err_maxiter, return; endelseif nargin == 5,     Init = [];    err_X = Verify_X(X);    err_k = Verify_k(k);    [ltol,err_ltol] = Verify_ltol(ltol);        [maxiter,err_maxiter] = Verify_maxiter(maxiter);    [pflag,err_pflag] = Verify_pflag(pflag);    if err_X | err_k | err_ltol | err_maxiter | err_pflag, return; endelseif nargin == 6,    err_X = Verify_X(X);    err_k = Verify_k(k);    [ltol,err_ltol] = Verify_ltol(ltol);        [maxiter,err_maxiter] = Verify_maxiter(maxiter);    [pflag,err_pflag] = Verify_pflag(pflag);    [Init,err_Init]=Verify_Init(Init);    if err_X | err_k | err_ltol | err_maxiter | err_pflag | err_Init, return; endelse    disp('EM_GM must have 2 to 6 inputs!');    returnend%%%% Initialize W, M, V,L %%%%t = cputime;if isempty(Init),      [W,M,V] = Init_EM(X,k); L = 0;    else    W = Init.W;    M = Init.M;    V = Init.V;endLn = Likelihood(X,k,W,M,V); % Initialize log likelihoodLo = 2*Ln;%%%% EM algorithm %%%%niter = 0;while (abs(100*(Ln-Lo)/Lo)>ltol) & (niter<=maxiter),    E = Expectation(X,k,W,M,V); % E-step        [W,M,V] = Maximization(X,k,E);  % M-step    Lo = Ln;    Ln = Likelihood(X,k,W,M,V);    niter = niter + 1;end L = Ln;%%%% Plot 1D or 2D %%%%if pflag==1,    [n,d] = size(X);    if d>2,        disp('Can only plot 1 or 2 dimensional applications!/n');        return    end        Plot_GM(X,k,W,M,V);    elapsed_time = sprintf('CPU time used for EM_GM: %5.2fs',cputime-t);    disp(elapsed_time);     disp(sprintf('Number of iterations: %d',niter-1));end%%%%%%%%%%%%%%%%%%%%%%%%%% End of EM_GM %%%%%%%%%%%%%%%%%%%%%%%%%%function E = Expectation(X,k,W,M,V)[n,d] = size(X);a = (2*pi)^(0.5*d);S = zeros(1,k);iV = zeros(d,d,k);for j=1:k,    if V(:,:,j)==zeros(d,d), V(:,:,j)=ones(d,d)*eps; end    S(j) = sqrt(det(V(:,:,j)));    iV(:,:,j) = inv(V(:,:,j));    endE = zeros(n,k);for i=1:n,        for j=1:k,        px = 0;        for m=1:k,            dXM = X(i,:)'-M(:,m);            px = px + W(m)*exp(-0.5*dXM'*iV(:,:,m)*dXM)/(a*S(m));        end        dXM = X(i,:)'-M(:,j);        pl = exp(-0.5*dXM'*iV(:,:,j)*dXM)/(a*S(j));        E(i,j) = W(j)*pl/px;    end    end%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% End of Expectation %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [W,M,V] = Maximization(X,k,E)[n,d] = size(X);W = zeros(1,k);M = zeros(d,k);V = zeros(d,d,k);for i=1:k,  % Compute weights    for j=1:n,        W(i) = W(i) + E(j,i);        M(:,i) = M(:,i) + E(j,i)*X(j,:)';    end    M(:,i) = M(:,i)/W(i);endfor i=1:k,    for j=1:n,        dXM = X(j,:)'-M(:,i);        V(:,:,i) = V(:,:,i) + E(j,i)*dXM*dXM';    end    V(:,:,i) = V(:,:,i)/W(i);endW = W/n;%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% End of Maximization %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function L = Likelihood(X,k,W,M,V)[n,d] = size(X);a = (2*pi)^(0.5*d);S = zeros(1,k);iV = zeros(d,d,k);for j=1:k,    if V(:,:,j)==zeros(d,d), V(:,:,j)=ones(d,d)*eps; end    S(j) = sqrt(det(V(:,:,j)));    iV(:,:,j) = inv(V(:,:,j));endL = 0;for i=1:n,    for j=1:k,        dXM = X(i,:)'-M(:,j);                L = L -0.5*dXM'*iV(:,:,j)*dXM-log(a*S(j));    endend%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% End of Likelihood %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function err_X = Verify_X(X)err_X = 1;[n,d] = size(X);if n<d,    disp('Input data must be n x d!/n');    returnenderr_X = 0;%%%%%%%%%%%%%%%%%%%%%%%%%%%%% End of Verify_X %%%%%%%%%%%%%%%%%%%%%%%%%%%%%function err_k = Verify_k(k)err_k = 1;if ~isnumeric(k) | ~isreal(k) | k<1,    disp('k must be a real integer >= 1!/n');    returnenderr_k = 0;%%%%%%%%%%%%%%%%%%%%%%%%%%%%% End of Verify_k %%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [ltol,err_ltol] = Verify_ltol(ltol)err_ltol = 1;if isempty(ltol),    ltol = 0.1;elseif ~isreal(ltol) | ltol<=0,    disp('ltol must be a positive real number!');    returnenderr_ltol = 0;%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% End of Verify_ltol %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [maxiter,err_maxiter] = Verify_maxiter(maxiter)err_maxiter = 1;if isempty(maxiter),    maxiter = 1000;elseif ~isreal(maxiter) | maxiter<=0,    disp('ltol must be a positive real number!');    returnenderr_maxiter = 0;%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% End of Verify_maxiter %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [pflag,err_pflag] = Verify_pflag(pflag)err_pflag = 1;if isempty(pflag),    pflag = 0;elseif pflag~=0 & pflag~=1,    disp('Plot flag must be either 0 or 1!/n');    returnenderr_pflag = 0;%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% End of Verify_pflag %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [Init,err_Init] = Verify_Init(Init)err_Init = 1;if isempty(Init),    % Do nothing;elseif isstruct(Init),    [Wd,Wk] = size(Init.W);    [Md,Mk] = size(Init.M);    [Vd1,Vd2,Vk] = size(Init.V);    if Wk~=Mk | Wk~=Vk | Mk~=Vk,        disp('k in Init.W(1,k), Init.M(d,k) and Init.V(d,d,k) must equal!/n')        return    end    if Md~=Vd1 | Md~=Vd2 | Vd1~=Vd2,        disp('d in Init.W(1,k), Init.M(d,k) and Init.V(d,d,k) must equal!/n')        return    endelse    disp('Init must be a structure: W(1,k), M(d,k), V(d,d,k) or []!');    returnenderr_Init = 0;%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% End of Verify_Init %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [W,M,V] = Init_EM(X,k)[n,d] = size(X);[Ci,C] = kmeans(X,k,'Start','cluster', ...    'Maxiter',100, ...    'EmptyAction','drop', ...    'Display','off'); % Ci(nx1) - cluster indeices; C(k,d) - cluster centroid (i.e. mean)while sum(isnan(C))>0,    [Ci,C] = kmeans(X,k,'Start','cluster', ...        'Maxiter',100, ...        'EmptyAction','drop', ...        'Display','off');endM = C';Vp = repmat(struct('count',0,'X',zeros(n,d)),1,k);for i=1:n, % Separate cluster points    Vp(Ci(i)).count = Vp(Ci(i)).count + 1;    Vp(Ci(i)).X(Vp(Ci(i)).count,:) = X(i,:);endV = zeros(d,d,k);for i=1:k,    W(i) = Vp(i).count/n;    V(:,:,i) = cov(Vp(i).X(1:Vp(i).count,:));end%%%%%%%%%%%%%%%%%%%%%%%%%%%% End of Init_EM %%%%%%%%%%%%%%%%%%%%%%%%%%%%function Plot_GM(X,k,W,M,V)[n,d] = size(X);if d>2,    disp('Can only plot 1 or 2 dimensional applications!/n');    returnendS = zeros(d,k);R1 = zeros(d,k);R2 = zeros(d,k);for i=1:k,  % Determine plot range as 4 x standard deviations    S(:,i) = sqrt(diag(V(:,:,i)));    R1(:,i) = M(:,i)-4*S(:,i);    R2(:,i) = M(:,i)+4*S(:,i);endRmin = min(min(R1));Rmax = max(max(R2));R = [Rmin:0.001*(Rmax-Rmin):Rmax];clf, hold onif d==1,    Q = zeros(size(R));    for i=1:k,        P = W(i)*normpdf(R,M(:,i),sqrt(V(:,:,i)));        Q = Q + P;        plot(R,P,'r-'); grid on,    end    plot(R,Q,'k-');    xlabel('X');    ylabel('Probability density');else % d==2    plot(X(:,1),X(:,2),'r.');    for i=1:k,        Plot_Std_Ellipse(M(:,i),V(:,:,i));    end    xlabel('1^{st} dimension');    ylabel('2^{nd} dimension');    axis([Rmin Rmax Rmin Rmax])endtitle('Gaussian Mixture estimated by EM');%%%%%%%%%%%%%%%%%%%%%%%%%%%% End of Plot_GM %%%%%%%%%%%%%%%%%%%%%%%%%%%%function Plot_Std_Ellipse(M,V)[Ev,D] = eig(V);d = length(M);if V(:,:)==zeros(d,d),    V(:,:) = ones(d,d)*eps;endiV = inv(V);% Find the larger projectionP = [1,0;0,0];  % X-axis projection operatorP1 = P * 2*sqrt(D(1,1)) * Ev(:,1);P2 = P * 2*sqrt(D(2,2)) * Ev(:,2);if abs(P1(1)) >= abs(P2(1)),    Plen = P1(1);else    Plen = P2(1);endcount = 1;step = 0.001*Plen;Contour1 = zeros(2001,2);Contour2 = zeros(2001,2);for x = -Plen:step:Plen,    a = iV(2,2);    b = x * (iV(1,2)+iV(2,1));    c = (x^2) * iV(1,1) - 1;    Root1 = (-b + sqrt(b^2 - 4*a*c))/(2*a);    Root2 = (-b - sqrt(b^2 - 4*a*c))/(2*a);    if isreal(Root1),        Contour1(count,:) = [x,Root1] + M';        Contour2(count,:) = [x,Root2] + M';        count = count + 1;    endendContour1 = Contour1(1:count-1,:);Contour2 = [Contour1(1,:);Contour2(1:count-1,:);Contour1(count-1,:)];plot(M(1),M(2),'k+');plot(Contour1(:,1),Contour1(:,2),'k-');plot(Contour2(:,1),Contour2(:,2),'k-');%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% End of Plot_Std_Ellipse %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -