⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 hrest.c

📁 HMM的另一种经典的训练算法,需要的快下啊
💻 C
📖 第 1 页 / 共 3 页
字号:
}/* LoadFile: load whole file or segments into segStore */void LoadFile(char *fn){   BufferInfo info;   char labfn[80];   Transcription *trans;   long segStIdx,segEnIdx;   static int segIdx=1;  /* Between call handle on latest seg in segStore */     static int prevSegIdx=1;   HTime tStart, tEnd;   int k,i,s,ncas,nObs,segLen;   LLink p;   Observation obs;   if((pbuf=OpenBuffer(&bufferStack, fn, 10, dff, FALSE_dup, FALSE_dup))==NULL)      HError(2250,"LoadFile: Config parameters invalid");   GetBufferInfo(pbuf,&info);   CheckData(fn,info);   if (firstTime) InitSegStore(&info);   if (segId == NULL)  {   /* load whole parameter file */      nObs = ObsInBuffer(pbuf);      tStart = 0.0;      tEnd = (info.tgtSampRate * nObs);      LoadSegment(segStore, tStart, tEnd, pbuf);      if (nObs > maxT)          maxT=nObs;       if (nObs < minT)         minT=nObs;            segIdx++;   }   else {                  /* load segment of parameter file */      MakeFN(fn,labDir,labExt,labfn);      trans = LOpen(&transStack,labfn,lff);      ncas = NumCases(trans->head,segId);            nObs = 0;      if ( ncas > 0) {         for (i=1,nObs=0; i<=ncas; i++) {            p = GetCase(trans->head,segId,i);            segStIdx= (long) (p->start/info.tgtSampRate);            segEnIdx  = (long) (p->end/info.tgtSampRate);            if (segEnIdx >= ObsInBuffer(pbuf))                segEnIdx = ObsInBuffer(pbuf)-1;            if (((segEnIdx - segStIdx + 1 >= nStates-2) || !segReject) 		&& (segStIdx <= segEnIdx)) {	/* skip short segments */               LoadSegment(segStore, p->start, p->end, pbuf);               if (trace&T_LD1)                  printf("  loading seg %s %f[%ld]->%f[%ld]\n",segId->name,                         p->start,segStIdx,p->end,segEnIdx);               segLen = SegLength(segStore, segIdx);               nObs += segLen;               if (segLen > maxT)                   maxT=segLen;                if (segLen < minT)                  minT=segLen;               segIdx++;            }else if (trace&T_LD1)               printf("   seg %s %f->%f ignored\n",segId->name,                      p->start,p->end);         }              }      }   if (hset.hsKind == DISCRETEHS){      for (k=prevSegIdx; k<segIdx; k++){         segLen = SegLength(segStore, k);         for (i=1; i<=segLen; i++){            obs = GetSegObs(segStore, k, i);            for (s=1; s<=nStreams; s++){               if( (obs.vq[s] < 1) || (obs.vq[s] > maxMixInS[s]))                  HError(2250,"LoadFile: Discrete data value [ %d ] out of range in stream [ %d ] in file %s",obs.vq[s],s,fn);            }         }      }      prevSegIdx=segIdx;   }   if (trace&T_LD0)      printf(" %d observations loaded from %s\n",nObs,fn);   CloseBuffer(pbuf);   ResetHeap(&transStack);}/* ------------------------ Trace Functions -------------------- *//* ShowSegNum: if not already printed, print seg number */void ShowSegNum(int seg){   static int lastseg = -1;      if (seg != lastseg){      printf("---- Training Segment %d [%3d frames] ----\n",seg,T);      lastseg = seg;   }}   /* ------------------------- Alpha-Beta ------------------------ *//* SetOutP: Set the output and mix prob matrices */                        void SetOutP(int seg){   int i,t,m,mx,s,nMix=0;   StreamElem *se;   MixtureElem *me;   StateInfo *si;   Matrix mixp;   LogFloat x,prob,streamP;   Vector strp = NULL;   Observation obs;   TMixRec *tmRec = NULL;   float wght=0.0,tmp;   MixPDF *mpdf=NULL;   PreComp *pMix;      for (t=1;t<=T;t++) {      obs = GetSegObs(segStore, seg, t);      if (hsKind == TIEDHS)         PrecomputeTMix(&hset,&obs,tMPruneThresh,0);               if ((maxMixes>1) && (hsKind!=DISCRETEHS)){ /* Multiple Mix Case */         for (i=2;i<nStates;i++) {            prob = 0.0;            si = hmm->svec[i].info;            se = si->pdf+1;             mixp = mixoutp[i][t];            if (nStreams>1) strp = stroutp[t][i];            for (s=1;s<=nStreams;s++,se++){               switch (hsKind){         /* Get nMix */               case TIEDHS:                  tmRec = &(hset.tmRecs[s]);                  nMix = tmRec->nMix;                  break;               case PLAINHS:               case SHAREDHS:                  nMix = se->nMix;                  break;               }               streamP = LZERO;               for (mx=1;mx<=nMix;mx++) {                  m=(hsKind==TIEDHS)?tmRec->probs[mx].index:mx;                  switch (hsKind){      /* Get wght and mpdf */                  case TIEDHS:                     wght=se->spdf.tpdf[m];                     mpdf=tmRec->mixes[m];                     break;                  case PLAINHS:                  case SHAREDHS:                     me = se->spdf.cpdf+m;                     wght=me->weight;                     mpdf=me->mpdf;                     break;                  }                  if (wght>MINMIX){                     switch(hsKind) { /* Get mixture prob */                     case TIEDHS:                        tmp = tmRec->probs[mx].prob;                        x = (tmp>=MINLARG)?log(tmp)+tmRec->maxP:LZERO;                        break;                     case SHAREDHS :                         pMix = (PreComp *)mpdf->hook;                        if (pMix->time==t)                           x = pMix->prob;                        else {                           x = MOutP(obs.fv[s],mpdf);                           pMix->prob = x; pMix->time = t;                        }                        break;                     case PLAINHS :                         x=MOutP(obs.fv[s],mpdf);                        break;                     default:                        x=LZERO;                        break;                     }                     mixp[s][m]=x;                     streamP = LAdd(streamP,log(wght)+x);                  } else                     mixp[s][m]=LZERO;               }                              if (nStreams>1)                  strp[s]=streamP;               prob += streamP; /* note stream weights ignored */            }               outprob[i][t]=prob;         }      } else          if (nStreams>1) {      /* Single Mixture multiple stream */            for (i=2;i<nStates;i++) {               prob = 0.0;               si = hmm->svec[i].info;               se = si->pdf+1;               strp = stroutp[t][i];               for (s=1;s<=nStreams;s++,se++){                  streamP = SOutP(&hset,s,&obs,se);                  strp[s] = streamP;                  prob += streamP; /* note stream weights ignored */               }               outprob[i][t]=prob;            }         } else                 /* Single Mixture - Single Stream */            for (i=2;i<nStates;i++){               si = hmm->svec[i].info;               se = si->pdf+1;               if (hsKind==DISCRETEHS)                  outprob[i][t]=SOutP(&hset,1,&obs,se);               else                  outprob[i][t]=OutP(&obs,hmm,i);            }   }   if (trace  & T_OTP) {      ShowSegNum(seg);      ShowMatrix("OutProb",outprob,10,12);   }}/* SetAlpha: compute alpha matrix and return prob of given sequence */LogDouble SetAlpha(int seg){   int i,j,t;   LogDouble x,a;   alpha[1][1] = 0.0;   for (j=2;j<nStates;j++) {              /* col 1 from entry state */      a=hmm->transP[1][j];      if (a<LSMALL)         alpha[j][1] = LZERO;      else         alpha[j][1] = a+outprob[j][1];   }   alpha[nStates][1] = LZERO;      for (t=2;t<=T;t++) {             /* cols 2 to T */      for (j=2;j<nStates;j++) {         x=LZERO ;         for (i=2;i<nStates;i++) {            a=hmm->transP[i][j];            if (a>LSMALL)               x = LAdd(x,alpha[i][t-1]+a);         }         alpha[j][t]=x+outprob[j][t];      }      alpha[1][t] = alpha[nStates][t] = LZERO;   }   x = LZERO ;                      /* finally calc seg prob */   for (i=2;i<nStates;i++) {      a=hmm->transP[i][nStates];      if (a>LSMALL)         x=LAdd(x,alpha[i][T]+a);    }     alpha[nStates][T] = x;      if (trace  & T_ALF) {      ShowSegNum(seg);      ShowDMatrix("Alpha",alpha,10,12);       printf("LogP= %10.3f\n\n",x);   }   return x;}/* SetBeta: compute beta matrix */LogDouble SetBeta(int seg){   int i,j,t;   LogDouble x,a;   beta[nStates][T] = 0.0;   for (i=2;i<nStates;i++)                /* Col T from exit state */      beta[i][T]=hmm->transP[i][nStates];   beta[1][T] = LZERO;   for (t=T-1;t>=1;t--) {           /* Col t from col t+1 */      for (i=1;i<=nStates;i++)         beta[i][t]=LZERO ;      for (j=2;j<nStates;j++) {         x=outprob[j][t+1]+beta[j][t+1];         for (i=2;i<nStates;i++) {            a=hmm->transP[i][j];            if (a>LSMALL)               beta[i][t]=LAdd(beta[i][t],x+a);         }      }   }   x=LZERO ;   for (j=2;j<nStates;j++) {      a=hmm->transP[1][j];      if (a>LSMALL)         x=LAdd(x,beta[j][1]+a+outprob[j][1]);    }   beta[1][1] = x;   if (trace & T_BET) {      ShowSegNum(seg);      ShowDMatrix("Beta",beta,10,12);       printf("LogP=%10.3f\n\n",beta[1][1]);   }   return x;}/* --------------------- Record Statistics ---------------- *//* SetOccr: set the global occupation counters occr for current seg */void SetOccr(LogDouble pr, int seg){   int i,t;   DVector alpha_i,beta_i;   Vector a_i;   LogDouble x;      occr[1] = 1.0;   for (i=2;i<nStates;i++) {      alpha_i = alpha[i]; beta_i = beta[i];      a_i = hmm->transP[i];      x=LZERO ;      for (t=1;t<=T;t++)         x=LAdd(x,alpha_i[t]+beta_i[t]);      x -= pr;      if (x>MINEARG)          occr[i] = exp(x);      else         occr[i] = 0.0;   }   if (trace & T_OCC){      ShowSegNum(seg);      ShowVector("OCC: ",occr,20);   }}/* UpTranCounts: update the transition counters in ta */void UpTranCounts(LogDouble pr,int seg){   int i,j,t;   Matrix tran;   Vector tran_i,outprob_j,a_i,occ;   DVector alpha_i,beta_j;   LogDouble x,a_ij;   double y;   TrAcc *ta;      ta = (TrAcc *) GetHook(hmm->transP);   tran = ta->tran; occ = ta->occ;   for (i=2; i<nStates; i++)      occ[i] += occr[i];   tran_i = tran[1];          /* transitions 1->j    1<j<nStates */   a_i = hmm->transP[1];   for (j=2;j<nStates;j++) {      a_ij = a_i[j];      if (a_ij>LSMALL) {         x = a_ij + outprob[j][1] + beta[j][1] - pr;         if (x>MINEARG) {            y =  exp(x);            tran_i[j] += y; occ[1] += y;         }      }   }   for (i=2;i<nStates;i++) {        /* transitions i->j    1<i,j<nStates */      a_i = hmm->transP[i];      alpha_i = alpha[i];      tran_i = tran[i];      for (j=2;j<nStates;j++) {         a_ij=a_i[j];         if (a_ij>LSMALL) {            x=LZERO; beta_j=beta[j]; outprob_j=outprob[j];            for (t=1;t<=T-1;t++)               x=LAdd(x,alpha_i[t]+a_ij+outprob_j[t+1]+beta_j[t+1]);            x -= pr;            if (x>MINEARG)               tran_i[j] += exp(x);         }      }   }   for (i=2; i<nStates; i++) {    /* transitions i->nStates    1<i<nStates */      a_ij = hmm->transP[i][nStates];      if (a_ij>LSMALL) {         x = a_ij + alpha[i][T] - pr;         if (x>MINEARG)            tran[i][nStates] += exp(x);      }        }   if (trace & T_TAC){      ShowSegNum(seg);      ShowMatrix("TRAN: ",tran,10,10);      ShowVector("TOCC: ",occ,10);      fflush(stdout);   }}/* UpStreamCounts: update mean, cov & mixweight counts for given stream */void UpStreamCounts(int j, int s, StreamElem *se, int vSize, LogDouble pr, int seg,                    DVector alphj, DVector betaj){   int i,m,nMix=0,k,l,t,ss,idx;   MixtureElem *me;   MixPDF *mpdf=NULL;   MuAcc *ma=NULL;   WtAcc *wa;   VaAcc *va=NULL;   Matrix *mixp_j;   Vector ot, strpt;   LogFloat a_ij,w;   LogDouble Lr;   double y;   Observation obs;   TMixRec *tmRec = NULL;   float wght=0.0;      wa = (WtAcc *)se->hook;   switch (hsKind){       /* Get nMix */   case TIEDHS:      tmRec = &(hset.tmRecs[s]);      nMix = tmRec->nMix;      break;   case PLAINHS:   case SHAREDHS:      nMix = se->nMix;      break;         case DISCRETEHS:      nMix = 1;                /* Only one code selected per observation */      break;   }   mixp_j = (maxMixes>1) ? mixoutp[j] : NULL;   for (m=1; m<=nMix; m++) {      switch (hsKind){            /* Get mpdf, wght */      case TIEDHS:                        wght=se->spdf.tpdf[m];         mpdf=tmRec->mixes[m];         break;      case DISCRETEHS:         wght=1.0;                 /* weight for DISCRETEHS has to be 1 */         mpdf=NULL;         break;      case PLAINHS:      case SHAREDHS:         me = se->spdf.cpdf+m;         wght=me->weight;         mpdf=me->mpdf;         break;      }      if (hsKind!=DISCRETEHS){         ma = (MuAcc *)GetHook(mpdf->mean);         va = (VaAcc *)GetHook(mpdf->cov.var);      }      if (wght > MINMIX) {         w = log(wght);         for (t=1; t<=T; t++) {                     /* Get observation vec ot and zero mean zot */            obs = GetSegObs(segStore, seg, t);            ot = obs.fv[s];            if (hsKind!=DISCRETEHS)               for (k=1; k<=vSize; k++)                  zot[k] = ot[k] - mpdf->mean[k];                                             /* Compute state/mix occupation log likelihood */            if (nMix==1 || (hsKind==DISCRETEHS)){               Lr = alphj[t]+betaj[t] - pr;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -