📄 hfb.c
字号:
x = LAdd(x,wt+mixp); outprobjs[m] = mixp; } } } else if (!pde) { /* Multiple Mixture Case - no shared mix case */ x = LZERO; for (m=1;m<=M;m++,me++) { wt = MixLogWeight(hset,me->weight); if (wt>LMINMIX){ mp = me->mpdf; mixp = MOutP(ApplyCompFXForm(mp,v,xform,&det,t),mp); mixp += det; x = LAdd(x,wt+mixp); outprobjs[m] = mixp; } } } else { /* Partial distance elimination */ /* first Gaussian computed exactly in PDE */ wt = MixLogWeight(hset,me->weight); mp = me->mpdf; otvs = ApplyCompFXForm(mp,v,xform,&det,t); mixp = IDOutP(otvs,VectorSize(otvs),mp); /* INVDIAGC assumed */ mixp += det; x = wt+mixp; outprobjs[1] = mixp; for (m=2,me=ste->spdf.cpdf+2;m<=M;m++,me++) { wt = MixLogWeight(hset,me->weight); if (wt>LMINMIX){ mp = me->mpdf; otvs = ApplyCompFXForm(mp,v,xform,&det,t); if (PDEMOutP(otvs,mp,&mixp,x-wt-det) == TRUE) { mixp += det; x = LAdd(x,wt+mixp); } outprobjs[m] = mixp; /* LZERO if PDEMOutP returns FALSE */ } } } outprobjs[0] = x; wa->prob = outprobjs; wa->time = t; } return outprobjs;} /* Setotprob: allocate and calculate otprob matrix at time t */static void Setotprob(AlphaBeta *ab, FBInfo *fbInfo, ParmBuf pbuf, Observation ot, int t, int S, int qHi, int qLo){ int q,j,Nq,s; float ***outprob, **outprobj, *****otprob; StreamElem *ste; HLink hmm; LogFloat sum; PruneInfo *p; int skipstart, skipend; HMMSet *hset; hset = fbInfo->al_hset; skipstart = fbInfo->skipstart; skipend = fbInfo->skipend; p = ab->pInfo; otprob = ab->otprob; ReadAsTable(pbuf,t-1,&ot); if (hset->hsKind == TIEDHS) PrecomputeTMix(hset,&ot,pruneSetting.minFrwdP,0); if (trace&T_OUT && NonSkipRegion(skipstart,skipend,t)) printf(" Output Probs at time %d\n",t); if (qLo>1) --qLo; otprob[t] = CreateOqprob(&ab->abMem,qLo,qHi); for (q=qHi;q>=qLo;q--) { if (trace&T_OUT && NonSkipRegion(skipstart,skipend,t)) printf(" Q%2d: ",q); hmm = ab->al_qList[q]; Nq = hmm->numStates; if (otprob[t][q] == NULL) { outprob = otprob[t][q] = CreateOjsprob(&ab->abMem,Nq,S); for (j=2;j<Nq;j++){ ste=hmm->svec[j].info->pdf+1; sum = 0.0; outprobj = outprob[j]; for (s=1;s<=S;s++,ste++){ switch (hset->hsKind){ case TIEDHS: /* SOutP deals with tied mix calculation */ case DISCRETEHS: if (S==1) { outprobj[0] = NewOtprobVec(&ab->abMem,1); outprobj[0][0] = SOutP(hset,s,&ot,ste); } else { outprobj[s] = NewOtprobVec(&ab->abMem,1); outprobj[s][0] = SOutP(hset,s,&ot,ste); } break; /* Check that PLAINHS is handled correctly this way - efficient? */ case PLAINHS: case SHAREDHS: if (S==1) outprobj[0] = ShStrP(hset,ste,ot.fv[s],t,fbInfo->al_inXForm,&ab->abMem); else outprobj[s] = ShStrP(hset,ste,ot.fv[s],t,fbInfo->al_inXForm,&ab->abMem); break; default: if (S==1) outprobj[0] = NULL; else outprobj[s] = NULL; } if (S>1) sum += outprobj[s][0]; } if (S>1){ outprobj[0][0] = sum; for (s=1;s<=S;s++) outprobj[s][0] = sum - outprobj[s][0]; } if (trace&T_OUT && NonSkipRegion(skipstart,skipend,t)) { printf(" %d. ",j); PrLog(outprobj[0][0]); if (S>1){ printf("[ "); for (s=1;s<=S;s++) PrLog(outprobj[s][0]); printf("]"); } } } } if (trace&T_OUT && NonSkipRegion(skipstart,skipend,t)) printf("\n"); }}/* TraceAlphaBeta: print alpha/beta values at time t, also sum alpha/beta product across states at t-, t, and t+ */static void TraceAlphaBeta(AlphaBeta *ab, int t, int startq, int endq, LogDouble pr){ int i,q,Nq; DVector aqt,bqt; HLink hmm; double summ,sump,sum; printf("Alpha/Betas at time %d\n",t); summ = sump = sum = LZERO; for (q=startq; q<=endq; q++) { hmm = ab->al_qList[q]; Nq = hmm->numStates; printf(" Q%2d: %5s alpha beta\n", q,ab->qIds[q]->name); aqt = ab->alphat[q]; bqt = ab->beta[t][q]; for (i=1;i<=Nq;i++){ printf(" "); PrLog(aqt[i]); printf(" "); PrLog(bqt[i]); printf("\n"); } summ = LAdd(summ,aqt[1]+bqt[1]); for (i=2;i<Nq;i++) sum = LAdd(sum,aqt[i]+bqt[i]); sump = LAdd(sump,aqt[Nq]+bqt[Nq]); } printf(" Sums of Products: "); PrLog(summ-pr); printf("(-) "); PrLog(sum-pr); printf(" "); PrLog(sump-pr); printf("(+)\n");} /* SetBeamTaper: set beam start and end points according to the minimum duration of the models in the current sequence */static void SetBeamTaper(PruneInfo *p, short *qDms, int Q, int T){ int q,dq,i,t; /* Set leading taper */ q=1;dq=qDms[q];i=0; for (t=1;t<=T;t++) { while (i==dq) { i=0; if (q<Q) q++,dq=qDms[q]; else dq=-1; } p->qHi[t]=q; i++; } q=Q;dq=qDms[q];i=0; for (t=T;t>=1;t--) { while (i==dq) { i=0; if (q>1) q--,dq=qDms[q]; else dq=-1; } p->qLo[t]=q; i++; } /* if (trace>1) for (t=1;t<=T;t++) printf("%d: %d to %d\n",t,p->qLo[t],p->qHi[t]); exit(1);*/}/* SetBeta: allocate and calculate beta and otprob matrices */static LogDouble SetBeta(AlphaBeta *ab, FBInfo *fbInfo, UttInfo *utt){ ParmBuf pbuf; int i,j,t,q,Nq,lNq=0,q_at_gMax,startq,endq; int S, Q, T; DVector bqt=NULL,bqt1,bq1t1,maxP, **beta; float ***outprob; LogDouble x,y,gMax,lMax,a,a1N=0.0; HLink hmm; PruneInfo *p; int skipstart, skipend; HMMSet *hset; hset = fbInfo->al_hset; skipstart = fbInfo->skipstart; skipend = fbInfo->skipend; pbuf=utt->pbuf; S=utt->S; Q=utt->Q; T=utt->T; p=ab->pInfo; beta=ab->beta; maxP = CreateDVector(&gstack, Q); /* for calculating beam width */ /* Last Column t = T */ p->qHi[T] = Q; endq = p->qLo[T]; Setotprob(ab,fbInfo,pbuf,utt->ot,T,S,Q,endq); beta[T] = CreateBetaQ(&ab->abMem,endq,Q,Q); gMax = LZERO; q_at_gMax = 0; /* max value of beta at time T */ for (q=Q; q>=endq; q--){ hmm = ab->al_qList[q]; Nq = hmm->numStates; bqt = beta[T][q] = NewBetaVec(&ab->abMem,Nq); bqt[Nq] = (q==Q)?0.0:beta[T][q+1][lNq]+a1N; for (i=2;i<Nq;i++) bqt[i] = hmm->transP[i][Nq]+bqt[Nq]; outprob = ab->otprob[T][q]; x = LZERO; for (j=2; j<Nq; j++){ a = hmm->transP[1][j]; y = bqt[j]; if (a>LSMALL && y > LSMALL) x = LAdd(x,a+outprob[j][0][0]+y); } bqt[1] = x; lNq = Nq; a1N = hmm->transP[1][Nq]; if (x>gMax) { gMax = x; q_at_gMax = q; } } if (trace&T_PRU && NonSkipRegion(skipstart,skipend,T) && p->pruneThresh < NOPRUNE) printf("%d: Beta Beam %d->%d; gMax=%f at %d\n", T,p->qLo[T],p->qHi[T],gMax,q_at_gMax); /* Columns T-1 -> 1 */ for (t=T-1;t>=1;t--) { gMax = LZERO; q_at_gMax = 0; /* max value of beta at time t */ startq = p->qHi[t+1]; endq = (p->qLo[t+1]==1)?1:((p->qLo[t]>=p->qLo[t+1])?p->qLo[t]:p->qLo[t+1]-1); while (endq>1 && ab->qDms[endq-1]==0) endq--; /* start end-point at top of beta beam at t+1 */ /* unless this is outside the beam taper. */ /* + 1 to allow for state q+1[1] -> q[N] */ /* + 1 for each tee model preceding endq. */ Setotprob(ab,fbInfo,pbuf,utt->ot,t,S,startq,endq); beta[t] = CreateBetaQ(&ab->abMem,endq,startq,Q); for (q=startq;q>=endq;q--) { lMax = LZERO; /* max value of beta in model q */ hmm = ab->al_qList[q]; Nq = hmm->numStates; bqt = beta[t][q] = NewBetaVec(&ab->abMem,Nq); bqt1 = beta[t+1][q]; bq1t1 = (q==Q)?NULL:beta[t+1][q+1]; outprob = ab->otprob[t+1][q]; bqt[Nq] = (bq1t1==NULL)?LZERO:bq1t1[1]; if (q<startq && a1N>LSMALL) bqt[Nq]=LAdd(bqt[Nq],beta[t][q+1][lNq]+a1N); for (i=Nq-1;i>1;i--){ x = hmm->transP[i][Nq] + bqt[Nq]; if (q>=p->qLo[t+1]&&q<=p->qHi[t+1]) for (j=2;j<Nq;j++) { a = hmm->transP[i][j]; y = bqt1[j]; if (a>LSMALL && y>LSMALL) x = LAdd(x,a+outprob[j][0][0]+y); } bqt[i] = x; if (x>lMax) lMax = x; if (x>gMax) { gMax = x; q_at_gMax = q; } } outprob = ab->otprob[t][q]; x = LZERO; for (j=2; j<Nq; j++){ a = hmm->transP[1][j]; y = bqt[j]; if (a>LSMALL && y>LSMALL) x = LAdd(x,a+outprob[j][0][0]+y); } bqt[1] = x; maxP[q] = lMax; lNq = Nq; a1N = hmm->transP[1][Nq]; } while (gMax-maxP[startq] > p->pruneThresh) { beta[t][startq] = NULL; --startq; /* lower startq till thresh reached */ if (startq<1) HError(7323,"SetBeta: Beta prune failed sq < 1"); } while(p->qHi[t]<startq) { /* On taper */ beta[t][startq] = NULL; --startq; /* lower startq till thresh reached */ if (startq<1) HError(7323,"SetBeta: Beta prune failed on taper sq < 1"); } p->qHi[t] = startq; while (gMax-maxP[endq]>p->pruneThresh){ beta[t][endq] = NULL; ++endq; /* raise endq till thresh reached */ if (endq>startq) { return(LZERO); } } p->qLo[t] = endq; if (trace&T_PRU && NonSkipRegion(skipstart,skipend,t) && p->pruneThresh < NOPRUNE) printf("%d: Beta Beam %d->%d; gMax=%f at %d\n", t,p->qLo[t],p->qHi[t],gMax,q_at_gMax); } /* Finally, set total prob pr */ utt->pr = bqt[1]; if (utt->pr <= LSMALL) { return LZERO; } if (trace&T_TOP) { printf(" Utterance prob per frame = %e\n",utt->pr/T); fflush(stdout); } return utt->pr;}/* -------------------- Top Level of F-B Updating ---------------- *//* CheckData: check data file consistent with HMM definition */static void CheckData(HMMSet *hset, char *fn, BufferInfo *info, Boolean twoDataFiles) { if (info->tgtVecSize!=hset->vecSize) HError(7350,"CheckData: Vector size in %s[%d] is incompatible with hset [%d]", fn,info->tgtVecSize,hset->vecSize); if (!twoDataFiles){ if (info->tgtPK != hset->pkind) HError(7350,"CheckData: Parameterisation in %s is incompatible with hset", fn); }}/* ResetStacks: Reset all stacks used by StepBack function */static void ResetStacks(AlphaBeta *ab){ ResetHeap(&ab->abMem);}/* StepBack: Step utterance from T to 1 calculating Beta matrix*/static Boolean StepBack(FBInfo *fbInfo, UttInfo *utt, char * datafn){ LogDouble lbeta; LogDouble pruneThresh; AlphaBeta *ab; PruneInfo *p; int qt; ResetObsCache(); ab = fbInfo->ab; pruneThresh=pruneSetting.pruneInit; do { ResetStacks(ab); InitPruneStats(ab); p = fbInfo->ab->pInfo; p->pruneThresh = pruneThresh; qt=CreateInsts(fbInfo,ab,utt->Q,utt->tr); if (qt>utt->T) { if (trace&T_TOP) printf(" Unable to traverse %d states in %d frames\n",qt,utt->T); HError(-7324,"StepBack: File %s - bad data or over pruning\n",datafn); return FALSE; } CreateBeta(ab,utt->T); SetBeamTaper(p,ab->qDms,utt->Q,utt->T); CreateOtprob(ab,utt->T); lbeta=SetBeta(ab,fbInfo,utt); if (lbeta>LSMALL) break; pruneThresh+=pruneSetting.pruneInc; if (pruneThresh>pruneSetting.pruneLim || pruneSetting.pruneInc==0.0) { if (trace&T_TOP) printf(" No path found in beta pass\n"); HError(-7324,"StepBack: File %s - bad data or over pruning\n",datafn); return FALSE; } if (trace&T_TOP) { printf("Retrying Beta pass at %5.1f\n",pruneThresh); } } while(pruneThresh<=pruneSetting.pruneLim); if (lbeta<LSMALL) HError(7323,"StepBack: Beta prune error"); return TRUE;}/* ---------------------- Statistics Accumulation -------------------- *//* UpTranParms: update the transition counters of given hmm */static void UpTranParms(FBInfo *fbInfo, HLink hmm, int t, int q, DVector aqt, DVector bqt, DVector bqt1, DVector bq1t, LogDouble pr){ int i,j,N; Vector ti,ai; float ***outprob,***outprob1; double sum,x; TrAcc *ta; AlphaBeta *ab; N = hmm->numStates; ab = fbInfo->ab; ta = (TrAcc *) GetHook(hmm->transP); outprob = ab->otprob[t][q]; if (bqt1!=NULL) outprob1 = ab->otprob[t+1][q]; /* Bug fix */ else outprob1 = NULL; for (i=1;i<N;i++) ta->occ[i] += ab->occt[i]; for (i=1;i<N;i++) { ti = ta->tran[i]; ai = hmm->transP[i]; for (j=2;j<=N;j++) { if (i==1 && j<N) { /* entry transition */ x = aqt[1]+ai[j]+outprob[j][0][0]+bqt[j]-pr; if (x>MINEARG) ti[j] += exp(x); } else if (i>1 && j<N && bqt1!=NULL) { /* internal transition */ x = aqt[i]+ai[j]+outprob1[j][0][0]+bqt1[j]-pr; if (x>MINEARG) ti[j] += exp(x); } else if (i>1 && j==N) { /* exit transition */ x = aqt[i]+ai[N]+bqt[N]-pr; if (x>MINEARG) ti[N] += exp(x); } if (i==1 && j==N && ai[N]>LSMALL && bq1t != NULL){ /* tee transition */ x = aqt[1]+ai[N]+bq1t[1]-pr; if (x>MINEARG) ti[N] += exp(x); } } } if (trace&T_TRA && NonSkipRegion(fbInfo->skipstart,fbInfo->skipend,t)) { printf("Tran Counts at time %d, Model Q%d %s\n",t,q,ab->qIds[q]->name); for (i=1;i<=N;i++) { printf(" %d. Occ %8.2f: Trans ",i,ta->occ[i]); sum = 0.0; for (j=2; j<=N; j++) { x = ta->tran[i][j]; sum += x; printf("%7.2f ",x); } printf(" [%8.2f]\n",sum); } }}/* UpMixParms: update mu/va accs of given hmm */static void UpMixParms(FBInfo *fbInfo, int q, HLink hmm, HLink al_hmm, Observation ot, Observation ot2, int t, DVector aqt, DVector aqt1, DVector bqt, int S,
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -