⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 trainfbhmm.m

📁 Continuous Profile Models (CPM) Matlab Toolbox.
💻 M
📖 第 1 页 / 共 2 页
字号:
%function [logLikes,allG,didConvergeL,errorFlag]=...%    trainFBHMM(G,samplesMat,errorLogFile,saveFile,HOLD_OUT,USE_CPM2)%% Use EM/forward-backward algorithm to train HMM model using% initTrace %% if HOLD_OUT==1, then sigmas are not constrained by each others% values (as is the case for training)%% 'logLikes' contains the log likelihood over training iterations% 'allG' contains the learned parameters at every iterations in%        a cell array (plus a bit of other junk)% 'didConvergeL' specifies whether the convergen threshold was met%                (not a big deal if it isn't, since this is a bit%                 arbitrary).%% if 'errorFlag' =1, then there was a problem of some sort (possibly% related to using too many 'numBins', or not enough 'lambda')function [logLikes,allG,didConvergeL,errorFlag]=...    trainFBHMM(G,samplesMat,errorLogFile,saveFile,HOLD_OUT,USE_CPM2)if G.numBins>24     warnStr=sprintf(...         ['****************************************************************\n'...         'You have chosen to use more than 24 features for each time point.'...         '\n\nDimensionality of each time point increases the CPU time \n'...         'roughly linearly, and one often need not use the full\n'...         'dimensionality to get good results.  Furthermore, in my\n'...         'useage so far, I have had numerical problems for more than\n'...         '24 dimensions (numBins), but this is problem-specific.'...         '\n\nChange this to a warning if you want to try more.\n'...         '****************************************************************'...         ]);     error(warnStr);endif exist('saveFile') & ~isempty(saveFile)        % save work after each iteration    saveWork=1;    %savevars = 'elapsed allG smoothLikes mainLikes likes numIt initTrace';    %saveCmd = ['save ' saveFile ' ' savevars];    saveCmd = ['save ' saveFile];    saveFreq=50;else    saveWork=0;enderrorFlag=0;lambda=G.lambda; nu=G.nu; initTrace=G.z;samplesMat=permute(samplesMat,[3,2,1]);%samplesMat = reshape(cell2mat(samples),[G.numBins G.numRealTimes G.numSamples]);%clear samples;if any(samplesMat(:))<0    error('CPM code does not handle negative values, please shift values upward and start again.');endcurrentTrace=initTrace;allTraces='';gammas='';alphas='';rhos='';smoothLikes='';mainLikes='';mytolx = '';% if isempty(errorLogFile)%     errorLogFile=['/u/jenn/phd/MS/matlabCode/workspaces/trainFBHMM_' filenameStamp '.LOG'];% endif ~isfield(G,'class')    error('G does not have the class variable in it');enderrorFound=0;%% file to log errors and messagesif ~isempty(errorLogFile)    cmd=['[fidErr,message]=fopen(errorLogFile,' '''a''' ');']    eval(cmd);    if (fidErr==-1)        error(['Unable to open file: ' errorLogFile]);    endelse    fidErr=0;endif ~isfield(G,'sigmas')        %disp('calling getInitSigmas');    [G.sigmas,G.varsigma,G.minSigma] = ...        getInitSigmas(G,samplesMat);end%figure,show(G.sigmas); colorbar; keyboard;printRunInfo(G,fidErr);%minSigmaUpdateIt=G.minSigmaUpdateIt;elapsed = zeros(1,G.maxIter);likes = zeros(1,G.maxIter);smoothLikes=zeros(G.maxIter,G.numClass);nuTerm=zeros(1,G.maxIter);timePriorTerm=zeros(G.maxIter,G.numSamples);scalePriorTerm=zeros(1,G.maxIter);mainLikes=zeros(G.maxIter,G.numSamples);oldLike=-Inf;numIt=0; keepGoing=1;M=G.numTaus;C = G.numClass;allTraces = zeros(G.maxIter, G.numTaus,G.numClass,G.numBins);%% scale vealue in real space corresponding to each statescalesExp=(2.^G.scales((G.stateToScaleTau(:,1))));scalesExp=scalesExp(:)';scalesExp2=(2.^G.scales((G.stateToScaleTau(:,1))));scalesExp2=scalesExp2(:)';oldU = G.u;%uCoeff = zeros(G.maxIter,G.numSamples,4);%myfval = zeros(G.maxIter,G.numSamples);%% precompute valid states for each time step (for updating Z)%% this never changes[allValidStates,scaleFacs,scaleFacsSq] = getValidStates(G);G.Jsparse = getJacobPattern(G);%keepGoing=(numIt<G.maxIter);keepGoing=1;if any(G.sigmas(:)==0)    warning('some sigmas=0');    keyboard;endwhile keepGoing    %tic;    firstTime = cputime;    if ~(exist('lastTime','var'))        lastTime=firstTime;    end    numIt = numIt+1;        if (mod(numIt,1)==0)        tmpStr=['\nIteration: ' num2str(numIt)];        myPrint(fidErr,tmpStr);    end    %% this is to keep track of old parameter values so that    %% we can check convergence based on changing parameters    oldG.sigmas = G.sigmas;    oldG.D = G.D;    oldG.S = G.S;    oldG.u = G.u;    oldG.z = currentTrace;    G.z=currentTrace;    ubar2 = getUbar2(G,G.u);    %% (RE-)INITIALIZE M-Step stuff    if G.updateZ || G.updateU        gammaSum1= cell(G.numSamples,M);        gammaSum2= cell(G.numBins,G.numSamples,M);        gammaSum5= sparse(G.numSamples,G.numStates);        %gammaSum6= sparse(G.numSamples,G.numStates);        gammaSum6 = cell(1,G.numBins);        for bb=1:G.numBins            gammaSum6{bb}=sparse(G.numSamples,G.numStates);        end    end    %% For updateSigma    newSigmas=zeros(size(G.sigmas));    %% For updateT    newD = zeros(G.numSamples,G.maxTimeSteps);    %% For updateScale    newScaleCounts = zeros(2,C);    %% For updateU        gammaSum3 = zeros(G.numBins,G.numSamples,G.numStates);    gammaSum4 = zeros(G.numBins,G.numSamples,G.numStates);        %% precompute this    %tempSigs = G.sigmas.^(-2);    %% The penalty terms of the likelihood should be computed    %% before we update the parameters.  Then, after iterating    %% through all of the samples and doing forward-backward,    %% we will have the 'normal' part of the likelihood to add to    %% these term.    smoothLikes(numIt,:) = getSmoothLike(G,G.z,G.u);    [timePriorTerm(numIt,:), scalePriorTerm(numIt)] = ...        getDirichletLike(G);    nuTerm(numIt) = getNuTerm(G);    scaleCenterPriorTerm = getScaleLike(G,G.u);    %tmp = ['nuTerm=' num2str(nuTerm(numIt),3)];    %tmp = [tmp '   ' 'lambdaTerm=' num2str(sum(smoothLikes(numIt,:)),3)];    %disp(tmp);    %% Now iterate through the samples, calculating the posterior    %% over hidden states and using these posteriors    %% in whatever computations we need, before throwing them out    %% to do the next sample.    %allGammas= cell(1,G.numSamples);% each G.numStates x G.numRealTimes    for kk=1:G.numSamples        %% E-step (obtaining gammas), using forward-backward in        %% standard way, with scaling tricks (not in log space)        myClass = getClass(G,kk);        clear alphas betas gammas rhos;        doback=1;  %% do backward pass                tmpZ = permute(G.z(:,myClass,:),[1 3 2]);                [mainLikes(numIt,kk),alphas,betas,rhos,FBerrorFlag]=...            FB(G,samplesMat(:,:,kk),kk,tmpZ,doback);        lastTimeLast = lastTime;        lastTime=cputime;        totalElapsed(numIt)=(lastTime-firstTime)/60;        elapsed(numIt)=(lastTime-lastTimeLast)/60;               msg=sprintf(['   Time series %d) E-step CPU Time: %.2f minutes'],...            kk,elapsed(numIt));        myPrint(fidErr,msg);               if FBerrorFlag==1            myStr=sprintf('ERROR: FBerrorFlag==1, means some rho==0, probably due to fixed, small sigma during cross validation on the hold out set calculations')            if fidErr fprintf(fidErr,'%s\n',myStr); end;            keyboard;            %return;        end        gammas = alphas.*betas;        %max(abs(sum(full(gammas),1)-1))        %allGammas{kk}=gammas;        %if (any(gammas(:)==0))        %  disp('some gammas=0');        %  [length(find(gammas(:)==0)) prod(size(gammas))]        %keyboard;        %end        %maxDiff(kk,numIt)=max(abs(sum(full(gammas),1)-1));        if max(abs(sum(full(gammas),1)-1))>1e-4%100*eps)            disp('trainFBHMM: gammas not exactly equal to 1');            disp('probably lambda is too small');            %maxDiff(kk,numIt)            errorFlag=1;            save gammaBug.mat            max(abs(sum(full(gammas),1)-1))            length(find(isinf(gammas)))            length(find(isinf(alphas)))            length(find(isinf(betas)))                        %return;            keyboard;        end        %% now use alphas, betas, gammas and rhos for this sample        %%and then throw them out to keep memory usage manageable.        %%Also, be careful to do updates in the right order.        %%update rules that are co-dependent need to be done in        %%order, with later ones using the latest updates.        %        %%Current order is:        %%sigma, z, u  (the others don't matter - they are independent)        if G.updateSigma | G.updateZ | G.updateU            clear tmpDat repdata;            repdata = repmat(samplesMat(:,:,kk)',[1 1 G.numStates]);        end        %% update sigma right here using gammas, since we use the        %% current value of u_k and z, this needs to be done before updating        %% u_k's, and z        if G.updateSigma             %% iterate over bins because repmat doesn't work with            %%sparse matrixes            %newSigmas = G.sigmas;            for bb=1:G.numBins                               if ~G.USE_CPM2                    augmentZ = G.u(kk)*G.z(G.stateToScaleTau(:,2),myClass,bb)';                else                                       uMat = G.uMat(kk,:);                    augmentZ = uMat.*G.z(G.stateToScaleTau(:,2),myClass,bb)';                end                augmentZ = augmentZ.*scalesExp;                augmentZ = repmat(augmentZ,[G.numRealTimes 1])';                                tmpDat = permute(repdata(:,bb,:),[3 1 2]);                                newSigmas(bb,kk) = sum(sum(gammas.*((tmpDat-augmentZ).^2)));                          end                                    if isnan(newSigmas(kk))                error(['ERROR: newSigmas(' num2str(kk) ')' '=nan']);                numIt=numIt+1;% bit of a hack to keep results even if                %keepGoing=0;                keyboard;            end        end        %% gather data we will need to update z        if G.updateZ || G.updateU            %% iterate over bins because repmat doesn't work with            %%sparse matrixes, and sparse matrixes can't be multiDim                        gammaSum5(kk,:) = sum(gammas,2)'; % same for each bb            for bb=1:G.numBins                tmpGS6 = gammaSum6{bb};                tmpDat = permute(repdata(:,bb,:),[3 1 2]);                

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -