⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 em_ghmm.c

📁 EM算法的MATLAB实现,EM算法是HMM中用于随机过程模型参数训练的经典算法
💻 C
字号:

/*

 em_ghmm : Expectation-Maximization algorithm for a HMM with Multivariate Gaussian measurements


 Usage  
 -------


 [logl , PI , A , M , S]  = em_ghmm(Z , PI0 , A0 , M0 , S0 , [options]);


 Inputs
 -------

 Z             Measurements (m x K x n1 x ... x nl)

 PI0           Initial proabilities (d x 1) : Pr(x_1 = i) , i=1,...,d. PI0 can be  (d x 1 x v1 x ... x vr)

 A0            Initial state transition probabilities matrix Pr(x_{k} = i| x_{k - 1} = j)  such
               sum_{x_k}(A0) = 1 => sum(A , 1) = 1. A0 can be (d x d x v1 x ... x vr).

 M0            Initial mean vector. M0 can be (m x 1 x d x v1 x ... x vr)

 S0            Initial covariance matrix. S0 can be (m x m x d x v1 x ... x vr)

 options       nb_ite         Number of iteration (default [30])
               update_PI      Update PI (0/1 = no/[yes])
               update_A       Update PI (0/1 = no/[yes])
               update_M       Update M  (0/1 = no/[yes])
               update_S       Update S  (0/1 = no/[yes])



 Ouputs
 -------

 logl          Final loglikelihood (n1 x ... x nl x v1 x ... x vr)

 PI            Estimated initial probabilities (d x 1 x n1 x ... x nl v1 x ... x vr)

 A             Estimated state transition probabilities matrix (d x d x n1 x ... x nl v1 x ... x vr)

 M             Estimated mean vector (m x 1 x d x n1 x ... x nl v1 x ... x vr)

 S             Estimated covariance vector (m x m x d x n1 x ... x nl v1 x ... x vr)



To compile
-----------


mex -output em_ghmm.dll em_ghmm.c


mex -f mexopts_intel10amd.bat -output em_ghmm.dll em_ghmm.c


Example 1
----------


   d                                   = 2;
   m                                   = 2;
   L                                   = 1;
   R                                   = 1;
   Ntrain                              = 3000;
   Ntest                               = 10000;
   options.nb_ite                      = 30;

   PI                                  = [0.5 ; 0.5];
   A                                   = [0.95 0.05 ; 0.05 0.95];
   M                                   = cat(3 , [-1 ; -1] , [2 ; 2]);
   S                                   = cat(3 , [1 0.3 ; 0.3 0.8] , [0.7 0.6; 0.6 1]);

   [Ztrain , Xtrain]                   = sample_ghmm(Ntrain , PI , A , M , S , L);
   Xtrain                              = Xtrain - 1;

   %%%%% initial parameters %%%%

   PI0                                 = rand(d , 1 , R);
   sumPI                               = sum(PI0);
   PI0                                 = PI0./sumPI(ones(d , 1) , : , :);
   
   A0                                  = rand(d , d , R);
   sumA                                = sum(A0);
   A0                                  = A0./sumA(ones(d , 1) , : , :);

   M0                                  = randn(m , 1 , d , R);
   S0                                  = repmat(cat(3 , [2 0 ; 0 2] , [3 0; 0 2]) , [1 , 1 , 1, R]);

   %%%%% EM algorithm %%%%

   [logl , PIest , Aest , Mest , Sest] = em_ghmm(Ztrain , PI0 , A0 , M0 , S0 , options);


   [x , y]                             = ndellipse(M , S);
   [xest , yest]                       = ndellipse(Mest , Sest);

   Ltrain_est                          = likelihood_mvgm(Ztrain , Mest , Sest);
   Xtrain_est                          = forward_backward(PIest , Aest , Ltrain_est);
   Xtrain_est                          = Xtrain_est - 1;

   ind1                                = (Xtrain_est == 0);
   ind2                                = (Xtrain_est == 1);

   Err_train                           = min(sum(Xtrain ~= Xtrain_est , 2)/Ntrain , sum(Xtrain ~= ~Xtrain_est , 2)/Ntrain);

   figure(1) , 
   h                                   = plot(Ztrain(1 , ind1) , Ztrain(2 , ind1) , 'k+' , Ztrain(1 , ind2) , Ztrain(2 , ind2) , 'g+' , x , y , 'b' , xest  , yest ,'r', 'linewidth' , 2);
   legend([h(1) ; h(3:m:end)] , 'Train data' , 'True'  , 'Estimated' , 'location' , 'best')
   title(sprintf('Train data, Error rate = %4.2f%%' , Err_train*100))

   %%%%% Test data  %%%%


   [Ztest , Xtest]                     = sample_ghmm(Ntest , PI , A , M , S , L);
   Xtest                               = Xtest - 1;


   Ltest_est                           = likelihood_mvgm(Ztest , Mest , Sest);
   Xtest_est                           = forward_backward(PIest , Aest , Ltest_est);
   Xtest_est                           = Xtest_est - 1;


   ind1                                = (Xtest_est == 0);
   ind2                                = (Xtest_est == 1);

   Err_test                            = min(sum(Xtest ~= Xtest_est , 2)/Ntest , sum(Xtest ~= ~Xtest_est , 2)/Ntest);

   figure(2), 
   h                                   = plot(Ztest(1 , ind1) , Ztest(2 , ind1) , 'k+' , Ztest(1 , ind2) , Ztest(2 , ind2) , 'g+' , x , y , 'b' , xest  , yest ,'r', 'linewidth' , 2);
   legend([h(1) ; h(3:m:end)] , 'Test data' , 'True'  , 'Estimated' , 'location' , 'best')
   title(sprintf('Test data, Error rate = %4.2f%%' , Err_test*100))



   Author : S閎astien PARIS  

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -