⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 hrest.c

📁 HMM的另一种经典的训练算法,需要的快下啊
💻 C
📖 第 1 页 / 共 3 页
字号:
            }            else {               if (t==1)                  Lr = hmm->transP[1][j];               else {                  Lr = LZERO;                  for (i=2; i<nStates; i++)                     if ((a_ij=hmm->transP[i][j]) > LSMALL)                        Lr = LAdd(Lr,alpha[i][t-1]+a_ij);               }               if (Lr>LSMALL) {                  Lr += mixp_j[t][s][m] + w + betaj[t] - pr;                  if (nStreams>1) { /* add contrib of parallel streams */                     strpt = stroutp[t][j];                     for (ss=1; ss<=nStreams; ss++)                        if (ss!=s) Lr += strpt[ss];                  }               }            }                        if (Lr > MINEARG) {               y = exp(Lr);                                 /* Update Weight Counter */               if (uFlags&UPMIXES) {                     idx = (hsKind==DISCRETEHS) ? obs.vq[s] : m;                  wa->occ += y; wa->c[idx] += y;               }                              /* Update Mean Counter */               if (uFlags&UPMEANS){                  ma->occ += y;                   for (k=1; k<=vSize; k++)                     ma->mu[k] += zot[k]*y; /* sum zero mean */               }                              /* Update Covariance Counter */               if (uFlags&UPVARS){                  va->occ += y;                  if (mpdf->ckind==DIAGC){                     for (k=1; k<=vSize; k++)                        va->cov.var[k] += zot[k] * zot[k] * y;                  } else{                     for (k=1; k<=vSize; k++)                        for (l=1; l<=k; l++)                           va->cov.inv[k][l] += zot[k] * zot[l] * y;                  }               }              }         }      }      if ((trace&(T_MAC|T_VAC))&&(hsKind!=DISCRETEHS)) {         ShowSegNum(seg);         printf("State %d, Stream %d, Mixture %d\n",j,s,m);         if (trace&T_MAC){            printf("MEAN OCC: %.2f\n",ma->occ);            ShowVector("MEAN ACC: ",ma->mu,10);         }         if (trace&T_VAC) {            if (mpdf->ckind==DIAGC){               printf("VAR OCC: %.2f\n",va->occ);               ShowVector("VAR ACC: ",va->cov.var,10);            } else {               printf("INV OCC: %.2f\n",va->occ);               ShowMatrix("INV ACC: ",va->cov.inv,10,10);            }         }      }   }   if (trace&T_WAC){      ShowSegNum(seg);      printf("State %d, Stream %d\n",j,s);      printf("WT OCC: %.2f\n",wa->occ);      ShowVector("WT ACC: ",wa->c,10);   }     }   /* UpPDFCounts: update output PDF counts for each stream of each state */void UpPDFCounts(LogDouble pr, int seg){   int j,s;   StateInfo *si;   StreamElem *se;   DVector alj,betj;   for (j=2; j<nStates; j++) {      si = hmm->svec[j].info;      alj = alpha[j]; betj = beta[j];      for (s=1,se = si->pdf+1; s<=nStreams; s++,se++)         UpStreamCounts(j,s,se,hset.swidth[s],pr,seg,alj,betj);   }}/* UpdateCounters: update the various counters */void UpdateCounters(LogDouble pr, int seg){   SetOccr(pr,seg);   if (uFlags&UPTRANS)       UpTranCounts(pr,seg);   if (uFlags&(UPMEANS|UPVARS|UPMIXES))      UpPDFCounts(pr,seg);}/* ------------------------- Model Update ----------------------- *//* RestTransP: reestimate transition probs */void RestTransP(void){   int i,j;   float occi,x,sum;   TrAcc *ta;   ta = (TrAcc *) GetHook(hmm->transP);      for (i=1;i<nStates;i++) {      hmm->transP[i][1] = LZERO;      occi = ta->occ[i];      if (occi == 0.0)         HError(2222,"RestTransP: Zero state %d occupation count",i);      sum = 0.0;      for (j=2;j<=nStates;j++) {         x = ta->tran[i][j]/occi;         hmm->transP[i][j] = x; sum += x;      }      for (j=2;j<=nStates;j++) {         x = hmm->transP[i][j]/sum;         hmm->transP[i][j] = (x<MINLARG) ? LZERO : log(x);      }   }   if (trace & T_TRE)      ShowMatrix("NEW TRANS: ",hmm->transP,10,10);}/* FloorMixes: apply floor to given mix set */void FloorMixes(MixtureElem *mixes, int M, float floor){   float sum,fsum,scale;   MixtureElem *me;   int m;      sum = fsum = 0.0;   for (m=1,me=mixes; m<=M; m++,me++) {      if (me->weight>floor)         sum += me->weight;      else {         fsum += floor; me->weight = floor;      }   }   if (fsum>1.0)      HError(2223,"FloorMixes: Floor sum too large");   scale = (1.0-fsum)/sum;   if (trace&T_WRE) printf("MIXW: ");   for (m=1,me=mixes; m<=M; m++,me++){      if (me->weight>floor)         me->weight *= scale;      if (trace&T_WRE) printf(" %.2f",me->weight);   }   if (trace&T_WRE) printf("\n");}  /* FloorTMMixes: apply floor to given tied mix set */void FloorTMMixes(Vector mixes, int M, float floor){   float sum,fsum,scale,fltWt;   int m;      sum = fsum = 0.0;   for (m=1; m<=M; m++) {      fltWt = mixes[m];      if (fltWt>floor)         sum += fltWt;      else {         fsum += floor;         mixes[m] = floor;      }   }   if (fsum>1.0) HError(2223,"FloorTMMixes: Floor sum too large");   scale = (1.0-fsum)/sum;   if (trace&T_WRE) printf("MIXW: ");   for (m=1; m<=M; m++){      fltWt = mixes[m];      if (fltWt>floor)         mixes[m] = fltWt*scale;      if (trace&T_WRE) printf(" %.2f",fltWt);   }}/* FloorDProbs: apply floor to given discrete prob set */void FloorDProbs(ShortVec mixes, int M, float floor){   float sum,fsum,scale,fltWt;   int m;      sum = fsum = 0.0;   for (m=1; m<=M; m++) {      fltWt = Short2DProb(mixes[m]);      if (fltWt>floor)         sum += fltWt;      else {         fsum += floor;         mixes[m] = DProb2Short(floor);      }   }   if (fsum>1.0) HError(2327,"FloorDProbs: Floor sum too large");   if (fsum == 0.0) return;   if (sum == 0.0) HError(2328,"FloorDProbs: No probabilities above floor");   scale = (1.0-fsum)/sum;   for (m=1; m<=M; m++){      fltWt = Short2DProb(mixes[m]);      if (fltWt>floor)         mixes[m] = DProb2Short(fltWt*scale);   }}/* RestMixWeights: reestimate the mixture weights */void RestMixWeights(int state, int s, StreamElem *se){   WtAcc *wa;   int m,M=0;   float x;   MixtureElem *me;      wa = (WtAcc *)se->hook;   if (wa->occ == 0.0)      HError(2222,"RestMixWeights: Zero weight occupation count");   switch (hsKind){   case TIEDHS:      M=hset.tmRecs[s].nMix;      break;   case PLAINHS:   case SHAREDHS:   case DISCRETEHS:      M=se->nMix;      break;   }   for (m=1; m<=M; m++){      x = wa->c[m]/wa->occ;      if (x>1.0)          HError(2290,"RestMixWeights: Mix wt>1 in %d.%d.%d",state,s,m);      switch (hsKind){      case DISCRETEHS:         se->spdf.dpdf[m] = (x>MINMIX) ? DProb2Short(x) : DLOGZERO;         break;      case TIEDHS:         se->spdf.tpdf[m] = (x>MINMIX) ? x : 0.0;         break;      case PLAINHS:      case SHAREDHS:         me=se->spdf.cpdf+m;         me->weight = (x>MINMIX) ? x : 0.0;         break;      }         }}/* RestMean: reestimate the given mean vector */void RestMean(Vector mean, int vSize){   int k;   MuAcc *ma;   float x;      ma = (MuAcc *)GetHook(mean);   if (ma->occ == 0.0)      HError(2222,"RestMean: Zero mean occupation count");   for (k=1; k<=vSize; k++){      x = mean[k] + ma->mu[k]/ma->occ;      ma->mu[k] = mean[k];  /* remember old mean */      mean[k] = x;   }   if (trace&T_MRE)      ShowVector("MEAN: ",mean,10);}/* RestCoVar: reestimate the given covariance and return FALSE              if any diagonal component == 0.0 */Boolean RestCoVar(MixPDF *mp, int vSize, Vector minV,                  Vector oldMean, Vector newMean, Boolean shared){   int k,l;   VaAcc *va;   float x,z;   float muDiffk,muDiffl;      va = (VaAcc *)GetHook(mp->cov.var);   if (va->occ == 0.0)      HError(2222,"RestCoVar: Zero variance occupation count");   if (mp->ckind==DIAGC){      for (k=1; k<=vSize; k++){         muDiffk = (shared)?0.0:newMean[k]-oldMean[k];         x = va->cov.var[k] / va->occ - muDiffk*muDiffk;         if (x<minV[k]) x = minV[k];         if (x<vDefunct) return FALSE;         mp->cov.var[k] = x;      }      FixDiagGConst(mp);      if (trace&T_VRE)         ShowVector("VARS: ",mp->cov.var,10);   } else {      for (k=1; k<=vSize; k++){         muDiffk = (shared)?0.0:newMean[k]-oldMean[k];         for (l=1; l<k; l++) {            muDiffl = (shared)?0.0:newMean[l]-oldMean[l];            x = va->cov.inv[k][l] / va->occ - muDiffk*muDiffl;            mp->cov.inv[k][l] = x;         }         z = va->cov.inv[k][k]/va->occ - muDiffk*muDiffk;         mp->cov.inv[k][k] = (z<minV[k])?minV[k]:z;      }      if (trace&T_VRE)         ShowTriMat("COVM: ",mp->cov.inv,10,10);      FixFullGConst(mp,CovInvert(mp->cov.inv,mp->cov.inv));   }   return TRUE;}/* RestStream: reestimate stream parameters */void RestStream(int state, int s, StreamElem *se, int vSize){   int m,M;   MixtureElem *me;   MixPDF *mp;   MuAcc *ma;   Boolean shared;   float wght;   if (trace&(T_WRE|T_MRE|T_VRE))      printf("State %d, Stream %d\n",state,s);   if (uFlags&UPMIXES)      RestMixWeights(state,s,se);   if ((hsKind != DISCRETEHS)&&(hsKind != TIEDHS)){ /*wts only DI'ETE & TIED*/      M=se->nMix;      for (m=1; m<=M; m++){         me = se->spdf.cpdf+m;         wght=me->weight;         mp=me->mpdf;         if (wght > MINMIX) {            if (trace&(T_MRE|T_VRE) && M>1)               printf("Mixture %d\n",m);            if (uFlags&UPMEANS)               RestMean(mp->mean,vSize);            /* NB old mean left in ma->mu */            if (uFlags&UPVARS){               shared = GetUse(mp->cov.var) > 1;               ma = (MuAcc *)GetHook(mp->mean);               if ( !RestCoVar(mp,vSize,vFloor[s],ma->mu,mp->mean,shared)) {                  if (M > 1) {                     HError(-2225,"RestStream: Defunct Mix %d.%d.%d",state,s,m);                     me->weight = 0.0;                  } else                     HError(2222,"RestStream: Zero Covariance in %d.%d",state,s);               }            }         }      }   }   if (hsKind == TIEDHS)      M=hset.tmRecs[s].nMix;   else      M=se->nMix;   if (M>1){      switch (hsKind){      case DISCRETEHS:         FloorDProbs(se->spdf.dpdf,M,mixWeightFloor);         break;      case TIEDHS:         FloorTMMixes(se->spdf.tpdf,M,mixWeightFloor);         break;      case PLAINHS:      case SHAREDHS:         FloorMixes(se->spdf.cpdf+1,M,mixWeightFloor);         break;      }        }}/* UpdateTheModel: use accumulated statistics to update model */void UpdateTheModel(void){   int j,s;   StateInfo *si;   StreamElem *se;   if (uFlags&UPTRANS)      RestTransP();   if (uFlags&(UPMEANS|UPVARS|UPMIXES))      for (j=2; j<nStates; j++) {         si = hmm->svec[j].info;         for (s=1,se = si->pdf+1; s<=nStreams; s++,se++)            RestStream(j,s,se,hset.swidth[s]);      }   if (uFlags&UPVARS)      FixAllGConsts(&hset);}/* ------------------------- Top Level Control ----------------------- *//* ReEstimateModel: top level of algorithm */void ReEstimateModel(void){   LogFloat segProb,oldP,newP,delta;   LogDouble ap,bp;   int converged,iteration,seg;   iteration=0;    oldP=LZERO;   do {        /*main re-est loop*/         ZeroAccs(&hset, uFlags); newP = 0.0; ++iteration;      nTokUsed = 0;      for (seg=1;seg<=nSeg;seg++) {         T=SegLength(segStore,seg);         SetOutP(seg);         if ((ap=SetAlpha(seg)) > LSMALL){            bp = SetBeta(seg);            if (trace & T_LGP)               printf("%d.  Pa = %e, Pb = %e, Diff = %e\n",seg,ap,bp,ap-bp);            segProb = (ap + bp) / 2.0;  /* reduce numeric error */            newP += segProb; ++nTokUsed;            UpdateCounters(segProb,seg);         } else            if (trace&T_TOP)                printf("Example %d skipped\n",seg);      }      if (nTokUsed==0)         HError(2226,"ReEstimateModel: No Usable Training Examples");      UpdateTheModel();      newP /= nTokUsed;      delta=newP-oldP; oldP=newP;      converged=(fabs(delta)<epsilon);       if (trace&T_TOP) {         printf("Ave LogProb at iter %d = %10.5f using %d examples",                iteration,oldP,nTokUsed);         if (iteration > 1)            printf("  change = %10.5f",delta);         printf("\n");         fflush(stdout);      }   } while ((iteration < maxIter) && !converged);   if (trace&T_TOP) {      if (converged)         printf("Estimation converged at iteration %d\n",iteration);      else         printf("Estimation aborted at iteration %d\n",iteration);      fflush(stdout);   }}/* ----------------------------------------------------------- *//*                      END:  HRest.c                          *//* ----------------------------------------------------------- */

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -