📄 lm_3g.c
字号:
FILE *fp; E_INFO("Dumping LM to %s\n", file); if ((fp = fopen(file, "wb")) == NULL) { E_ERROR("Cannot create file %s\n", file); return 0; }#define fwrite_int32(f,v) fwrite(&v,sizeof(v),1,f) k = strlen(darpa_hdr) + 1; fwrite_int32(fp, k); fwrite(darpa_hdr, sizeof(char), k, fp); k = strlen(lmfile) + 1; fwrite_int32(fp, k); fwrite(lmfile, sizeof(char), k, fp); /* Write version# and LM file modification date */ k = -1; fwrite_int32(fp, k); /* version # */ fwrite_int32(fp, k); /* ignore modification date */ /* Write file format description into header */ for (i = 0; fmtdesc[i] != NULL; i++) { k = strlen(fmtdesc[i]) + 1; fwrite_int32(fp, k); fwrite(fmtdesc[i], sizeof(char), k, fp); } /* Pad it out for alignment purposes */ k = ftell(fp) & 3; if (k > 0) { k = 4 - k; fwrite_int32(fp, k); fwrite(&zero, 1, k, fp); } fwrite_int32(fp, zero); fwrite_int32(fp, model->ucount); fwrite_int32(fp, model->bcount); fwrite_int32(fp, model->tcount); for (i = 0; i <= model->ucount; i++) fwrite(&(model->unigrams[i]), sizeof(model->unigrams[i]), 1, fp); for (i = 0; i <= model->bcount; i++) fwrite(&(model->bigrams[i]), sizeof(model->bigrams[i]), 1, fp); for (i = 0; i < model->tcount; i++) fwrite(&(model->trigrams[i]), sizeof(model->trigrams[i]), 1, fp); fwrite_int32(fp, model->n_prob2); for (i = 0; i < model->n_prob2; i++) fwrite_int32(fp, model->prob2[i].l); if (model->tcount > 0) { fwrite_int32(fp, model->n_bo_wt2); for (i = 0; i < model->n_bo_wt2; i++) fwrite_int32(fp, model->bo_wt2[i].l); fwrite_int32(fp, model->n_prob3); for (i = 0; i < model->n_prob3; i++) fwrite_int32(fp, model->prob3[i].l); k = (model->bcount + 1) / BG_SEG_SZ + 1; fwrite_int32(fp, k); for (i = 0; i < k; i++) fwrite_int32(fp, model->tseg_base[i]); } k = 0; for (i = 0; i < model->ucount; i++) k += strlen(word_str[i]) + 1; fwrite_int32(fp, k); for (i = 0; i < model->ucount; i++) fwrite(word_str[i], sizeof(char), strlen(word_str[i]) + 1, fp); fclose(fp); return 0;}voidlmSetStartSym(char const *sym)/*----------------------------* * Description - reconfigure the start symbol */{ start_sym = ckd_salloc(sym);}voidlmSetEndSym(char const *sym)/*----------------------------* * Description - reconfigure the end symbol */{ end_sym = ckd_salloc(sym);}/* * Convert probs and backoff weights to LOG quantities, add language weight * and insertion penalty. */static voidlm_set_param(lm_t * model, double lw, double uw, double wip, int32 word_pair){ int32 i; int32 tmp1, tmp2; int32 logUW, logOneMinusUW, logUniform; const int16 *at = fe_logadd_table; int32 ts = fe_logadd_table_size; model->lw = FLOAT2LW(lw); model->invlw = FLOAT2LW(1.0 / lw); model->uw = uw; model->log_wip = LOG(wip); E_INFO("%8.2f = Language Weight\n", LW2FLOAT(model->lw)); E_INFO("%8.2f = Unigram Weight\n", model->uw); E_INFO("%8d = LOG (Insertion Penalty (%.2f))\n", model->log_wip, wip); logUW = LOG(model->uw); logOneMinusUW = LOG(1.0 - model->uw); logUniform = LOG(1.0 / (model->ucount - 1)); /* -1 for ignoring <s> */ if (word_pair) E_FATAL("word-pair LM not implemented\n"); for (i = 0; i < model->ucount; i++) { model->unigrams[i].bo_wt1.l = (LOG10TOLOG(UG_BO_WT_F(model, i)) * lw); /* Interpolate LM unigram prob with uniform prob (except start_sym) */ if (strcmp(word_str[i], start_sym) == 0) { model->unigrams[i].prob1.l = (LOG10TOLOG(UG_PROB_F(model, i)) * lw) + model->log_wip; } else { tmp1 = (LOG10TOLOG(UG_PROB_F(model, i))) + logUW; tmp2 = logUniform + logOneMinusUW; FAST_ADD(tmp1, tmp1, tmp2, at, ts); model->unigrams[i].prob1.l = (tmp1 * lw) + model->log_wip; } } for (i = 0; i < model->n_prob2; i++) { model->prob2[i].l = (LOG10TOLOG(model->prob2[i].f) * lw) + model->log_wip; } if (model->tcount > 0) { for (i = 0; i < model->n_bo_wt2; i++) { model->bo_wt2[i].l = (LOG10TOLOG(model->bo_wt2[i].f) * lw); } } if (model->tcount > 0) { for (i = 0; i < model->n_prob3; i++) { model->prob3[i].l = (LOG10TOLOG(model->prob3[i].f) * lw) + model->log_wip; } }}#define BINARY_SEARCH_THRESH 16int32lm3g_ug_score(int32 wid){ int32 lwid; if ((lwid = lmp->dictwid_map[wid]) < 0) E_FATAL("dictwid[%d] not in LM\n", wid); lm_last_access_type = LM3G_ACCESS_UG; return (lmp->unigrams[lwid].prob1.l + lmp->inclass_ugscore[wid]);}/* Locate a specific bigram within a bigram list */static int32find_bg(bigram_t * bg, int32 n, int32 w){ int32 i, b, e; /* Binary search until segment size < threshold */ b = 0; e = n; while (e - b > BINARY_SEARCH_THRESH) { i = (b + e) >> 1; if (bg[i].wid < w) b = i + 1; else if (bg[i].wid > w) e = i; else return i; } /* Linear search within narrowed segment */ for (i = b; (i < e) && (bg[i].wid != w); i++); return ((i < e) ? i : -1);}/* w1, w2 are dictionary (base-)word ids */int32lm3g_bg_score(int32 w1, int32 w2){ int32 lw1, lw2, i, n, b, score; lm_t *lm; bigram_t *bg; lm = lmp; /* lm->n_bg_score++; */ if ((lw1 = lm->dictwid_map[w1]) < 0) E_FATAL("dictwid[%d] not in LM\n", w1); if ((lw2 = lm->dictwid_map[w2]) < 0) E_FATAL("dictwid[%d] not in LM\n", w2); b = FIRST_BG(lm, lw1); n = FIRST_BG(lm, lw1 + 1) - b; bg = lm->bigrams + b; if ((i = find_bg(bg, n, lw2)) >= 0) { score = lm->prob2[bg[i].prob2].l; lm_last_access_type = LM3G_ACCESS_BG; } else { /* lm->n_bg_bo++; */ score = lm->unigrams[lw1].bo_wt1.l + lm->unigrams[lw2].prob1.l; lm_last_access_type = LM3G_ACCESS_UG; } score += lm->inclass_ugscore[w2];#if 0 printf(" %5d %5d -> %8d (%16s %16s)\n", w1, w2, score, word_dict->dict_list[UG_MAPID(lm, lw1)]->word, word_dict->dict_list[UG_MAPID(lm, lw2)]->word);#endif return (score);}static voidload_tginfo(lm_t * lm, int32 lw1, int32 lw2){ int32 i, n, b, t; bigram_t *bg; tginfo_t *tginfo; /* First allocate space for tg information for bg lw1,lw2 */ tginfo = (tginfo_t *) listelem_alloc(sizeof(tginfo_t)); tginfo->w1 = lw1; tginfo->tg = NULL; tginfo->next = lm->tginfo[lw2]; lm->tginfo[lw2] = tginfo; /* Locate bigram lw1,lw2 */ b = lm->unigrams[lw1].bigrams; n = lm->unigrams[lw1 + 1].bigrams - b; bg = lm->bigrams + b; if ((n > 0) && ((i = find_bg(bg, n, lw2)) >= 0)) { tginfo->bowt = lm->bo_wt2[bg[i].bo_wt2].l; /* Find t = Absolute first trigram index for bigram lw1,lw2 */ b += i; /* b = Absolute index of bigram lw1,lw2 on disk */ t = FIRST_TG(lm, b); tginfo->tg = lm->trigrams + t; /* Find #tg for bigram w1,w2 */ tginfo->n_tg = FIRST_TG(lm, b + 1) - t; } else { /* No bigram w1,w2 */ tginfo->bowt = 0; tginfo->n_tg = 0; }}/* Similar to find_bg */static int32find_tg(trigram_t * tg, int32 n, int32 w){ int32 i, b, e; b = 0; e = n; while (e - b > BINARY_SEARCH_THRESH) { i = (b + e) >> 1; if (tg[i].wid < w) b = i + 1; else if (tg[i].wid > w) e = i; else return i; } for (i = b; (i < e) && (tg[i].wid != w); i++); return ((i < e) ? i : -1);}/* w1, w2, w3 are dictionary wids */int32lm3g_tg_score(int32 w1, int32 w2, int32 w3){ int32 lw1, lw2, lw3, i, n, score; lm_t *lm; trigram_t *tg; tginfo_t *tginfo, *prev_tginfo; lm = lmp; if ((lm->tcount <= 0) || (w1 < 0)) return (lm3g_bg_score(w2, w3)); /* lm->n_tg_score++; */ if ((lw1 = lm->dictwid_map[w1]) < 0) E_FATAL("dictwid[%d] not in LM\n", w1); if ((lw2 = lm->dictwid_map[w2]) < 0) E_FATAL("dictwid[%d] not in LM\n", w2); if ((lw3 = lm->dictwid_map[w3]) < 0) E_FATAL("dictwid[%d] not in LM\n", w3); prev_tginfo = NULL; for (tginfo = lm->tginfo[lw2]; tginfo; tginfo = tginfo->next) { if (tginfo->w1 == lw1) break; prev_tginfo = tginfo; } if (!tginfo) { load_tginfo(lm, lw1, lw2); tginfo = lm->tginfo[lw2]; } else if (prev_tginfo) { prev_tginfo->next = tginfo->next; tginfo->next = lm->tginfo[lw2]; lm->tginfo[lw2] = tginfo; } tginfo->used = 1; /* Trigrams for w1,w2 now pointed to by tginfo */ n = tginfo->n_tg; tg = tginfo->tg; if ((i = find_tg(tg, n, lw3)) >= 0) { score = lm->prob3[tg[i].prob3].l + lm->inclass_ugscore[w3]; lm_last_access_type = LM3G_ACCESS_TG; } else { /* lm->n_tg_bo++; */ score = tginfo->bowt + lm3g_bg_score(w2, w3); }#if 0 printf("%5d %5d %5d -> %8d (%16s %16s %16s)\n", w1, w2, w3, score, word_dict->dict_list[UG_MAPID(lm, lw1)]->word, word_dict->dict_list[UG_MAPID(lm, lw2)]->word, word_dict->dict_list[UG_MAPID(lm, lw3)]->word);#endif return (score);}voidlm3g_cache_reset(void){ int32 i; lm_t *lm; tginfo_t *tginfo, *next_tginfo, *prev_tginfo; lm = lmp; for (i = 0; i < lm->ucount; i++) { prev_tginfo = NULL; for (tginfo = lm->tginfo[i]; tginfo; tginfo = next_tginfo) { next_tginfo = tginfo->next; if (!tginfo->used) { /* lm->n_tg_inmem -= tginfo->n_tg; */ listelem_free((void *) tginfo, sizeof(tginfo_t)); if (prev_tginfo) prev_tginfo->next = next_tginfo; else lm->tginfo[i] = next_tginfo; /* n_tgfree++; */ } else { tginfo->used = 0; prev_tginfo = tginfo; } } }}voidlm3g_cache_stats_dump(FILE * file){ /* FIXME: does nothing! */}voidlm_next_frame(void){}int32lm3g_raw_score(int32 score){ score -= lmp->log_wip; score = LWMUL(score, lmp->invlw); return score;}int32lm3g_access_type(void){ return lm_last_access_type;}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -