⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 chmmtrain.m

📁 CHMMBOX, version 1.2, Iead Rezek, Oxford University, Feb 2001 Matlab toolbox for max. aposteriori e
💻 M
字号:
function [chmm]=chmmtrain(X,Y,T,chmm)% function [chmm]=chmmtrain(X,Y,T,chmm)%% Train Coupled Hidden Markov Model using Baum Welch/MAP EM algorithm% using node clustering%% INPUTS:%% X,Y - observation sequences% T - length of each sequence (N must evenly divide by T, default T=N)% chmm.hmmone    definitions of first chain% chmm.hmmtwo    definitions of second chain% for each chain% chmm.hmmone/hmmtwo.K - number of states  for first chain% chmm.hmmone/hmmtwo.P - state transition matrix, 3 dimensional% chmm.hmmone/hmmtwo.obsmodel -  'Gauss' ('GaussCom','AR' or 'LIKE' not yet implemented)%% chmm.train.cyc - maximum number of cycles of Baum-Welch (default 100)% chmm.train.tol - termination tol (prop change in likelihood) (default 0.0001)% chmm.train.init - Already initialised the obsmodel (1 or 0) ? (default=0)% chmm.train.obsupdate - Update the obsmodel (1 or 0) ?  (default=1)% chmm.train.pupdate - Update transition matrix (1 or 0) ? (default=1)%% OUTPUTS% chmm.Pi - priors% chmm.P - state transition matrix, % chmm.state(k).$$ - whatever parameters there are in the observation model% chmm.LPtrain  - training log posterior%if isfield(chmm.train,'cyc')  cyc=chmm.train.cyc;else  cyc=50; endif isfield(chmm.train,'tol')  tol=chmm.train.tol;else  tol=0.0001; end% checking parameters and assigning defaults[chmm,N,K,traininit,updateobs,updatep,P,Pi]=paramchk(chmm,X,Y,T);LP=[];lpost=0;K_cart=K(1)*K(2);alpha=zeros(T,K_cart);beta=zeros(T,K_cart);gamma=zeros(T,K_cart);								    % merging transition probabilities, initial state probabilities,transition% probability and initial state probability priorsP_cart=reshape(joinpdf(P.hmmone,P.hmmtwo,[2 3]),K_cart,K_cart) % checkedPi_cart=reshape(Pi.hmmone'*Pi.hmmtwo,1,K_cart);chmm.priors.Dir2d_alpha=...    reshape(joinpdf(chmm.hmmone.priors.Dir3d_alpha,...		    chmm.hmmtwo.priors.Dir3d_alpha,[2 3]),K_cart,K_cart) ;chmm.priors.Dir_alpha=...    reshape(chmm.hmmone.priors.Dir_alpha'*chmm.hmmtwo.priors.Dir_alpha,...	    1,K_cart);% The transition probabilities have the following form%% P_cart=P(S_t|S_t')=P(next state | current state) = %{Sa,Sb}_t|{Sa,Sb}_t';{Sa,Sb}_t|{~Sa,Sb}_t';{Sa,Sb}_t|{Sa,~Sb}_t';{Sa,Sb}_t|{~Sa,~Sb}_t'%{~Sa,Sb}_t|   ,,    ;{~Sa,Sb}_t|    ,,    ;{~Sa,Sb}_t|   ,,     ;{~Sa,Sb}_t|    ,,      %{Sa,~Sb}_t|   ,,    ;{Sa,~Sb}_t|    ,,    ;{Sa,~Sb}_t|   ,,     ;{Sa,~Sb}_t|    ,,      %{~Sa,~Sb}_t|  ,,    ;{~Sa,~Sb}_t|   ,,    ;{~Sa,~Sb}_t|  ,,     ;{~Sa,~Sb}_t|   ,,      %for cycle=1:cyc     %%%% FORWARD-BACKWARD     Gamma.joint=[];								     Gamma.hmmone=[];  Gamma.hmmtwo=[];  Gammasum.joint=zeros(1,K_cart);  Gammasum.hmmone=zeros(1,K(1));  Gammasum.hmmtwo=zeros(1,K(2));  Scale=zeros(T,1);  Xi=zeros(T-1,K_cart,K_cart);  for n=1:N        Bone = obslike(X,T,n,chmm.hmmone);    Btwo = obslike(Y,T,n,chmm.hmmtwo);	    % Augmenting    for i=1:T,       Btemp=Bone(i,:)'*ones(1,K(2));        Bcompone(i,:)=reshape(Btemp,1,K_cart); % P(Sa,~Sa,Sa,~Sa);       Btemp=ones(1,K(1))'*Btwo(i,:);        Bcomptwo(i,:)=reshape(Btemp,1,K_cart); % P(Sb,Sb,~Sb,~Sb);       % Bcompone.* Bcomptwo  <=>  P({Sa,Sb},{~Sa,Sb},{Sa,~Sb},{~Sa,Sb})    end;    % alpha <=>  P(O,{Sa,Sb},{~Sa,Sb},{Sa,~Sb},{~Sa,Sb})    scale=zeros(T,1);								        alpha(1,:)=Pi_cart.*[Bcompone(1,:).* Bcomptwo(1,:)];    scale(1,:)=sum(alpha(1,:));    alpha(1,:)=alpha(1,:)/scale(1);    for i=2:T      % alpha(i,:)=(alpha(i-1,:)*P_cart).*[Bcompone(i,:).*Bcomptwo(i,:)];      alpha(i,:)=([P_cart*alpha(i-1,:)']').*[Bcompone(i,:).*Bcomptwo(i,:)];      scale(i)=sum(alpha(i,:));      alpha(i,:)=alpha(i,:)/scale(i);    end;    beta(T,:)=ones(1,K_cart)/scale(T);						       for i=T-1:-1:1        beta(i,:)=(beta(i+1,:).*... 		  [Bcompone(i+1,:).*Bcomptwo(i+1,:)])*(P_cart)/scale(i); 	%        % beta(i,:)=(beta(i+1,:).*[Bcompone(i+1,:).*Bcomptwo(i+1,:)])*(Pcart')/scale(i);    end;        gamma.joint=(alpha.*beta);     gamma.joint=mddiv(gamma.joint,sum(gamma.joint,2),2); % sum over states    gammasum.joint=sum(gamma.joint,1);	% sum over time for each state        gamma.hmmone=squeeze(sum(reshape(gamma.joint,T,K(1),K(2)),3));    gammasum.hmmone=sum(gamma.hmmone,1);    gamma.hmmtwo=squeeze(sum(reshape(gamma.joint,T,K(1),K(2)),2));    gammasum.hmmtwo=sum(gamma.hmmtwo,1);        xi=zeros(T-1,K_cart,K_cart);    for i=1:T-1        % t=P_cart.*( alpha(i,:)' * (beta(i+1,:).*[Bcompone(i+1,:).*Bcomptwo(i+1,:)]));      t=P_cart.*((beta(i+1,:).*[Bcompone(i+1,:).*Bcomptwo(i+1,:)])'*alpha(i,:));      xi(i,:,:)=t./sum(t(:));    end;        Scale=Scale+log(scale);    Gamma.joint=[Gamma.joint; gamma.joint];    Gammasum.joint=Gammasum.joint+gammasum.joint;    Gamma.hmmone=[Gamma.hmmone; gamma.hmmone];    Gammasum.hmmone=Gammasum.hmmone+gammasum.hmmone;    Gamma.hmmtwo=[Gamma.hmmtwo; gamma.hmmtwo];    Gammasum.hmmtwo=Gammasum.hmmtwo+gammasum.hmmtwo;    Xi=Xi+xi;  end;  %  evaluate likelihood and priors   oldlpost=lpost;  lik=sum(Scale);  lprior=evalmodelprior(chmm.hmmone);	% compute parameter pops under priors  lprior=[lprior evalmodelprior(chmm.hmmtwo)]; % for both chains    % Transition Props and props of first state node  lprior=[lprior dirichlet(Pi_cart,chmm.priors.Dir_alpha,1)];  for l=1:K,    lprior=[lprior dirichlet(P_cart(:,l),chmm.priors.Dir2d_alpha(:,l),1)];  end;  lpost=(lik+sum(lprior));		% log posterior  LP=[LP; lik lprior];%%%% M STEP     % transition matrix   sxi=squeeze(sum(Xi,1));                  % sum over time  if updatep    sxi=sxi+(chmm.priors.Dir2d_alpha-1);    P_cart=mddiv(sxi,sum(sxi,1),1);        % normalise over future state  end  % priors  Pi_cart=zeros(1,K_cart);  for i=1:N    Pi_cart=Pi_cart+Gamma.joint((i-1)*T+1,:);  end  Pi_cart=Pi_cart+chmm.priors.Dir_alpha-1;  Pi_cart=Pi_cart./sum(Pi_cart);    Pi.hmmone=sum(reshape(Pi_cart,K(1),K(2)),2)';  Pi.hmmtwo=sum(reshape(Pi_cart,K(1),K(2)),1);  % Observation model  if sum(updateobs.hmmone) > 0    chmm.hmmone=obsupdate(X,T,Gamma.hmmone,Gammasum.hmmone,chmm.hmmone,updateobs.hmmone);  end  if sum(updateobs.hmmtwo) > 0    chmm.hmmtwo=obsupdate(Y,T,Gamma.hmmtwo,Gammasum.hmmtwo,chmm.hmmtwo,updateobs.hmmtwo);  end    fprintf('cycle %i log posterior = %f ',cycle,lpost);    if (cycle<=2)    lpostbase=lpost;  elseif (lpost<oldlpost)     fprintf('violation');  elseif ((lpost-lpostbase)<(1 + tol)*(oldlpost-lpostbase)|~finite(lpost))     fprintf('\n');    break;  end;  fprintf('\n');endchmm.hmmone.P=squeeze(sum(reshape(P_cart,K(1),K(2),K(1),K(2)),2));chmm.hmmtwo.P=squeeze(sum(reshape(P_cart,K(1),K(2),K(1),K(2)),1));chmm.hmmone.Pi=Pi.hmmone;chmm.hmmtwo.Pi=Pi.hmmtwo;chmm.P=P_cart;chmm.Pi=Pi_cart;chmm.K=K_cart;chmm.LPtrain=lpost;chmm.data.Xtrain=X;chmm.data.Ytrain=Y;chmm.data.T=T;%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [chmm,N,K,traininit,updateobs,updatep,P,Pi]=paramchk(chmm,X,Y,T)%% Copy in and check existence of parameters for chmm data structure%% Input chmm, data series X and Y %% Output params: %        params(1) = p	; time-series Dimension%        N	; time-series length%	 K		; dimension of state-space%	 traininit	; initialisation flag%	 P		; transition probabilities% The first hmm chain[chmm.hmmone,params.hmmone]=scparamchk(chmm.hmmone,X);% Now the second hmm chain[chmm.hmmtwo,params.hmmtwo]=scparamchk(chmm.hmmtwo,Y);% consistency check for the 2 chainsif params.hmmone.N~=params.hmmtwo.N	error('Time series must be of equal length');else   N=params.hmmone.N;end% hidden statesK=[params.hmmone.K; params.hmmtwo.K];% training initialisationtraininit.hmmone=params.hmmone.init;traininit.hmmtwo=params.hmmtwo.init;% transition matricesif params.hmmone.P==-1,  P.hmmone=rand(K(1),K(1),K(2)); P.hmmone=mddiv(P.hmmone,sum(P.hmmone,1),1);else  P.hmmone=chmm.hmmone.P;end;if params.hmmtwo.P==-1,  P.hmmtwo=rand(K(2),K(1),K(2)); P.hmmtwo=mddiv(P.hmmtwo,sum(P.hmmtwo,1),1);else  P.hmmtwo=chmm.hmmtwo.P;end;% update observation modelsupdateobs.hmmone=params.hmmone.updateobs;updateobs.hmmtwo=params.hmmtwo.updateobs;% update state transition prop.if (params.hmmone.updatep ~= params.hmmtwo.updatep),  error('Transition Probabilities must be updated jointly');end;updatep=params.hmmone.updatep;% priorPi.hmmone=params.hmmone.Pi;Pi.hmmtwo=params.hmmtwo.Pi;if (rem(N,T)~=0)  error('Data matrix length must be multiple of sequence length T');  return;end;N=N/T;%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [hmm,params]=scparamchk(hmm,X)% function [hmm,params]=scparamchk(hmm,X)%% Copy in and check existence of parameters for single chain hmm data structure%% Input hmm%% Output params: %        params.p	; time-series Dimension%        params.N	; time-series length%	 params.K	; dimension of state-space%	 params.traininit	; initialisation flag%	 params.obsupdate	; update observation model%	 params.pupdatep	; update state transition matrix%	 params.P	; transition probabilities%        params.Pi	; initial state probabilitiesif ~isfield(hmm,'obsmodel')  disp('Error in hmm_train: obsmodel not specified');  returnendparams.p=length(X(1,:));params.N=length(X(:,1));if isfield(hmm,'K')  params.K=hmm.K;else  disp('Error in hmmtrain: K not specified');  returnendif ~isfield(hmm,'train')  disp('Error in hmmtrain: hmm.train not specified');  returnendif ~isfield(hmm.train,'init')  params.init=0;else  params.init=hmm.train.init;endif ~isfield(hmm.train,'obsupdate')  params.updateobs=ones(1,hmm.K);  % update observation models for all stateselse  params.updateobs=hmm.train.obsupdate;endif ~isfield(hmm.train,'pupdate')  params.updatep=1;else  params.updatep=hmm.train.pupdate;endif ~isfield(hmm,'P')%  P=rand(hmm.K,hmm.K,hmm.K);%  params(7)=mddiv(P,mdsum(P,dim),dim);   params.P=-1;			% must be done outsideelse				% need info from other chain  params.P=1;endif ~isfield(hmm,'Pi')   params.Pi=rand(1,hmm.K);   params.Pi=params.Pi./sum(params.Pi);else   params.Pi=hmm.Pi;end;if ~params.init,  hmm=obsinit(X,hmm);end;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -