📄 trainfbhmm.m
字号:
tmpGS6(kk,:) = sum(gammas.*tmpDat,2)'; gammaSum6{bb}=tmpGS6; end for jj=1:M validStates=allValidStates{jj}; numVal=length(validStates); %%%% X(j,j) gammaSum1{kk,jj}=gammaSum5(kk,validStates)'; %%%% b for bb=1:G.numBins tmpGS6 = gammaSum6{bb}; gammaSum2{bb,kk,jj}=tmpGS6(kk,validStates)'; end end end if G.updateTime newD(kk,:) = updateTimeConst(G,alphas,betas,rhos,samplesMat(:,:,kk),tmpZ,kk); end if G.updateScale newScaleCounts(:,myClass) = newScaleCounts(:,myClass) + updateScaleConst(G,alphas,betas,rhos,samplesMat(:,:,kk),tmpZ,kk)'; end if G.updateU % size(gammas)= numStates x numRealTimes for dd=1:G.numBins tmpDat = permute(repdata(:,dd,:),[3 1 2]); gammaSum3(dd,kk,:)=sum(gammas.*tmpDat,2)'; %% this one could just be independent of the bin... gammaSum4(dd,kk,:)=sum(gammas,2)'; end end end %%END ITERATING OVER INDIVIDUAL SAMPLES, kk %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% msg=sprintf('Iteration %d E-step CPU time for all time series %.2f min.',... numIt,totalElapsed(numIt)); myPrint(fidErr,msg); if G.updateSigma% & (numIt>minSigmaUpdateIt)) %% NO, NO, this can really mess things up!!! if 0 %% force the sigmas to be the same across all bins: %% (they aren't because z changes from bin to bin newSigmas = mean(newSigmas,1); newSigmas = repmat(newSigmas, [G.numBins 1]); end %% force sigmas to be shared in each class if 1%G.SHARE for cc=1:C newVals = mean(newSigmas(1:G.numBins,G.class{cc})'); newValsRep = repmat(newVals,[G.numPerClass(cc) 1])'; newSigmas(:,G.class{cc}) = newValsRep; end newSigmas = sqrt(newSigmas/G.numRealTimes); elseif ~HOLD_OUT %% enforce factor by which they must agree %% but if it is hold out, then don't need to enforce newSigmas = sqrt(newSigmas/G.numRealTimes); %newSigmas2=newSigmas; for b=1:G.numBins newSigmas2(b,:) = constrainSigmasIterate(... G,newSigmas(b,:),G.sigmas(b,:)); end %newSigmas-newSigmas2 %keyboard; newSigmas = newSigmas2; elseif HOLD_OUT %% don't enforce the constraint between sigmas, just use %% their raw updates newSigmas = sqrt(newSigmas/G.numRealTimes); else error('case doesnt exist'); end %temp=[G.sigmas' newSigmas' (G.sigmas'-newSigmas')]; temp=[newSigmas]; % disp the new sigmas %disp('Sigmas Updates'); %num2str(newSigmas',5) formatStr = getFormatStr(G.numSamples,2); myStr=sprintf(formatStr,newSigmas'); if fidErr fprintf(fidErr,'Sigma Updates\n%s\n',myStr); end %oldSigmas = G.sigmas; G.sigmas = newSigmas; %tempSigs = G.sigmas.^(-2); %figure,show(oldSigmas-G.sigmas); %imstats(oldSigmas-G.sigmas); %keyboard; end if G.updateZ for bb=1:G.numBins sqGamSum2 = permute(gammaSum2(bb,:,:),[2 3 1]); [currentTrace(:,:,bb),output,options,changeFlag] = ... getNewZ(G,samplesMat,allValidStates,... scaleFacs,scaleFacsSq,gammaSum1,... sqGamSum2,gammaSum5,... gammaSum6{bb},scalesExp,scalesExp2,bb); if numIt==1 myStr=printStruct(options); if fidErr fprintf(fidErr,'%s\n',myStr); end end myStr=printStruct(output); if fidErr fprintf(fidErr,'%s\n',myStr); end end G.z=currentTrace; allTraces(numIt,:,:,:)=G.z; if ~changeFlag myStr=sprintf('newZ same as oldZ'); if fidErr fprintf(fidErr,'%s\n',myStr); end end end %%end G.updateZ if G.updateU newU= getNewU(G,gammaSum3,gammaSum4,gammaSum5,... gammaSum6,scalesExp); G.u=newU; if G.USE_CPM2 %oldUmat = G.uMat; G.uMat = getUMat(G,G.u); %figure,show(G.uMat-oldUmat); colorbar; end %myStr=sprintf('myfval:%.4f\n',myfval(numIt,:)); %fprintf(fidErr,'%s\n',myStr); %myStr=sprintf('u: %.4f\n',G.u'); %myPrint(fidErr,myStr); end oldU=G.u; if G.updateTime if G.SHARE_TRANS tmpD = mean(newD,1); newD = repmat(tmpD,[G.numSamples 1]); end myStr=sprintf('TimeTrans prob') tmpStr = ''; for jj=1:G.maxTimeSteps tmpStr = [tmpStr '%.4f ']; end tmpStr = [tmpStr '\n']; myStr=sprintf(tmpStr,newD'); if fidErr fprintf(fidErr,'TimeTrans prob\n %s\n',myStr); end end if G.updateScale newS = getNewS(G,newScaleCounts); myStr=sprintf('ScaleTrans prob: %.4f\n',newS) myStr=[myStr sprintf('\n')]; if fidErr fprintf(fidErr,'%s\n',myStr); end end %% propagate the state transition updates to our data %% structure, G if G.updateTime if G.updateScale G = reviseG(G,newS,newD); else G = reviseG(G,G.S,newD); end elseif G.updateScale G = reviseG(G,newS,G.D); end %% these were computed at top of loop, before FB algorithm, %% since the state posteriors and hence mainLikes come from %% previous model, not updates just performed. newLike = sum(mainLikes(numIt,:)); newLike = newLike + scalePriorTerm(numIt) + ... sum(timePriorTerm(numIt,:)); newLike = newLike + nuTerm(numIt) + ... sum(smoothLikes(numIt,:)) + scaleCenterPriorTerm; %% This newLike is the likelihood likes(numIt) = newLike; %% check if parameters have converged didConvergeP = checkConverge(oldG,G); %didConvergeL = (newLike-oldLike)/abs(newLike) < G.thresh; didConvergeL = (newLike-oldLike) < G.thresh;% G.badThresh=1e-8;% if 1%~(HOLD_OUT && numIt==2)% %badLikelihood = (oldLike-newLike)>1e-12;%eps;% badLikelihood = (oldLike-newLike)>G.badThresh;%1e-8;%1e-12;%eps;% else% badLikelihood=0;% end %%% if you want to check for parameter convergence: %didConverge = didConvergeP;% & didConvergeL; %%% if you want to check for parameter and likelihood convergence: %didConverge = didConvergeP && didConvergeL; %%% if you want to check for just likelihood convergence: didConverge = didConvergeL; msg=sprintf( 'log likelihood for all data: %.8e',likes(numIt)); myPrint(fidErr,msg); if numIt>1 msg=sprintf(' (difference from last iteration: %.8e)',diff(likes((numIt-1):numIt))); myPrint(fidErr,msg); end if 0%badLikelihood msg1 = 'LIKELIHOOD went down more than threshold!' msg2 = ['Change: ' num2str((oldLike-newLike),2)]; diffLike = sprintf('%.2e',(oldLike-newLike)); msg2 = [msg2 ' Difference in likelihood: ' diffLike] errorFound=1; disp('Likelihood went down... waiting'); figure,semilogy(likes(1:numIt)) keyboard; elseif didConverge keepGoing=0; msg1 = 'CONVERGED Likelihood'; if (didConvergeP) msg1 = [msg1 ' P']; end if (didConvergeL) msg1 = [msg1 ' L']; end disp(msg1); msg2=['Change: ' num2str((newLike-oldLike),2)] errorFound=1; %bit of a misnomer... end if errorFound if fidErr fprintf(fidErr,'WARNING: %s %s\n\n',msg1,msg2); end errorFound=0; end if numIt==G.maxIter & ~errorFound msg1=['Did not converge, max num iterations reached: ' num2str(G.maxIter)]; myPrint(fidErr,msg1); if G.maxIter>1 msg2=['Last change in likelihood=', num2str((newLike-oldLike),2)]; disp(msg2); else msg2=''; end keepGoing=0; errorFound=1; end if errorFound if fidErr fprintf(fidErr,'WARNING: %s %s\n\n',msg1,msg2); end errorFound=0; end %% stripG gets rid of huge redundant parts of the data structure %% which can easily be regenerated from the remaining bits allG{numIt}=stripG(G,1); oldLike=newLike; lastTime=cputime; elapsed(numIt)=(lastTime-firstTime)/60; sprintf('Iteration %.2f full EM step, CPU time: %.2f minutes\n',numIt,elapsed(numIt)); if fidErr fprintf(fidErr,'Iteration %.2f CPU Time: %.2f minutes\n',numIt,elapsed(numIt)); end if saveWork && mod(numIt,saveFreq)==0 eval(saveCmd); endend%% Some computation of the log likelihood is at the beginning of the%% loop, so we need to do it once more here to get the final value.lastLogLike = EMCPM_logLike(G,samplesMat)likes(numIt+1)=lastLogLike; msg1 = ['FINAL EM Iteration: ' num2str(numIt), ... ', log likelihood=' printSci(likes(numIt),6) ' TOTAL CPU TIME=' num2str(sum(elapsed))];myPrint(fidErr,msg1);elapsed = elapsed(1:numIt);logLikes = likes(1:(numIt+1));smoothLikes=smoothLikes(1:numIt,:);scalePriorTerm=scalePriorTerm(1:numIt);timePriorTerm=timePriorTerm(1:numIt,:);nuTerm = nuTerm(1:numIt);mainLikes=mainLikes(1:numIt,:);allTraces = allTraces(1:numIt,:,:);if fidErr fclose(fidErr); endreturn;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -