📄 hrest.c
字号:
} else { if (t==1) Lr = hmm->transP[1][j]; else { Lr = LZERO; for (i=2; i<nStates; i++) if ((a_ij=hmm->transP[i][j]) > LSMALL) Lr = LAdd(Lr,alpha[i][t-1]+a_ij); } if (Lr>LSMALL) { Lr += mixp_j[t][s][m] + w + betaj[t] - pr; if (nStreams>1) { /* add contrib of parallel streams */ strpt = stroutp[t][j]; for (ss=1; ss<=nStreams; ss++) if (ss!=s) Lr += strpt[ss]; } } } if (Lr > MINEARG) { y = exp(Lr); /* Update Weight Counter */ if (uFlags&UPMIXES) { idx = (hsKind==DISCRETEHS) ? obs.vq[s] : m; wa->occ += y; wa->c[idx] += y; } /* Update Mean Counter */ if (uFlags&UPMEANS){ ma->occ += y; for (k=1; k<=vSize; k++) ma->mu[k] += zot[k]*y; /* sum zero mean */ } /* Update Covariance Counter */ if (uFlags&UPVARS){ va->occ += y; if (mpdf->ckind==DIAGC){ for (k=1; k<=vSize; k++) va->cov.var[k] += zot[k] * zot[k] * y; } else{ for (k=1; k<=vSize; k++) for (l=1; l<=k; l++) va->cov.inv[k][l] += zot[k] * zot[l] * y; } } } } } if ((trace&(T_MAC|T_VAC))&&(hsKind!=DISCRETEHS)) { ShowSegNum(seg); printf("State %d, Stream %d, Mixture %d\n",j,s,m); if (trace&T_MAC){ printf("MEAN OCC: %.2f\n",ma->occ); ShowVector("MEAN ACC: ",ma->mu,10); } if (trace&T_VAC) { if (mpdf->ckind==DIAGC){ printf("VAR OCC: %.2f\n",va->occ); ShowVector("VAR ACC: ",va->cov.var,10); } else { printf("INV OCC: %.2f\n",va->occ); ShowMatrix("INV ACC: ",va->cov.inv,10,10); } } } } if (trace&T_WAC){ ShowSegNum(seg); printf("State %d, Stream %d\n",j,s); printf("WT OCC: %.2f\n",wa->occ); ShowVector("WT ACC: ",wa->c,10); } } /* UpPDFCounts: update output PDF counts for each stream of each state */void UpPDFCounts(LogDouble pr, int seg){ int j,s; StateInfo *si; StreamElem *se; DVector alj,betj; for (j=2; j<nStates; j++) { si = hmm->svec[j].info; alj = alpha[j]; betj = beta[j]; for (s=1,se = si->pdf+1; s<=nStreams; s++,se++) UpStreamCounts(j,s,se,hset.swidth[s],pr,seg,alj,betj); }}/* UpdateCounters: update the various counters */void UpdateCounters(LogDouble pr, int seg){ SetOccr(pr,seg); if (uFlags&UPTRANS) UpTranCounts(pr,seg); if (uFlags&(UPMEANS|UPVARS|UPMIXES)) UpPDFCounts(pr,seg);}/* ------------------------- Model Update ----------------------- *//* RestTransP: reestimate transition probs */void RestTransP(void){ int i,j; float occi,x,sum; TrAcc *ta; ta = (TrAcc *) GetHook(hmm->transP); for (i=1;i<nStates;i++) { hmm->transP[i][1] = LZERO; occi = ta->occ[i]; if (occi == 0.0) HError(2222,"RestTransP: Zero state %d occupation count",i); sum = 0.0; for (j=2;j<=nStates;j++) { x = ta->tran[i][j]/occi; hmm->transP[i][j] = x; sum += x; } for (j=2;j<=nStates;j++) { x = hmm->transP[i][j]/sum; hmm->transP[i][j] = (x<MINLARG) ? LZERO : log(x); } } if (trace & T_TRE) ShowMatrix("NEW TRANS: ",hmm->transP,10,10);}/* FloorMixes: apply floor to given mix set */void FloorMixes(MixtureElem *mixes, int M, float floor){ float sum,fsum,scale; MixtureElem *me; int m; sum = fsum = 0.0; for (m=1,me=mixes; m<=M; m++,me++) { if (me->weight>floor) sum += me->weight; else { fsum += floor; me->weight = floor; } } if (fsum>1.0) HError(2223,"FloorMixes: Floor sum too large"); scale = (1.0-fsum)/sum; if (trace&T_WRE) printf("MIXW: "); for (m=1,me=mixes; m<=M; m++,me++){ if (me->weight>floor) me->weight *= scale; if (trace&T_WRE) printf(" %.2f",me->weight); } if (trace&T_WRE) printf("\n");} /* FloorTMMixes: apply floor to given tied mix set */void FloorTMMixes(Vector mixes, int M, float floor){ float sum,fsum,scale,fltWt; int m; sum = fsum = 0.0; for (m=1; m<=M; m++) { fltWt = mixes[m]; if (fltWt>floor) sum += fltWt; else { fsum += floor; mixes[m] = floor; } } if (fsum>1.0) HError(2223,"FloorTMMixes: Floor sum too large"); scale = (1.0-fsum)/sum; if (trace&T_WRE) printf("MIXW: "); for (m=1; m<=M; m++){ fltWt = mixes[m]; if (fltWt>floor) mixes[m] = fltWt*scale; if (trace&T_WRE) printf(" %.2f",fltWt); }}/* FloorDProbs: apply floor to given discrete prob set */void FloorDProbs(ShortVec mixes, int M, float floor){ float sum,fsum,scale,fltWt; int m; sum = fsum = 0.0; for (m=1; m<=M; m++) { fltWt = Short2DProb(mixes[m]); if (fltWt>floor) sum += fltWt; else { fsum += floor; mixes[m] = DProb2Short(floor); } } if (fsum>1.0) HError(2327,"FloorDProbs: Floor sum too large"); if (fsum == 0.0) return; if (sum == 0.0) HError(2328,"FloorDProbs: No probabilities above floor"); scale = (1.0-fsum)/sum; for (m=1; m<=M; m++){ fltWt = Short2DProb(mixes[m]); if (fltWt>floor) mixes[m] = DProb2Short(fltWt*scale); }}/* RestMixWeights: reestimate the mixture weights */void RestMixWeights(int state, int s, StreamElem *se){ WtAcc *wa; int m,M=0; float x; MixtureElem *me; wa = (WtAcc *)se->hook; if (wa->occ == 0.0) HError(2222,"RestMixWeights: Zero weight occupation count"); switch (hsKind){ case TIEDHS: M=hset.tmRecs[s].nMix; break; case PLAINHS: case SHAREDHS: case DISCRETEHS: M=se->nMix; break; } for (m=1; m<=M; m++){ x = wa->c[m]/wa->occ; if (x>1.0) HError(2290,"RestMixWeights: Mix wt>1 in %d.%d.%d",state,s,m); switch (hsKind){ case DISCRETEHS: se->spdf.dpdf[m] = (x>MINMIX) ? DProb2Short(x) : DLOGZERO; break; case TIEDHS: se->spdf.tpdf[m] = (x>MINMIX) ? x : 0.0; break; case PLAINHS: case SHAREDHS: me=se->spdf.cpdf+m; me->weight = (x>MINMIX) ? x : 0.0; break; } }}/* RestMean: reestimate the given mean vector */void RestMean(Vector mean, int vSize){ int k; MuAcc *ma; float x; ma = (MuAcc *)GetHook(mean); if (ma->occ == 0.0) HError(2222,"RestMean: Zero mean occupation count"); for (k=1; k<=vSize; k++){ x = mean[k] + ma->mu[k]/ma->occ; ma->mu[k] = mean[k]; /* remember old mean */ mean[k] = x; } if (trace&T_MRE) ShowVector("MEAN: ",mean,10);}/* RestCoVar: reestimate the given covariance and return FALSE if any diagonal component == 0.0 */Boolean RestCoVar(MixPDF *mp, int vSize, Vector minV, Vector oldMean, Vector newMean, Boolean shared){ int k,l; VaAcc *va; float x,z; float muDiffk,muDiffl; va = (VaAcc *)GetHook(mp->cov.var); if (va->occ == 0.0) HError(2222,"RestCoVar: Zero variance occupation count"); if (mp->ckind==DIAGC){ for (k=1; k<=vSize; k++){ muDiffk = (shared)?0.0:newMean[k]-oldMean[k]; x = va->cov.var[k] / va->occ - muDiffk*muDiffk; if (x<minV[k]) x = minV[k]; if (x<vDefunct) return FALSE; mp->cov.var[k] = x; } FixDiagGConst(mp); if (trace&T_VRE) ShowVector("VARS: ",mp->cov.var,10); } else { for (k=1; k<=vSize; k++){ muDiffk = (shared)?0.0:newMean[k]-oldMean[k]; for (l=1; l<k; l++) { muDiffl = (shared)?0.0:newMean[l]-oldMean[l]; x = va->cov.inv[k][l] / va->occ - muDiffk*muDiffl; mp->cov.inv[k][l] = x; } z = va->cov.inv[k][k]/va->occ - muDiffk*muDiffk; mp->cov.inv[k][k] = (z<minV[k])?minV[k]:z; } if (trace&T_VRE) ShowTriMat("COVM: ",mp->cov.inv,10,10); FixFullGConst(mp,CovInvert(mp->cov.inv,mp->cov.inv)); } return TRUE;}/* RestStream: reestimate stream parameters */void RestStream(int state, int s, StreamElem *se, int vSize){ int m,M; MixtureElem *me; MixPDF *mp; MuAcc *ma; Boolean shared; float wght; if (trace&(T_WRE|T_MRE|T_VRE)) printf("State %d, Stream %d\n",state,s); if (uFlags&UPMIXES) RestMixWeights(state,s,se); if ((hsKind != DISCRETEHS)&&(hsKind != TIEDHS)){ /*wts only DI'ETE & TIED*/ M=se->nMix; for (m=1; m<=M; m++){ me = se->spdf.cpdf+m; wght=me->weight; mp=me->mpdf; if (wght > MINMIX) { if (trace&(T_MRE|T_VRE) && M>1) printf("Mixture %d\n",m); if (uFlags&UPMEANS) RestMean(mp->mean,vSize); /* NB old mean left in ma->mu */ if (uFlags&UPVARS){ shared = GetUse(mp->cov.var) > 1; ma = (MuAcc *)GetHook(mp->mean); if ( !RestCoVar(mp,vSize,vFloor[s],ma->mu,mp->mean,shared)) { if (M > 1) { HError(-2225,"RestStream: Defunct Mix %d.%d.%d",state,s,m); me->weight = 0.0; } else HError(2222,"RestStream: Zero Covariance in %d.%d",state,s); } } } } } if (hsKind == TIEDHS) M=hset.tmRecs[s].nMix; else M=se->nMix; if (M>1){ switch (hsKind){ case DISCRETEHS: FloorDProbs(se->spdf.dpdf,M,mixWeightFloor); break; case TIEDHS: FloorTMMixes(se->spdf.tpdf,M,mixWeightFloor); break; case PLAINHS: case SHAREDHS: FloorMixes(se->spdf.cpdf+1,M,mixWeightFloor); break; } }}/* UpdateTheModel: use accumulated statistics to update model */void UpdateTheModel(void){ int j,s; StateInfo *si; StreamElem *se; if (uFlags&UPTRANS) RestTransP(); if (uFlags&(UPMEANS|UPVARS|UPMIXES)) for (j=2; j<nStates; j++) { si = hmm->svec[j].info; for (s=1,se = si->pdf+1; s<=nStreams; s++,se++) RestStream(j,s,se,hset.swidth[s]); } if (uFlags&UPVARS) FixAllGConsts(&hset);}/* ------------------------- Top Level Control ----------------------- *//* ReEstimateModel: top level of algorithm */void ReEstimateModel(void){ LogFloat segProb,oldP,newP,delta; LogDouble ap,bp; int converged,iteration,seg; iteration=0; oldP=LZERO; do { /*main re-est loop*/ ZeroAccs(&hset, uFlags); newP = 0.0; ++iteration; nTokUsed = 0; for (seg=1;seg<=nSeg;seg++) { T=SegLength(segStore,seg); SetOutP(seg); if ((ap=SetAlpha(seg)) > LSMALL){ bp = SetBeta(seg); if (trace & T_LGP) printf("%d. Pa = %e, Pb = %e, Diff = %e\n",seg,ap,bp,ap-bp); segProb = (ap + bp) / 2.0; /* reduce numeric error */ newP += segProb; ++nTokUsed; UpdateCounters(segProb,seg); } else if (trace&T_TOP) printf("Example %d skipped\n",seg); } if (nTokUsed==0) HError(2226,"ReEstimateModel: No Usable Training Examples"); UpdateTheModel(); newP /= nTokUsed; delta=newP-oldP; oldP=newP; converged=(fabs(delta)<epsilon); if (trace&T_TOP) { printf("Ave LogProb at iter %d = %10.5f using %d examples", iteration,oldP,nTokUsed); if (iteration > 1) printf(" change = %10.5f",delta); printf("\n"); fflush(stdout); } } while ((iteration < maxIter) && !converged); if (trace&T_TOP) { if (converged) printf("Estimation converged at iteration %d\n",iteration); else printf("Estimation aborted at iteration %d\n",iteration); fflush(stdout); }}/* ----------------------------------------------------------- *//* END: HRest.c *//* ----------------------------------------------------------- */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -