📄 htrain.c
字号:
}/* DumpVaAcc: dump variance acc to file f */static void DumpVaAcc(FILE *f, VaAcc *va, CovKind ck){ switch(ck){ case DIAGC: case INVDIAGC: WriteVector(f,va->cov.var,ldBinary); break; case FULLC: case LLTC: WriteTriMat(f,va->cov.inv,ldBinary); break; default: HError(7170,"DumpVaAcc: bad cov kind"); } WriteFloat(f,&(va->occ),1,ldBinary); if (!ldBinary) fprintf(f,"\n");}/* DumpTrAcc: dump transition acc to file f */static void DumpTrAcc(FILE *f, TrAcc *ta){ WriteMatrix(f,ta->tran,ldBinary); WriteVector(f,ta->occ,ldBinary); if (!ldBinary) fprintf(f,"\n");}/* DumpMarker: dump a marker into file f */static void DumpMarker(FILE *f){ int mark = 123456; WriteInt(f,&mark,1,ldBinary); if (!ldBinary) fprintf(f,"\n");}/* GetDumpFile: Process dump file name and open it */static FILE * GetDumpFile(char *name, int n){ char buf[MAXSTRLEN],num[20]; int i,j,k,len,nlen; FILE *f; sprintf(num,"%d",n); len = strlen(name); nlen = strlen(num); for (i=0,j=0; i<len; i++) { if (name[i] == '$') for (k=0; k<nlen; k++) buf[j++] = num[k]; else buf[j++] = name[i]; } buf[j] = '\0'; f = fopen(buf,"wb"); /* Binary file */ if (f==NULL) HError(7111,"GetDumpFile: cannot open acc dump file %s",buf); if (trace & T_ALD) printf("Dumping accumulators to file %s\n",buf); return f;}/* EXPORT->DumpAccs: Dump a copy of the accs in hset to fname. Any occurrence of the $ symbol in fname is replaced by n. The file is left open and returned */FILE * DumpAccs(HMMSet *hset, char *fname, int n){ return DumpAccsParallel(hset,fname,n,0); }FILE * DumpAccsParallel(HMMSet *hset, char *fname, int n, int index){ FILE *f; HLink hmm; HMMScanState hss; int m,s; MixPDF* mp; f = GetDumpFile(fname,n); NewHMMScan(hset, &hss); do { hmm = hss.hmm; DumpPName(f,hss.mac->id->name); WriteInt(f,(int *)&hmm->hook,1,ldBinary); while (GoNextState(&hss,TRUE)) { while (GoNextStream(&hss,TRUE)) { DumpWtAcc(f,((WtAcc *)hss.ste->hook)+index); if (hss.isCont){ while (GoNextMix(&hss,TRUE)) { if (!IsSeenV(hss.mp->mean)) { DumpMuAcc(f,((MuAcc *)GetHook(hss.mp->mean))+index); TouchV(hss.mp->mean); } if (!IsSeenV(hss.mp->cov.var)) { DumpVaAcc(f,((VaAcc *)GetHook(hss.mp->cov.var))+index,hss.mp->ckind); TouchV(hss.mp->cov.var); } } } } } if (!IsSeenV(hmm->transP)){ DumpTrAcc(f, ((TrAcc *) GetHook(hmm->transP))+index); TouchV(hmm->transP); } DumpMarker(f); } while (GoNextHMM(&hss)); EndHMMScan(&hss); if (hset->hsKind == TIEDHS){ for (s=1; s<=hset->swidth[0]; s++){ for (m=1; m<=hset->tmRecs[s].nMix; m++){ mp = hset->tmRecs[s].mixes[m]; DumpMuAcc(f,((MuAcc *)GetHook(mp->mean))+index); DumpVaAcc(f,((VaAcc *)GetHook(mp->cov.var))+index,mp->ckind); } } } return f;}/* LoadWtAcc: new inc of wt acc from file f */static void LoadWtAcc(Source *src, WtAcc *wa, int numMixtures){ int m; float f; Vector cTemp; cTemp = CreateVector(&gstack,numMixtures); ReadVector(src,cTemp,ldBinary); for (m=1;m<=numMixtures;m++){ if(!finite(cTemp[m])) HError(7191, "Infinite WtAcc!"); wa->c[m] += cTemp[m]; } ReadFloat(src,&f,1,ldBinary); wa->occ += f; FreeVector(&gstack,cTemp);}/* LoadMuAcc: new inc of mean acc from file f */static void LoadMuAcc(Source *src, MuAcc *ma, int vSize){ int k; Vector vTemp; float f; vTemp = CreateVector(&gstack,vSize); ReadVector(src,vTemp,ldBinary); for (k=1;k<=vSize;k++){ if(!finite(vTemp[k])) HError(7191, "Infinite MuAcc!"); ma->mu[k] += vTemp[k]; } ReadFloat(src,&f,1,ldBinary); ma->occ += f; FreeVector(&gstack,vTemp);}/* LoadVaAcc: p'th inc of variance acc from file f */static void LoadVaAcc(Source *src, VaAcc *va, int vSize, CovKind ck){ int k,kk; Vector vTemp; TriMat mTemp; float f; switch(ck){ case DIAGC: case INVDIAGC: vTemp = CreateVector(&gstack, vSize); ReadVector(src,vTemp,ldBinary); for (k=1;k<=vSize;k++){ if(!finite(vTemp[k])) HError(7191, "Infinite VaAcc!"); va->cov.var[k] += vTemp[k]; } FreeVector(&gstack, vTemp); break; case FULLC: case LLTC: mTemp = CreateTriMat(&gstack,vSize); ReadTriMat(src,mTemp,ldBinary); for (k=1;k<=vSize;k++) for (kk=1; kk<=k; kk++) { va->cov.inv[k][kk] += mTemp[k][kk]; } FreeTriMat(&gstack,mTemp); break; } ReadFloat(src,&f,1,ldBinary); va->occ += f;}/* LoadTrAcc: p'th inc of transition acc from file f */static void LoadTrAcc(Source *src, TrAcc *ta, int numStates){ int i,j; Matrix tTemp; Vector nTemp; tTemp = CreateMatrix(&gstack,numStates,numStates); nTemp = CreateVector(&gstack,numStates); ReadMatrix(src,tTemp,ldBinary); for (i=1;i<=numStates;i++) for (j=1;j<=numStates;j++) ta->tran[i][j] += tTemp[i][j]; ReadVector(src,nTemp,ldBinary); for (i=1;i<=numStates;i++) ta->occ[i] += nTemp[i]; FreeMatrix(&gstack,tTemp);}/* CheckPName: check dumped name matches hmm phys name */static void CheckPName(Source *src, char *pname){ int c; char buf[MAXSTRLEN]; ReadString(src,buf); c = GetCh(src); if (c != '\n') HError(7150,"CheckPName: Cant find EOL"); if (strcmp(pname,buf) != 0) HError(7150,"CheckPName: expected %s got %s",pname,buf);}/* CheckMarker: check file f has a marker next */static void CheckMarker(Source *src){ int mark = 123456, temp; ReadInt(src,&temp,1,ldBinary); if (temp != mark) HError(7150,"CheckMarker: Marker Expected in Dump File");}/* EXPORT->LoadAccs: inc accumulators in hset by vals in fname */Source LoadAccs(HMMSet *hset, char *fname){ return LoadAccsParallel(hset,fname,0); }Source LoadAccsParallel(HMMSet *hset, char *fname, int index){ Source src; HLink hmm; HMMScanState hss; int size,negs,m,s; MixPDF* mp; if (trace & T_ALD) printf("Loading accumulators from file %s\n",fname); if(InitSource(fname,&src,NoFilter)<SUCCESS) HError(7110,"LoadAccs: Can't open file %s", fname); NewHMMScan(hset, &hss); do { hmm = hss.hmm; CheckPName(&src,hss.mac->id->name); ReadInt(&src,&negs,1,ldBinary); negs += (int)hmm->hook; hmm->hook = (void *)negs; while (GoNextState(&hss,TRUE)) { while (GoNextStream(&hss,TRUE)) { size = hset->swidth[hss.s]; LoadWtAcc(&src,((WtAcc *)hss.ste->hook)+index,hss.M); if (hss.isCont){ while (GoNextMix(&hss,TRUE)) { if (!IsSeenV(hss.mp->mean)) { LoadMuAcc(&src,((MuAcc *)GetHook(hss.mp->mean))+index,size); TouchV(hss.mp->mean); } if (!IsSeenV(hss.mp->cov.var)) { LoadVaAcc(&src,((VaAcc *)GetHook(hss.mp->cov.var))+index, size,hss.mp->ckind); TouchV(hss.mp->cov.var); } } } } } if (!IsSeenV(hmm->transP)){ LoadTrAcc(&src, ((TrAcc *) GetHook(hmm->transP))+index,hss.N); TouchV(hmm->transP); } CheckMarker(&src); } while (GoNextHMM(&hss)); EndHMMScan(&hss); if (hset->hsKind == TIEDHS){ for (s=1; s<=hset->swidth[0]; s++){ size = hset->swidth[s]; for (m=1;m<=hset->tmRecs[s].nMix; m++){ mp = hset->tmRecs[s].mixes[m]; LoadMuAcc(&src,((MuAcc *)GetHook(mp->mean))+index,size); LoadVaAcc(&src,((VaAcc *)GetHook(mp->cov.var))+index,size,mp->ckind); } } } return src;}void RestorePDF(MixPDF *mp, int index){ int i,j; MuAcc *ma = ((MuAcc *)GetHook(mp->mean))+index; VaAcc *va = ((VaAcc*)GetHook(mp->cov.var))+index; int size = VectorSize(mp->mean); for(i=1;i<=size;i++){ ma->mu[i] += ma->occ * mp->mean[i]; } switch(mp->ckind){ case DIAGC: case INVDIAGC: for(i=1;i<=size;i++){ va->cov.var[i] += 2*ma->mu[i]*mp->mean[i] - va->occ*mp->mean[i]*mp->mean[i]; } break; case FULLC: case LLTC: for(i=1;i<=size;i++){ for(j=1;j<=i;j++){ va->cov.inv[i][j] += ma->mu[i]*mp->mean[j] + ma->mu[j]*mp->mean[i] - va->occ*mp->mean[i]*mp->mean[j]; } } break; default: HError(7191, "Unknown ckind [RestoreAccsParallel]"); }}void RestoreAccs(HMMSet *hset){ RestoreAccsParallel(hset,0); }void RestoreAccsParallel(HMMSet *hset, int index){ HMMScanState hss; int s,m,size; if(hset->hsKind==TIEDHS){ for (s=1; s<=hset->swidth[0]; s++){ size = hset->swidth[s]; for (m=1;m<=hset->tmRecs[s].nMix; m++) RestorePDF(hset->tmRecs[s].mixes[m], index); } } else { NewHMMScan(hset, &hss); while (GoNextMix(&hss,FALSE)) { RestorePDF(hss.mp, index); } EndHMMScan(&hss); }}double ScalePDF(MixPDF *mpdf, int vSize, int index, float wt){ float ans; MuAcc *ma = ((MuAcc*)GetHook(mpdf->mean))+index; VaAcc *va = ((VaAcc*)GetHook(mpdf->cov.var))+index; /*diagonal case, of course.*/ {/*Scale the mu.*/ int x; ma->occ *= wt; for(x=1;x<=vSize;x++) ma->mu[x] *= wt; } {/*Scale the var.*/ int x; ans = va->occ; va->occ *= wt; for(x=1;x<=vSize;x++) va->cov.var[x] *= wt; } return ans;}double ScaleAccs(HMMSet *hset, float wt){ return ScaleAccsParallel(hset,wt,0);}double ScaleAccsParallel(HMMSet *hset, float wt, int index){ HMMScanState hss; int s,m,size; float ans=0; if(hset->ckind != DIAGC || !(hset->hsKind==PLAINHS || hset->hsKind==SHAREDHS || hset->hsKind==TIEDHS)) HError(-1, "ScaleAccsParallel: wrong kind of hset."); /* Do gaussians. */ if(hset->hsKind==TIEDHS){ for (s=1; s<=hset->swidth[0]; s++){ size = hset->swidth[s]; for (m=1;m<=hset->tmRecs[s].nMix; m++) ans += ScalePDF(hset->tmRecs[s].mixes[m], size, index, wt); } } else { NewHMMScan(hset, &hss); while (GoNextMix(&hss,FALSE)) { size = hset->swidth[hss.s]; ans += ScalePDF(hss.mp, size, index, wt); } EndHMMScan(&hss); } /* Do weights. */ NewHMMScan(hset,&hss); while(GoNextState(&hss,FALSE)){ /*skip over hmm boundaries.*/ while(GoNextStream(&hss,TRUE)){ /*Don't skip over state boundaries.*/ StreamElem *ste = hss.ste; int m, nMix; WtAcc *wa = ((WtAcc*) hss.ste->hook)+index; switch(hset->hsKind){ case PLAINHS: case SHAREDHS: nMix = (ste->nMix>0?ste->nMix:-ste->nMix); wa->occ*=wt; /*take the value wa->occ to the desired value.*/ for(m=1;m<=nMix;m++){ wa->c[m] *= wt; /*scale the WtAcc->c[]*/ } break; case TIEDHS: nMix = hset->tmRecs[hss.s].nMix; for(m=1;m<=nMix;m++){ wa->c[m] *= wt; /*scale the WtAcc->c[]*/ } break; default: HError(1, "ScaleAccs- unknown hsKind."); } } } EndHMMScan(&hss); /* Do transitions. */ NewHMMScan(hset,&hss); do{ HLink hmm = hss.hmm; int i,j,N; /*This code taken from UpdateTransP*/ TrAcc *ta; if (!IsSeenV(hmm->transP)){ TouchV(hmm->transP); ta = ((TrAcc*)GetHook(hmm->transP))+index; if (ta==NULL) HError(1, "HTrain.c: ScaleAccs: null TransP."); N = hmm->numStates; for (i=1;i<N;i++) { ta->occ[i]*=wt; for (j=2;j<=N;j++) { ta->tran[i][j]*=wt; } } } } while(GoNextHMM(&hss)); EndHMMScan(&hss); return ans;}/* ------------------------ End of HTrain.c ----------------------- */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -