📄 em_ghmm.c
字号:
/*
em_ghmm : Expectation-Maximization algorithm for a HMM with Multivariate Gaussian measurements
Usage
-------
[logl , PI , A , M , S] = em_ghmm(Z , PI0 , A0 , M0 , S0 , [options]);
Inputs
-------
Z Measurements (m x K x n1 x ... x nl)
PI0 Initial proabilities (d x 1) : Pr(x_1 = i) , i=1,...,d. PI0 can be (d x 1 x v1 x ... x vr)
A0 Initial state transition probabilities matrix Pr(x_{k} = i| x_{k - 1} = j) such
sum_{x_k}(A0) = 1 => sum(A , 1) = 1. A0 can be (d x d x v1 x ... x vr).
M0 Initial mean vector. M0 can be (m x 1 x d x v1 x ... x vr)
S0 Initial covariance matrix. S0 can be (m x m x d x v1 x ... x vr)
options nb_ite Number of iteration (default [30])
update_PI Update PI (0/1 = no/[yes])
update_A Update PI (0/1 = no/[yes])
update_M Update M (0/1 = no/[yes])
update_S Update S (0/1 = no/[yes])
Ouputs
-------
logl Final loglikelihood (n1 x ... x nl x v1 x ... x vr)
PI Estimated initial probabilities (d x 1 x n1 x ... x nl v1 x ... x vr)
A Estimated state transition probabilities matrix (d x d x n1 x ... x nl v1 x ... x vr)
M Estimated mean vector (m x 1 x d x n1 x ... x nl v1 x ... x vr)
S Estimated covariance vector (m x m x d x n1 x ... x nl v1 x ... x vr)
To compile
-----------
mex -output em_ghmm.dll em_ghmm.c
mex -f mexopts_intel10amd.bat -output em_ghmm.dll em_ghmm.c
Example 1
----------
d = 2;
m = 2;
L = 1;
R = 1;
Ntrain = 3000;
Ntest = 10000;
options.nb_ite = 30;
PI = [0.5 ; 0.5];
A = [0.95 0.05 ; 0.05 0.95];
M = cat(3 , [-1 ; -1] , [2 ; 2]);
S = cat(3 , [1 0.3 ; 0.3 0.8] , [0.7 0.6; 0.6 1]);
[Ztrain , Xtrain] = sample_ghmm(Ntrain , PI , A , M , S , L);
Xtrain = Xtrain - 1;
%%%%% initial parameters %%%%
PI0 = rand(d , 1 , R);
sumPI = sum(PI0);
PI0 = PI0./sumPI(ones(d , 1) , : , :);
A0 = rand(d , d , R);
sumA = sum(A0);
A0 = A0./sumA(ones(d , 1) , : , :);
M0 = randn(m , 1 , d , R);
S0 = repmat(cat(3 , [2 0 ; 0 2] , [3 0; 0 2]) , [1 , 1 , 1, R]);
%%%%% EM algorithm %%%%
[logl , PIest , Aest , Mest , Sest] = em_ghmm(Ztrain , PI0 , A0 , M0 , S0 , options);
[x , y] = ndellipse(M , S);
[xest , yest] = ndellipse(Mest , Sest);
Ltrain_est = likelihood_mvgm(Ztrain , Mest , Sest);
Xtrain_est = forward_backward(PIest , Aest , Ltrain_est);
Xtrain_est = Xtrain_est - 1;
ind1 = (Xtrain_est == 0);
ind2 = (Xtrain_est == 1);
Err_train = min(sum(Xtrain ~= Xtrain_est , 2)/Ntrain , sum(Xtrain ~= ~Xtrain_est , 2)/Ntrain);
figure(1) ,
h = plot(Ztrain(1 , ind1) , Ztrain(2 , ind1) , 'k+' , Ztrain(1 , ind2) , Ztrain(2 , ind2) , 'g+' , x , y , 'b' , xest , yest ,'r', 'linewidth' , 2);
legend([h(1) ; h(3:m:end)] , 'Train data' , 'True' , 'Estimated' , 'location' , 'best')
title(sprintf('Train data, Error rate = %4.2f%%' , Err_train*100))
%%%%% Test data %%%%
[Ztest , Xtest] = sample_ghmm(Ntest , PI , A , M , S , L);
Xtest = Xtest - 1;
Ltest_est = likelihood_mvgm(Ztest , Mest , Sest);
Xtest_est = forward_backward(PIest , Aest , Ltest_est);
Xtest_est = Xtest_est - 1;
ind1 = (Xtest_est == 0);
ind2 = (Xtest_est == 1);
Err_test = min(sum(Xtest ~= Xtest_est , 2)/Ntest , sum(Xtest ~= ~Xtest_est , 2)/Ntest);
figure(2),
h = plot(Ztest(1 , ind1) , Ztest(2 , ind1) , 'k+' , Ztest(1 , ind2) , Ztest(2 , ind2) , 'g+' , x , y , 'b' , xest , yest ,'r', 'linewidth' , 2);
legend([h(1) ; h(3:m:end)] , 'Test data' , 'True' , 'Estimated' , 'location' , 'best')
title(sprintf('Test data, Error rate = %4.2f%%' , Err_test*100))
Author : S閎astien PARIS
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -