⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 gaussmix.m

📁 voice box tool box for matlab
💻 M
📖 第 1 页 / 共 2 页
字号:
    wpk=repmat((1:p)',k,1);
    for j=1:l
        g1=g;                    % save previous log likelihood (2*pi factor omitted)
        m1=m;                       % save previous means, variances and weights
        v1=v;
        w1=w;

        for ik=1:k
            
            % these lines added for debugging only
%             vk=reshape(v(k,lixi),p,p);
%             condk(ik)=cond(vk);
            %%%%%%%%%%%%%%%%%%%%
            vik=-0.5*pinv(reshape(v(k,lixi),p,p));   % use pseudo inverse just in case
            vi((ik-1)*p+(1:p),:)=vik;
            vim((ik-1)*p+(1:p))=vik*m(ik,:)';
            mtk((ik-1)*p+(1:p))=m(ik,:)';
            vm(ik)=sqrt(det(vik))*w(ik);        % could do this jointly with the pinv function
            % ************ should use log(vm) to avoid overflow problems
        end
        %
        %         % first do partial chunk
        %
        jx=jx0;
        ii=1:jx;
        xii=x(ii,:).';
        py=reshape(sum(reshape((vi*xii-vim(:,wnj)).*(xii(wpk,:)-mtk(:,wnj)),p,jx*k),1),k,jx);
        mx=max(py,[],1);                % find normalizing factor for each data point to prevent underflow when using exp()
        px=exp(py-mx(wk,:)).*vm(:,wnj);  % find normalized probability of each mixture for each datapoint
        ps=sum(px,1);                   % total normalized likelihood of each data point
        px=px./ps(wk,:);                % relative mixture probabilities for each data point (columns sum to 1)
        lpx(ii)=log(ps)+mx;
        pk=sum(px,2);                   % effective number of data points for each mixture (could be zero due to underflow)
        sx=px*x(ii,:);
        sx2=px*(x(ii,rix).*x(ii,cix));            % accumulator for variance calculation (lower tri cov matrix as a row)

        for il=2:nl
            ix=jx+1;
            jx=jx+nb;        % increment upper limit
            ii=ix:jx;
            xii=x(ii,:).';
            py=reshape(sum(reshape((vi*xii-vim(:,wnb)).*(xii(wpk,:)-mtk(:,wnb)),p,nb*k),1),k,nb);
            mx=max(py,[],1);                % find normalizing factor for each data point to prevent underflow when using exp()
            px=exp(py-mx(wk,:)).*vm(:,wnb);  % find normalized probability of each mixture for each datapoint
            ps=sum(px,1);                   % total normalized likelihood of each data point
            px=px./ps(wk,:);                % relative mixture probabilities for each data point (columns sum to 1)
            lpx(ii)=log(ps)+mx;
            pk=pk+sum(px,2);                % effective number of data points for each mixture (could be zero due to underflow)
            sx=sx+px*x(ii,:);               % accumulator for mean calculation
            sx2=sx2+px*(x(ii,rix).*x(ii,cix));            % accumulator for variance calculation
        end
        g=sum(lpx);                    % total log probability summed over all data points
        gg(j)=g;                        % save convergence history
        w=pk/n;                         % normalize to get the column of weights
        if pk                       % if all elements of pk are non-zero
            m=sx./pk(:,wp);         % find mean and mean square
            v=sx2./pk(:,wpl);
        else
            wm=pk==0;                       % mask indicating mixtures with zero weights
            [vv,mk]=sort(lpx);             % find the lowest probability data points
            m=zeros(k,p);                   % initialize means and variances to zero (variances are floored later)
            v=zeros(k,pl);
            m(wm,:)=x(mk(1:sum(wm)),:);                % set zero-weight mixture means to worst-fitted data points
            wm=~wm;                         % mask for non-zero weights
            m(wm,:)=sx(wm,:)./pk(wm,wp);  % recalculate means and variances for mixtures with a non-zero weight
            v(wm,:)=sx2(wm,:)./pk(wm,wpl);
        end
        v=v-m(:,cix).*m(:,rix);                 % subtract off mean squared
        v(:,dix)=max(v(:,dix),c);                                   % force diagonal elements to be >= c
        if g-g1<=th && j>1
            if ~ss, break; end  %  stop
            ss=ss-1;       % stop next time
        end
    end
    if sd  % we need to calculate the final probabilities
        pp=lpx'-0.5*p*log(2*pi);   % log of total probability of each data point
        gg=gg(1:j)/n-0.5*p*log(2*pi);    % average log prob at each iteration
        g=gg(end);
        %     gg' % *** DEBUG ONLY ***
        m=m1;       % back up to previous iteration
        v=v1;
        w=w1;
        mm=sum(m,1)/k;
        sm=sum(m(:,rix).*m(:,cix),1)/k;
        vm=sum(v,1)/k;
        f=det(sm(lixi)-mm'*mm)/det(vm(lixi));
    end
    v=reshape(v(:,lixi)',[p,p,k]);
    if l==0         % suppress the first three output arguments if l==0
        m=g;
        v=f;
        w=pp;
    end
else
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % Diagonal Covariance matrices  %
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    v=max(v,c);         % apply the lower bound


    % If data size is large then do calculations in chunks

    nb=min(n,max(1,floor(memsize/(8*p*k))));    % chunk size for testing data points
    nl=ceil(n/nb);                  % number of chunks
    jx0=n-(nl-1)*nb;                % size of first chunk

    im=repmat(1:k,1,nb); im=im(:);
    th=(l-floor(l))*n;
    sd=(nargout > 3*(l~=0)); % = 1 if we are outputting log likelihood values
    l=floor(l)+sd;   % extra loop needed to calculate final G value

    lpx=zeros(1,n);             % log probability of each data point
    wk=ones(k,1);
    wp=ones(1,p);
    wnb=ones(1,nb);
    wnj=ones(1,jx0);

    % EM loop

    g=0;                           % dummy initial value for comparison
    gg=zeros(l+1,1);
    ss=sd;                       % initialize stopping count (0 or 1)
    for j=1:l
        g1=g;                    % save previous log likelihood (2*pi factor omitted)
        m1=m;                       % save previous means, variances and weights
        v1=v;
        w1=w;
        lvm=log(w)-0.5*sum(log(v),2);   % calculate log of mixture scale factor to avoid overflow problems
        vi=-0.5*v.^(-1);                % exponent scale factors

        % first do partial chunk

        jx=jx0;
        ii=1:jx;
        kk=repmat(ii,k,1);
        km=repmat(1:k,1,jx);
        py=reshape(sum((x(kk(:),:)-m(km(:),:)).^2.*vi(km(:),:),2),k,jx)+lvm(:,wnj);
        mx=max(py,[],1);                % find normalizing factor for each data point to prevent underflow when using exp()
        px=exp(py-mx(wk,:));  % find normalized probability of each mixture for each datapoint
        ps=sum(px,1);                   % total normalized likelihood of each data point
        px=px./ps(wk,:);                % relative mixture probabilities for each data point (columns sum to 1)
        lpx(ii)=log(ps)+mx;
        pk=sum(px,2);                   % effective number of data points for each mixture (could be zero due to underflow)
        sx=px*x(ii,:);
        sx2=px*x2(ii,:);

        for il=2:nl
            ix=jx+1;
            jx=jx+nb;        % increment upper limit
            ii=ix:jx;
            kk=repmat(ii,k,1);
            py=reshape(sum((x(kk(:),:)-m(im,:)).^2.*vi(im,:),2),k,nb)+lvm(:,wnb);
            mx=max(py,[],1);                % find normalizing factor for each data point to prevent underflow when using exp()
            px=exp(py-mx(wk,:));  % find normalized probability of each mixture for each datapoint
            ps=sum(px,1);                   % total normalized likelihood of each data point
            px=px./ps(wk,:);                % relative mixture probabilities for each data point (columns sum to 1)
            lpx(ii)=log(ps)+mx;
            pk=pk+sum(px,2);                   % effective number of data points for each mixture (could be zero due to underflow)
            sx=sx+px*x(ii,:);
            sx2=sx2+px*x2(ii,:);
        end
        g=sum(lpx);                    % total log probability summed over all data points
        gg(j)=g;
        w=pk/n;                         % normalize to get the weights
        if pk                       % if all elements of pk are non-zero
            m=sx./pk(:,wp);
            v=sx2./pk(:,wp);
        else
            wm=pk==0;                       % mask indicating mixtures with zero weights
            [vv,mk]=sort(lpx);             % find the lowest probability data points
            m=zeros(k,p);                   % initialize means and variances to zero (variances are floored later)
            v=m;
            m(wm,:)=x(mk(1:sum(wm)),:);                % set zero-weight mixture means to worst-fitted data points
            wm=~wm;                         % mask for non-zero weights
            m(wm,:)=sx(wm,:)./pk(wm,wp);  % recalculate means and variances for mixtures with a non-zero weight
            v(wm,:)=sx2(wm,:)./pk(wm,wp);
        end
        v=max(v-m.^2,c);                % apply floor to variances

        if g-g1<=th && j>1
            if ~ss, break; end  %  stop
            ss=ss-1;       % stop next time
        end

    end
    if sd  % we need to calculate the final probabilities
        pp=lpx'-0.5*p*log(2*pi);   % log of total probability of each data point
        gg=gg(1:j)/n-0.5*p*log(2*pi);    % average log prob at each iteration
        g=gg(end);
        %     gg' % *** DEBUG ***
        m=m1;       % back up to previous iteration
        v=v1;
        w=w1;
        mm=sum(m,1)/k;
        f=prod(sum(m.^2,1)/k-mm.^2)/prod(sum(v,1)/k);
    end
    if l==0         % suppress the first three output arguments if l==0
        m=g;
        v=f;
        w=pp;
    end
end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -