⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 em_algorithm.m

📁 EM分群,matlab程式碼,用來分群用的
💻 M
字号:
function [weights, pars, EM_exflag] = EM_algorithm(data_R, data_T, module_type, init_pars, init_weights, opts, curr_fig);

N_R = length(data_R);
N_T = length(data_T);

N   = N_R + N_T;

n_mix   = length(module_type);

for ii=1:n_mix
    
    y = EM_Handle_list(module_type{ii});    
    pdf_handle(ii) = y(1);
    cdf_handle(ii) = y(2);
    
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

h_evolution = [];
h_sing      = [];

weights     = init_weights;
pars        = init_pars;

pdf_est      = zeros(opts.FIG_NBINS,1);
f_i_k        = zeros(N_R,n_mix);

if any(strcmp('Gam',module_type)) | any(strcmp(module_type,'LogNorm') )
	log_data_R = log(data_R);
	if opts.USECENS>0
        log_data_T = log(data_T);
	end
end

if opts.USECENS>0
    S_i_k = zeros(N_T,n_mix);
end



perc_var_L  = 1;
nn          = 1;

figure(curr_fig);


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


for k=1:n_mix       
    f_i_k(:,k) = feval(pdf_handle(k), data_R, pars{k});  
    alpha_i_k(:,k) = weights(k)*f_i_k(:,k);
end
NORM_ALPHA = sum(alpha_i_k,2);    

for k=1:n_mix    
    alpha_i_k(:,k)   = alpha_i_k(:,k)./NORM_ALPHA;
end        

if opts.USECENS>0    
    for k=1:n_mix
        S_i_k(:,k) = 1 - feval(cdf_handle(k), data_T, pars{k});        
        beta_i_k(:,k)  = weights(k)*S_i_k(:,k);                      
    end
    NORM_BETA  = sum(beta_i_k,2);        
    for k=1:n_mix    
        beta_i_k(:,k)    = beta_i_k(:,k)./NORM_BETA;    
    end            
end



old_L       = 0;
for k=1:n_mix
    if opts.USECENS>0
        indici = find(beta_i_k(:,k)>0);        
        old_L = old_L + ...
            (1/N)*negloglik_cens(pars{k}, pdf_handle(k), cdf_handle(k), ...
            alpha_i_k(:,k), beta_i_k(indici,k), data_R, data_T(indici));      
    else
        old_L = old_L + (1/N_R)*negloglik(pars{k}, pdf_handle(k), cdf_handle(k), alpha_i_k(:,k), data_R);
    end  
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% calculate likelihood for initial moment estimate

while abs(perc_var_L)>opts.PERCSTOP & nn<opts.MAXIT

   
	for k=1:n_mix       
        f_i_k(:,k) = feval(pdf_handle(k), data_R, pars{k});  
        alpha_i_k(:,k) = weights(k)*f_i_k(:,k);
    end
    NORM_ALPHA = sum(alpha_i_k,2);    
    for k=1:n_mix    
        alpha_i_k(:,k)   = alpha_i_k(:,k)./NORM_ALPHA;
    end        
    SUM_alpha_i_k = sum(alpha_i_k,1);	
    weights_new = (1/N_R) * (SUM_alpha_i_k);
    

    if opts.USECENS>0    
		for k=1:n_mix
            S_i_k(:,k) = 1 - feval(cdf_handle(k), data_T, pars{k});        
            beta_i_k(:,k)  = weights(k)*S_i_k(:,k);                      
        end
        NORM_BETA  = sum(beta_i_k,2);        
        for k=1:n_mix    
            beta_i_k(:,k)    = beta_i_k(:,k)./NORM_BETA;    
        end            
        SUM_beta_i_k  = sum(beta_i_k,1);     		
        weights_new = (1/N) * (SUM_alpha_i_k + SUM_beta_i_k);        
    end
    
    

    new_L = 0;    
	for k = 1:n_mix

        mu_est  =   sum(alpha_i_k(:,k).* data_R, 1) / SUM_alpha_i_k(k);
        var_est =   sum(alpha_i_k(:,k).* (data_R - mu_est).^2, 1) / SUM_alpha_i_k(k);   
        
        par_est =   EM_Invert_pars(mu_est, sqrt(var_est), module_type{k});
        
        if opts.USECENS>0         
            
            indici = find(beta_i_k(:,k)>0);
            
            [p_bar,fval,exflag] = fminsearch(@negloglik_cens,par_est,optimset('Display','off'), ...
                 pdf_handle(k), cdf_handle(k), alpha_i_k(:,k), beta_i_k(indici,k), data_R, data_T(indici));

            if exflag>0
                pars{k} = p_bar;
            else 
                fprintf(1,'Fail to solve nonlinear equation at step %d \t module %d. \n',nn,k);                          
                EM_exflag = 0;                
                return
            end
            
        else                
            
            switch module_type{k}          
                
            case 'InvGauss' 

                    pars{k}(1) = par_est(1);
                    av_invX = sum(alpha_i_k(:,k) .* (1./data_R))/SUM_alpha_i_k(k);
                    pars{k}(2) = 1/(av_invX - 1/mu_est);                    
                                   
            case 'Gam'                                
                
                    av_logX = sum(alpha_i_k(:,k) .* log_data_R)/SUM_alpha_i_k(k);                
                    c = av_logX - log(mu_est);                            
                    [a_bar,fval,exflag] = fzero(@gamma_zero, par_est(1),optimset('Display','off'), c);                                                                  
	
                    if exflag>0
                        pars{k}(1) = a_bar;
                        pars{k}(2) = mu_est/a_bar;
                    else
                        fprintf(1,'Fail to solve nonlinear equation at step %d \t module %d. \n',nn,k);  
                        EM_exflag = 0;
                        return
                    end                                                	
                                
            case 'LogNorm'  
	               
                    mi = sum(alpha_i_k(:,k).* log(data_R), 1) / SUM_alpha_i_k(k);
                    ssig2 = sum(alpha_i_k(:,k).* (log(data_R) - mi).^2, 1) / SUM_alpha_i_k(k);
                    pars{k} = [mi sqrt(ssig2)];                                          
                                                
                    
            end % end 'switch' module_type (no censoring)

            
        end % end if-clause on censoring 

        if opts.USECENS>0
            new_L = new_L + ...
                (1/N)*negloglik_cens(pars{k}, pdf_handle(k), cdf_handle(k), ...
                alpha_i_k(:,k), beta_i_k(indici,k), data_R, data_T(indici));      
        else
            new_L = new_L + (1/N_R)*negloglik(pars{k}, pdf_handle(k), cdf_handle(k), alpha_i_k(:,k), data_R);
        end        
        

    end % end for-loop on modules
                 
   
    perc_var_L = (new_L-old_L)/old_L;
    
    weights = weights_new;
       
 
    if opts.DOPLOT > 0
 
        figure(curr_fig);

        hold on;    
        delete(h_evolution);
        delete(h_sing);
        
        pdf_est = zeros(1,opts.FIG_NBINS);
        single_module = zeros(1,opts.FIG_NBINS);    
        for kj = 1:n_mix
            single_module = weights(kj)*feval(pdf_handle(kj),opts.FIG_BIN_CTRS, pars{kj});        
            pdf_est = pdf_est + single_module;       
            h_sing(kj)=plot(opts.FIG_BIN_CTRS, single_module*opts.FIG_BIN_SIZE, 'k:','EraseMode','none');        
        end
        h_evolution = plot(opts.FIG_BIN_CTRS, pdf_est*opts.FIG_BIN_SIZE, ':r');
        ylim([0 opts.FIG_YMAX]);

        drawnow;
        
    end % end if-clause on DOPLOT
    
    stringa_titolo = strcat('Iterations ',num2str(nn),'/',num2str(opts.MAXIT),...
        ', Expected NegLogLik =',num2str(new_L), ', Rel. change =',num2str(perc_var_L));
    title(stringa_titolo);drawnow;
    
%---------------------------------------------------------------------------------------    


    if opts.DISP_STEPS > 0                       
        fprintf(1,'%d \t\t',nn);            
        for k=1:n_mix
            fprintf(1,'%f \t',weights(k));     
            for jj = 1:length(pars{k})
                fprintf(1,'%f \t',pars{k}(jj));            
            end            
            fprintf(1,'\t');                                    
        end
        fprintf(1,'%f\t%f\n',new_L,perc_var_L);                 
    end


    nn = nn + 1;
    
    old_L = new_L;

end % end EM iterations 

EM_exflag = 1;

fprintf(1,'\n\n Estimated model: \n\n');


for k=1:n_mix

    fprintf(1,'Module %d : %s \n',k, module_type{k});     
    fprintf(1,'-------------\n\n');     
    fprintf(1,'\t\tWeight  = %f ',weights(k));     
    
    for jj = 1:length(pars{k})
        fprintf(1,'\t\tPar_%d = %f',jj, pars{k}(jj));
    end            
    fprintf(1,'\n\n\n');                                    
end


figure(curr_fig);

hold on;    
delete(h_evolution);
delete(h_sing);

pdf_est = zeros(1,opts.FIG_NBINS);
single_module = zeros(1,opts.FIG_NBINS);    
for kj = 1:n_mix
    single_module = weights(kj)*feval(pdf_handle(kj),opts.FIG_BIN_CTRS, pars{kj});        
    pdf_est = pdf_est + single_module;       
    h_sing(kj)=plot(opts.FIG_BIN_CTRS, single_module*opts.FIG_BIN_SIZE, 'k:','EraseMode','none');        
end

h_evolution = plot(opts.FIG_BIN_CTRS, pdf_est*opts.FIG_BIN_SIZE, ':r');
ylim([0 opts.FIG_YMAX]);
drawnow;
hold off;



%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 function y = gamma_zero(a, p)
% 
 y = p - psi(a) + log(a);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function y = negloglik(pars, fhandle, f_cum_handle, ALPHA_k, X_R)

y = zeros(size(X_R));

p = feval(fhandle, X_R, pars);

indici = find(p>0);

y = -sum(ALPHA_k(indici).*log(p(indici)));

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function y = negloglik_cens(pars, fhandle, f_cum_handle, ALPHA_k, BETA_k, X_R, X_T)

p = feval(fhandle, X_R, pars);

S = 1-feval(f_cum_handle, X_T, pars);

indexes = find(S>0);

y = -sum(ALPHA_k.*log(p)) - sum(BETA_k(indexes).*log(S(indexes)));

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -