⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 hfb.c

📁 隐马尔科夫模型工具箱
💻 C
📖 第 1 页 / 共 4 页
字号:
   float zmean,zmeanlr,zmean2,tmp;   double Lr,steSumLr;   HMMSet *hset;   HSetKind hsKind;   AlphaBeta *ab;   StreamElem *ste;   MixtureElem *me;   MixPDF *mp;   MuAcc *ma;   VaAcc *va;   WtAcc *wa = NULL;   PreComp *pMix;   Boolean mmix=FALSE;  /* TRUE if multiple mixture */   float wght;   /* variables for 2-model reestimation */   Vector comp_prob;             /* array[1..M] of Component probability */   float norm;                   /* total mixture prob */   ab     = fbInfo->ab;   hset   = fbInfo->up_hset;   hsKind = fbInfo->hsKind;   if (trace&T_MIX && fbInfo->uFlags&UPMIXES &&        NonSkipRegion(fbInfo->skipstart,fbInfo->skipend,t)){      printf("Mixture Weights at time %d, model Q%d %s\n",             t,q,ab->qIds[q]->name);   }   if (fbInfo->twoModels)       comp_prob = CreateVector(&gstack,fbInfo->maxM);   N = hmm->numStates;   for (j=2;j<N;j++) {      if (fbInfo->maxM>1){         initx = hmm->transP[1][j] + aqt[1];         if (t>1)            for (i=2;i<N;i++){               a = hmm->transP[i][j];               if (a>LSMALL)                  initx = LAdd(initx,aqt1[i]+a);            }         initx += bqt[j] - pr;      }      if (trace&T_MIX && fbInfo->uFlags&UPMIXES &&           NonSkipRegion(fbInfo->skipstart,fbInfo->skipend,t))         printf("  State %d: ",j);      ste = hmm->svec[j].info->pdf+1;      outprob = ab->otprob[t][q][j];      for (s=1;s<=S;s++,ste++){         /* Get observation vector for this state/stream */         vSize = hset->swidth[s];         otvs = ot.fv[s];               switch (hsKind){         case TIEDHS:             /* if tied mixtures then we only */            tmRec = &(hset->tmRecs[s]); /* want to process the non-pruned */            M = tmRec->topM;            /* components */            mmix = TRUE;            break;         case DISCRETEHS:            M = 1;            mmix = FALSE;            break;         case PLAINHS:         case SHAREDHS:            M = ste->nMix;            mmix = (M>1);            break;         }         /* update weight occupation count */         wa = (WtAcc *) ste->hook; steSumLr = 0.0;         if (fbInfo->twoModels) { /* component probs of update hmm */             norm = LZERO;             for (mx=1; mx<=M; mx++) {                 me = ste->spdf.cpdf+mx;	mp=me->mpdf;                 wght = me->weight;                 comp_prob[mx]=log(wght)+MOutP(otvs,mp);                 norm = LAdd(norm,comp_prob[mx]);             }         }               for (mx=1;mx<=M;mx++) {             /* process mixtures */            switch (hsKind){    /* Get wght and mpdf */            case TIEDHS:               m=tmRec->probs[mx].index;               wght=ste->spdf.tpdf[m];                         mp=tmRec->mixes[m];               break;            case DISCRETEHS:               if (twoDataFiles)                  m=ot2.vq[s];               else                  m=ot.vq[s];               wght = 1.0;               mp=NULL;               break;            case PLAINHS:            case SHAREDHS:               m = mx;               me = ste->spdf.cpdf+m;               wght = MixWeight(hset,me->weight);               mp=me->mpdf;               break;            }            if (wght>MINMIX){               /* compute mixture likelihood  */               if (!mmix || (hsKind==DISCRETEHS)) /* For DISCRETEHS calcs are*/                  x = aqt[j]+bqt[j]-pr;           /* same as single mix*/               else if (fbInfo->twoModels) {      /* note: only SHAREDHS or PLAINHS */                  x = comp_prob[m]+aqt[j]+bqt[j]-pr-norm;               }               else {                  c_jm=log(wght);                  x = initx+c_jm;                  switch(hsKind) {                  case TIEDHS :                     tmp = tmRec->probs[mx].prob;                     prob = (tmp>=MINLARG)?log(tmp)+tmRec->maxP:LZERO;                     break;                  case SHAREDHS :                      pMix = (PreComp *)mp->hook;                     if (pMix->time==t)                        prob = pMix->prob;                     else {                        prob = MOutP(otvs,mp);                        pMix->prob = prob; pMix->time = t;                     }                     break;                  case PLAINHS :                      prob=MOutP(otvs,mp);                     break;                  default:                     x=LZERO;                     break;                  }                  x += prob;                  if (S>1)      /* adjust for parallel streams */                     x += outprob[s];               }               if (twoDataFiles){  /* switch to new data for mu & var est */                  otvs = ot2.fv[s];               }               if (-x<pruneSetting.minFrwdP) {                  Lr = exp(x);                  /* More diagnostics */                  /* if (Lr>0.000001 && ab->occt[j]>0.000001 &&                     (Lr/ab->occt[j])>1.00001)                     printf("Too big %d %d %s : %5.3f %10.2f %8.2f (%4.2f)\n",t,q,                     ab->qIds[q]->name,Lr/ab->occt[j],Lr,ab->occt[j],prob); */                              /* update occupation counts */                  steSumLr += Lr;                  /* update the adaptation statistic counts */                  if (fbInfo->uFlags&UPADAPT)                     AccAdaptFrame(Lr, otvs, mp, fbInfo->rt);                  /* update mean counts */                  if ((fbInfo->uFlags&UPMEANS) || (fbInfo->uFlags&UPVARS))                     mean = mp->mean;                   if ((fbInfo->uFlags&UPMEANS) && (fbInfo->uFlags&UPVARS)) {                     ma = (MuAcc *) GetHook(mean);                     va = (VaAcc *) GetHook(mp->cov.var);                     ma->occ += Lr;                     va->occ += Lr;                     mu_jm = ma->mu;                     if ((mp->ckind==DIAGC)||(mp->ckind==INVDIAGC)){                        var = va->cov.var;                        for (k=1;k<=vSize;k++) {                           zmean=otvs[k]-mean[k];                           zmeanlr=zmean*Lr;                           mu_jm[k] += zmeanlr;                           var[k] += zmean*zmeanlr;                        }                     } else {                        inv = va->cov.inv;                        for (k=1;k<=vSize;k++) {                           invk = inv[k];                           zmean=otvs[k]-mean[k];                           zmeanlr=zmean*Lr;                           mu_jm[k] += zmeanlr;                           for (kk=1;kk<=k;kk++) {                              zmean2 = otvs[kk]-mean[kk];                              invk[kk] += zmean2*zmeanlr;                           }                        }                     }                  }                  else if (fbInfo->uFlags&UPMEANS){                     ma = (MuAcc *) GetHook(mean);                     mu_jm = ma->mu;                     ma->occ += Lr;                     for (k=1;k<=vSize;k++)     /* sum zero mean */                        mu_jm[k] += (otvs[k]-mean[k])*Lr;                  }                  else if (fbInfo->uFlags&UPVARS){                     /* update covariance counts */                     va = (VaAcc *) GetHook(mp->cov.var);                     va->occ += Lr;                     if ((mp->ckind==DIAGC)||(mp->ckind==INVDIAGC)){                        var = va->cov.var;                        for (k=1;k<=vSize;k++) {                           zmean=otvs[k]-mean[k];                           var[k] += zmean*zmean*Lr;                        }                     } else {                        inv = va->cov.inv;                        for (k=1;k<=vSize;k++) {                           invk = inv[k];                           zmean=otvs[k]-mean[k];                                          for (kk=1;kk<=k;kk++) {                              zmean2 = otvs[kk]-mean[kk];                              invk[kk] += zmean*zmean2*Lr;                           }                        }                     }                  }                              /* update mixture weight counts */                  if (fbInfo->uFlags&UPMIXES) {                     wa->c[m] +=Lr;                     if (trace&T_MIX && NonSkipRegion(fbInfo->skipstart,fbInfo->skipend,t))                        printf("%3d. %7.2f",m,wa->c[m]);                  }               }            }            if (twoDataFiles){ /* Switch back to old data for prob calc */               otvs = ot.fv[s];            }             }               wa->occ += steSumLr;         if (trace&T_MIX && mmix && fbInfo->uFlags&UPMIXES &&              NonSkipRegion(fbInfo->skipstart,fbInfo->skipend,t))            printf("[%7.2f]\n",wa->occ);      }   }   if (fbInfo->twoModels)       FreeVector(&gstack,comp_prob);}/* -------------------- Top Level of F-B Updating ---------------- *//* StepForward: Step from 1 to T calc'ing Alpha columns and    accumulating statistic */static void StepForward(FBInfo *fbInfo, UttInfo *utt){   int q,t,start,end,negs;   DVector aqt,aqt1,bqt,bqt1,bq1t;   HLink al_hmm, up_hmm;   AlphaBeta *ab;   /* reset the memory heap for alpha for a new utterance */   /* ResetHeap(&(fbMemInfo.alphaStack)); */     ab = fbInfo->ab;   CreateAlpha(ab,fbInfo->al_hset,utt->Q); /* al_hset may be idential to up_hset */   InitAlpha(ab,&start,&end,utt->Q,fbInfo->skipstart,fbInfo->skipend);   ab->occa = NULL;   if (trace&T_OCC)       CreateTraceOcc(ab,utt);   for (q=1;q<=utt->Q;q++){             /* inc access counters */      up_hmm = ab->up_qList[q];      negs = (int)up_hmm->hook+1;      up_hmm->hook = (void *)negs;   }   for (t=1;t<=utt->T;t++) {      GetInputObs(utt, t, fbInfo->hsKind);      if (fbInfo->hsKind == TIEDHS)         PrecomputeTMix(fbInfo->al_hset,&(utt->ot),pruneSetting.minFrwdP,0);      if (t>1)         StepAlpha(ab,t,&start,&end,utt->Q,utt->T,utt->pr,                   fbInfo->skipstart,fbInfo->skipend);          if (trace&T_ALF && NonSkipRegion(fbInfo->skipstart,fbInfo->skipend,t))          TraceAlphaBeta(ab,t,start,end,utt->pr);          for (q=start;q<=end;q++) {          /* increment accs for each active model */         al_hmm = ab->al_qList[q];         up_hmm = ab->up_qList[q];         aqt = ab->alphat[q];         bqt = ab->beta[t][q];         bqt1 = (t==utt->T) ? NULL:ab->beta[t+1][q];         aqt1 = (t==1)      ? NULL:ab->alphat1[q];         bq1t = (q==utt->Q) ? NULL:ab->beta[t][q+1];         SetOcct(al_hmm,q,ab->occt,ab->occa,aqt,bqt,bq1t,utt->pr);         /* accumulate the statistics */         if (fbInfo->uFlags&(UPMEANS|UPVARS|UPMIXES|UPADAPT))            UpMixParms(fbInfo,q,up_hmm,utt->ot,utt->ot2,t,aqt,aqt1,bqt,                       utt->S, utt->twoDataFiles, utt->pr);         if (fbInfo->uFlags&UPTRANS)            UpTranParms(fbInfo,up_hmm,t,q,aqt,bqt,bqt1,bq1t,utt->pr);      }      if (trace&T_OCC && NonSkipRegion(fbInfo->skipstart,fbInfo->skipend,t))          TraceOcc(ab,utt,t);   }}/* load the labels into the UttInfo structure from file */void LoadLabs(UttInfo *utt, FileFormat lff, char * datafn,               char *labDir, char *labExt){   char labfn[255],buf1[255],buf2[255];   /* reset the heap for a new transcription */   ResetHeap(&utt->transStack);     MakeFN(datafn,labDir,labExt,labfn);   if (traceHFB || trace&T_TOP) {      printf(" Processing Data: %s; Label %s\n",             NameOf(datafn,buf1),NameOf(labfn,buf2));      fflush(stdout);   }   utt->tr = LOpen(&utt->transStack,labfn,lff);   utt->Q  = CountLabs(utt->tr->head);   if (utt->Q==0)      HError(-7325,"LoadUtterance: No labels in file %s",labfn);}/* load the data file(s) into the UttInfo structure */void LoadData(HMMSet *hset, UttInfo *utt, FileFormat dff,               char * datafn, char * datafn2){   BufferInfo info, info2;   int T2;   /* close any open buffers */   if (utt->pbuf != NULL) {      CloseBuffer(utt->pbuf);      if (utt->twoDataFiles)         CloseBuffer(utt->pbuf2);   }   /* reset the data stack for a new utterance */   ResetHeap(&utt->dataStack);   if (utt->twoDataFiles)      ResetHeap(&utt->dataStack2);   if (utt->twoDataFiles)      if(SetChannel("HPARM1")<SUCCESS)         HError(7350,"HFB: Channel parameters invalid");   if((utt->pbuf=OpenBuffer(&utt->dataStack,datafn,0,dff,                            FALSE_dup,FALSE_dup))==NULL)      HError(7350,"HFB: Config parameters invalid");   GetBufferInfo(utt->pbuf,&info);   if (utt->twoDataFiles){      if(SetChannel("HPARM2")<SUCCESS)         HError(7350,"HFB: Channel parameters invalid");           if((utt->pbuf2=OpenBuffer(&utt->dataStack2,datafn2,0,dff,                                FALSE_dup,FALSE_dup))==NULL)         HError(7350,"HFB: Config parameters invalid");      GetBufferInfo(utt->pbuf2,&info2);      CheckData(hset,datafn2,&info2,utt->twoDataFiles);      T2 = ObsInBuffer(utt->pbuf2);   }else      CheckData(hset,datafn,&info,utt->twoDataFiles);   utt->T = ObsInBuffer(utt->pbuf);   if (utt->twoDataFiles && (utt->T != T2))      HError(7326,"HFB: Paired training files must be same length for single pass retraining");    }/* Initialise the observation structures within UttInfo */void InitUttObservations(UttInfo *utt, HMMSet *al_hset,                          char * datafn, int * maxMixInS){   BufferInfo info, info2;   Boolean eSep;   int s, i;   if (utt->twoDataFiles)      if(SetChannel("HPARM1")<SUCCESS)         HError(7350,"HFB: Channel parameters invalid");   GetBufferInfo(utt->pbuf,&info);   if (utt->twoDataFiles){      if(SetChannel("HPARM2")<SUCCESS)         HError(7350,"HFB: Channel parameters invalid");      GetBufferInfo(utt->pbuf2,&info2);   }     SetStreamWidths(info.tgtPK,info.tgtVecSize,al_hset->swidth,&eSep);   utt->ot = MakeObservation(&gstack,al_hset->swidth,info.tgtPK,                             al_hset->hsKind==DISCRETEHS,eSep);   if (utt->twoDataFiles)       utt->ot2 = MakeObservation(&gstack,al_hset->swidth,info2.tgtPK,                                  al_hset->hsKind==DISCRETEHS,eSep);      if (al_hset->hsKind==DISCRETEHS){       for (i=0; i<utt->T; i++){         ReadAsTable(utt->pbuf,i,&utt->ot);         for (s=1; s<=utt->S; s++){             if( (utt->ot.vq[s] < 1) || (utt->ot.vq[s] > maxMixInS[s]))                 HError(7350,"LoadFile: Discrete data value [ %d ] out of range in seam [ %d ] in file %s",                        utt->ot.vq[s],s,datafn);         }      }   }   }/* FBFile: apply forward-backward to given utterance */Boolean FBFile(FBInfo *fbInfo, UttInfo *utt, char * datafn){   Boolean success;   if ((success = StepBack(fbInfo,utt,datafn)))      StepForward(fbInfo,utt);   ResetStacks(fbInfo->ab);   return success;}/* ----------------------------------------------------------- *//*                      END:  HFB.c                         *//* ----------------------------------------------------------- */

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -