📄 hinit.c
字号:
} CopyVector(thisP,lastP); if (trace & T_VIT) ShowP(segIdx,lastP); } /* column[segLen]--> exit state(numStates) */ bestPrevState=2; tranP = hmmLink->transP[2][nStates]; prevP = lastP[2]; bestP=(tranP<LSMALL) ? LZERO : tranP+prevP; for (prevState=3;prevState<nStates;prevState++) { tranP = hmmLink->transP[prevState][nStates]; prevP = lastP[prevState]; currP = (tranP<LSMALL) ? LZERO : tranP+prevP; if (currP > bestP) { bestPrevState=prevState; bestP=currP; } } /* bestPrevState now gives last internal state along best state sequence */ if (bestP<LSMALL) HError(2126,"ViterbiAlign: No path found in %d'th segment",segNum); if (trace & T_VIT) { ShowTraceBack(segLen,traceBack); printf(" bestP = %12.5f via state %d\n",bestP,bestPrevState); fflush(stdout); } DoTraceBack(segLen,states,bestPrevState); if (mixes!=NULL) /* ie not DISCRETE */ FindBestMixes(segNum,segLen,states,mixes); ResetHeap( &traceBackStack ); return bestP; }/* ----------------- Update Count Routines --------------------------- *//* UpdateCounts: using frames in seg i and alignment in states/mixes */void UpdateCounts(int segNum, int segLen, IntVec states,IntVec *mixes){ int M=0,i,j,k,s,m,state,last; StreamElem *ste; MixPDF *mp = NULL; WtAcc *wa; MuAcc *ma; VaAcc *va; TrAcc *ta; Vector v; Observation obs; TMixRec *tmRec = NULL; float x,y; last = 1; /* last before 1st emitting state must be 1 */ ta = (TrAcc *)GetHook(hmmLink->transP); for (i=1; i<=segLen; i++){ state = states[i]; if (trace&T_CNT) printf(" Seg %d -> state %d\n",i,state); if (uFlags&(UPMEANS|UPVARS|UPMIXES)){ obs = GetSegObs(segStore, segNum, i); if (hset.hsKind == TIEDHS) PrecomputeTMix(&hset, &obs, 50.0, 0); ste = hmmLink->svec[state].info->pdf+1; for (s=1; s<=nStreams; s++,ste++){ if (hset.hsKind==DISCRETEHS){ m = obs.vq[s]; v = NULL; } else { v = obs.fv[s]; m = mixes[s][i]; } switch(hset.hsKind){ case TIEDHS: tmRec = &(hset.tmRecs[s]); M = tmRec->nMix; break; case PLAINHS: case SHAREDHS: case DISCRETEHS: M = ste->nMix; break; } if (m<1 || m > M) HError(2170,"UpdateCounts: mix/vq idx out of range[%d]",m); if (trace&T_CNT) printf(" stream %d -> mix %d[%d]\n",s,m,M); /* update mixture weight */ if (M>1 && (uFlags&UPMIXES)) { wa = (WtAcc *)ste->hook; wa->occ += 1.0; wa->c[m] += 1.0; if (trace&T_CNT) printf(" mix wt -> %.1f\n",wa->c[m]); } if (hset.hsKind==DISCRETEHS) continue; /* update state/mixture component */ switch(hset.hsKind){ case PLAINHS: case SHAREDHS: mp = ste->spdf.cpdf[m].mpdf; break; case TIEDHS: mp = tmRec->mixes[m]; break; } ma = (MuAcc *)GetHook(mp->mean); va = (VaAcc *)GetHook(mp->cov.var); ma->occ += 1.0; va->occ += 1.0; for (j=1; j<=hset.swidth[s]; j++) { x = v[j] - mp->mean[j]; ma->mu[j] += x; if (uFlags&UPVARS) switch(mp->ckind){ case DIAGC: va->cov.var[j] += x*x; break; case FULLC: for (k=1; k<=j; k++){ y = v[k]-mp->mean[k]; va->cov.inv[j][k] += x*y; } break; default: HError(2124,"UpdateCounts: bad cov kind %d\n", mp->ckind); } } if (trace&T_CNT) { ShowVector(" mean ->",ma->mu,6); if (uFlags&UPVARS) { if (mp->ckind==DIAGC) ShowVector(" var ->",va->cov.var,6); else ShowTriMat(" cov ->",va->cov.inv,6,6); } fflush(stdout); } } } /* update transition probs */ if (uFlags&UPTRANS){ ta->occ[last] += 1.0; ta->tran[last][state] += 1.0; last = state; if (i==segLen){ /* remember final state */ ta->occ[state] += 1.0; ta->tran[state][nStates] += 1.0; } if (trace&T_CNT) { ShowMatrix(" tran ->",ta->tran,6,6); fflush(stdout); } } }}/* ----------------- Update Parameters --------------------------- *//* UpWeights: update given mixture weights */void UpWeights(int i, int s, int M, WtAcc *wa, StreamElem *ste){ int m; float sum=0.0; if (wa->occ == 0.0) HError(2127,"UpWeights: zero occ i=%d/s=%d",i,s); for (m=1; m<=M; m++){ sum += wa->c[m]; switch(hset.hsKind){ case PLAINHS: case SHAREDHS: ste->spdf.cpdf[m].weight = wa->c[m] / wa->occ; break; case TIEDHS: ste->spdf.tpdf[m] = wa->c[m] / wa->occ; break; } } if (fabs(sum-wa->occ)/sum > 0.001) HError(2190,"UpWeights: mix weight sum error");}/* UpMeans: update mean, leave old mean in acc */void UpMeans(int i, int s, int m, int size, MuAcc *ma, Vector mean){ int k; float x; if (ma->occ == 0.0) HError(2127,"UpMeans: zero occ i=%d/s=%d/m=%d",i,s,m); for (k=1; k<=size; k++){ x = mean[k] + ma->mu[k]/ma->occ; ma->mu[k] = mean[k]; /* remember old mean */ if (uFlags&UPMEANS) mean[k] = x; }}/* UpVars: update variances, apply correction if covariance is not shared */void UpVars(int i, int s, int m, int size, VaAcc *va, Vector oldMean, Vector newMean, Boolean shared, MixPDF *mp){ int j,k; float x,y,z; Vector floor; if (va->occ == 0.0) HError(2127,"UpVars: zero occ i=%d/s=%d/m=%d",i,s,m); floor=vFloor[s]; switch(mp->ckind){ case DIAGC: for (j=1; j<=size; j++){ x = (shared)?0.0:newMean[j]-oldMean[j]; z = va->cov.var[j]/va->occ - x*x; mp->cov.var[j] = (z<floor[j])?floor[j]:z; } FixDiagGConst(mp); break; case FULLC: for (j=1; j<=size; j++){ x = (shared)?0.0:newMean[j]-oldMean[j]; for (k=1; k<j; k++) { y = (shared)?0.0:newMean[k]-oldMean[k]; mp->cov.inv[j][k] = va->cov.inv[j][k]/va->occ - x*y; } z = va->cov.inv[j][j]/va->occ - x*x; mp->cov.inv[j][j] = (z<floor[j])?floor[j]:z; } FixFullGConst(mp,CovInvert(mp->cov.inv,mp->cov.inv)); break; default: HError(2124,"UpVars: bad cov kind %d",mp->ckind); }}/* UpTrans: update transition parameters */void UpTrans(TrAcc *ta, Matrix tr){ int i,j; float occi,x,sum; for (i=1; i<nStates; i++){ occi = ta->occ[i]; if (occi == 0.0) HError(2127,"UpTrans: zero occ in state %d",i); sum = 0.0; tr[i][1] = LZERO; for (j=2;j<=nStates;j++) { x = ta->tran[i][j]/occi; tr[i][j] = x; sum += x; } if (fabs(sum-1.0) > 0.001) HError(2190,"UpTrans: row %d, sum=%f",i,sum,occi); for (j=2;j<=nStates;j++) { x = tr[i][j]/sum; tr[i][j] = (x<MINLARG) ? LZERO : log(x); } }}/* UpDProbs: update given mixture weights */void UpDProbs(int i, int s, int M, WtAcc *wa, ShortVec dw){ int m; float x,sum=0.0; if (wa->occ == 0.0) HError(2127,"UpDProbs: zero occ i=%d/s=%d",i,s); for (m=1; m<=M; m++){ sum += wa->c[m]; x = wa->c[m] / wa->occ; if (x<mixWeightFloor) x = mixWeightFloor; dw[m] = DProb2Short(x); } if (fabs(sum-wa->occ)/sum > 0.001) HError(2190,"UpDProbs: dprob weight sum error");}/* UpdateParameters: in hmm using counts in accumulators */void UpdateParameters(void){ HMMScanState hss; int size; StreamElem *ste; WtAcc *wa; MuAcc *ma = NULL; VaAcc *va; TrAcc *ta; Boolean hFound = FALSE,shared; NewHMMScan(&hset,&hss); do if (hmmLink == hss.hmm){ hFound = TRUE; while (GoNextState(&hss,TRUE)) { while (GoNextStream(&hss,TRUE)) { ste = hss.ste; if (hss.M>1 && (uFlags&UPMIXES)){ wa = (WtAcc *)ste->hook; if (hset.hsKind == DISCRETEHS) UpDProbs(hss.i,hss.s,hss.M,wa,ste->spdf.dpdf); else UpWeights(hss.i,hss.s,hss.M,wa,ste); } size = hset.swidth[hss.s]; if (hss.isCont && (uFlags&(UPMEANS|UPVARS)))/*PLAINHS or SHAREDHS*/ while (GoNextMix(&hss,TRUE)) { if (!IsSeenV(hss.mp->mean)) { ma = (MuAcc *)GetHook(hss.mp->mean); UpMeans(hss.i,hss.s,hss.m,size,ma,hss.mp->mean); /* NB old mean left in ma->mu */ TouchV(hss.mp->mean); } if (!IsSeenV(hss.mp->cov.var)) { if (uFlags&UPVARS) { va = (VaAcc *)GetHook(hss.mp->cov.var); shared = GetUse(hss.mp->cov.var) > 1; UpVars(hss.i,hss.s,hss.m,size,va,ma->mu,hss.mp->mean, shared,hss.mp); } TouchV(hss.mp->cov.var); } } } } if (!IsSeenV(hmmLink->transP)) { if (uFlags&UPTRANS){ ta = (TrAcc *)GetHook(hmmLink->transP); UpTrans(ta,hmmLink->transP); } TouchV(hmmLink->transP); } } while (!hFound && GoNextHMM(&hss)); EndHMMScan(&hss); if (!hFound) HError(2129,"UpdateParameters: hmm not found");}/* ----------------- Top Level of Estimation Procedure --------------- *//* CreateMixes: create array[1..S][1..segLen] of mix component index */IntVec *CreateMixes(MemHeap *x,int segLen){ IntVec *mixes; int s; mixes = (IntVec*)New(x,sizeof(IntVec)*nStreams); --mixes; for (s=1; s<=nStreams; s++) mixes[s] = CreateIntVec(x,segLen); return mixes;}/* EstimateModel: top level of iterative estimation process */void EstimateModel(void){ LogFloat totalP,newP,delta; Boolean converged = FALSE; int i,iter,numSegs,segLen; IntVec states; /* array[1..numSegs] of State */ IntVec *mixes; /* array[1..S][1..numSegs] of MixComp */ if (trace&T_TOP) printf("Starting Estimation Process\n"); if (newModel){ UniformSegment(); } totalP=LZERO; for (iter=1; !converged && iter<=maxIter; iter++){ ZeroAccs(&hset, uFlags); /* Clear all accumulators */ numSegs = NumSegs(segStore); /* Align on each training segment and accumulate stats */ for (newP=0.0,i=1;i<=numSegs;i++) { segLen = SegLength(segStore,i); states = CreateIntVec(&gstack,segLen); mixes = (hset.hsKind==DISCRETEHS)?NULL: CreateMixes(&gstack,segLen); newP += ViterbiAlign(i,segLen,states,mixes); if (trace&T_ALN) ShowAlignment(i,segLen,states,mixes); UpdateCounts(i,segLen,states,mixes); FreeIntVec(&gstack,states); /* disposes mixes too */ } /* Update parameters or quit */ newP /= (float)numSegs; delta = newP - totalP; converged = (iter>1) && (fabs(delta) < epsilon); if (!converged) UpdateParameters(); totalP = newP; if (trace & T_TOP){ printf("Iteration %d: Average LogP =%12.5f",iter,totalP); if (iter > 1) printf(" Change =%12.5f\n",delta); else printf("\n"); fflush(stdout); } } if (trace&T_TOP) { if (converged) printf("Estimation converged at iteration %d\n",iter); else printf("Estimation aborted at iteration %d\n",iter); fflush(stdout); }}/* ------------------------- Save Model ----------------------- *//* SaveModel: save HMMSet containing one model */void SaveModel(char *outfn){ if (outfn != NULL) macroLink->id = GetLabId(outfn,TRUE); if(SaveHMMSet(&hset,outDir,NULL,NULL,saveBinary)<SUCCESS) HError(2111,"SaveModel: SaveHMMSet failed");}/* ----------------------------------------------------------- *//* END: HInit.c *//* ----------------------------------------------------------- */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -