📄 hmmofngrams.cc
字号:
/*
* HMMofNgrams.cc --
* Hidden Markov Model of Ngram distributions
*
*/
#ifndef lint
static char Copyright[] = "Copyright (c) 1997-2006 SRI International. All Rights Reserved.";
static char RcsId[] = "@(#)$Header: /home/srilm/devel/lm/src/RCS/HMMofNgrams.cc,v 1.13 2006/01/05 20:21:27 stolcke Exp $";
#endif
#include <iostream>
using namespace std;
#include <stdlib.h>
#include "HMMofNgrams.h"
#include "Trellis.cc"
#include "Array.cc"
#define DEBUG_PRINT_WORD_PROBS 2 /* from LM.cc */
#define DEBUG_READ_STATS 3
#define DEBUG_PRINT_VITERBI 2
#define DEBUG_TRANSITIONS 4
#define DEBUG_STATE_PROBS 5
#define NO_LM "." /* pseudo-filename for null LM */
#define INLINE_LM "-" /* pseudo-file for inline LM */
const unsigned maxTransPerState = 1000;
HMMofNgrams::HMMofNgrams(Vocab &vocab, unsigned order)
: LM(vocab), order(order), trellis(maxWordsPerLine + 2 + 1), savedLength(0)
{
/*
* Remove standard vocab items not applicable to state space
*/
stateVocab.remove(stateVocab.unkIndex());
stateVocab.remove(stateVocab.ssIndex());
stateVocab.remove(stateVocab.seIndex());
stateVocab.remove(stateVocab.pauseIndex());
/*
* Add initial, final states
*/
initialState = stateVocab.addWord("INITIAL");
states.insert(initialState);
finalState = stateVocab.addWord("FINAL");
states.insert(finalState);
}
HMMofNgrams::~HMMofNgrams()
{
LHashIter<HMMIndex,HMMState> iter(states);
HMMState *state;
HMMIndex index;
while (state = iter.next(index)) {
state->~HMMState();
}
}
/*
* Propagate changes to Debug state to component models
*/
void
HMMofNgrams::debugme(unsigned level)
{
LHashIter<HMMIndex,HMMState> iter(states);
HMMState *state;
HMMIndex index;
while (state = iter.next(index)) {
if (state->ngram) {
state->ngram->debugme(level);
}
}
Debug::debugme(level);
}
ostream &
HMMofNgrams::dout(ostream &stream)
{
LHashIter<HMMIndex,HMMState> iter(states);
HMMState *state;
HMMIndex index;
while (state = iter.next(index)) {
if (state->ngram) {
state->ngram->dout(stream);
}
}
return Debug::dout(stream);
}
/*
* Read HMMofNgrams from file.
* File format: 1 line per state, containing
*
* state name
* Ngram model file name
* follow-state1 transitiion-prob1
* follow-state2 transitiion-prob2.
* etc.
*/
Boolean
HMMofNgrams::read(File &file, Boolean limitVocab)
{
char *line;
VocabString fields[maxTransPerState + 3];
while (line = file.getline()) {
unsigned numFields =
vocab.parseWords(line, fields, maxTransPerState + 3);
if (numFields == maxTransPerState + 3) {
file.position() << "too many fields\n";
return false;
}
if (numFields < 2 || numFields % 2 != 0) {
file.position() << "wrong number of fields\n";
return false;
}
HMMIndex stateIndex = stateVocab.addWord(fields[0]);
/*
* Clear all current transitions
*/
states.insert(stateIndex)->transitions.clear();
/*
* Read transitions out of state
*/
for (unsigned i = 2; i < numFields; i += 2) {
HMMIndex toIndex = stateVocab.addWord(fields[i]);
states.insert(toIndex);
if (toIndex == initialState) {
file.position() << "illegal transition to initial state\n";
return false;
}
double prob;
if (sscanf(fields[i + 1], "%lf", &prob) != 1) {
file.position() << "bad transition prob "
<< fields[i + 1] << endl;
return false;
}
*(states.find(stateIndex)->transitions.insert(toIndex)) = prob;
}
/*
* Read LM for state
*/
if (stateIndex == initialState || stateIndex == finalState) {
if (strcmp(fields[1], NO_LM) != 0) {
file.position() << "ngram not allowed on initial/final state\n";
return false;
}
} else {
HMMState *state = states.insert(stateIndex);
/*
* Check for identity of the Ngram filenames.
* If they are identical (and not "-") assume that the models
* are the same and avoid reloading them.
*/
if (state->ngramName &&
strcmp(state->ngramName, INLINE_LM) != 0 &&
strcmp(state->ngramName, fields[1]) == 0)
{
if (debug(DEBUG_READ_STATS)) {
dout() << "reusing state ngram " << state->ngramName
<< endl;
}
} else {
if (state->ngramName) free(state->ngramName);
state->ngramName = strdup(fields[1]);
assert(state->ngramName != 0);
delete state->ngram;
state->ngram = new Ngram(vocab, order);
assert(state->ngram != 0);
state->ngram->debugme(debuglevel());
Boolean status;
if (strcmp(state->ngramName, INLINE_LM) == 0) {
status = state->ngram->read(file, limitVocab);
} else {
File ngramFile(state->ngramName, "r", false);
if (ngramFile.error()) {
file.position() << "error opening Ngram file "
<< state->ngramName << endl;
return false;
}
status = state->ngram->read(ngramFile, limitVocab);
}
if (!status) {
file.position() << "bad Ngram file " << fields[1] << endl;
return false;
}
}
}
}
/*
* Ensure that all states (except initial and final) have
* an LM defined
*/
LHashIter<HMMIndex,HMMState> iter(states);
HMMState *state;
HMMIndex stateIndex;
while (state = iter.next(stateIndex)) {
if (stateIndex != initialState && stateIndex != finalState &&
!state->ngram)
{
file.position() << "no LM defined for state "
<< stateVocab.getWord(stateIndex) << endl;
return false;
}
}
return true;
}
void
HMMofNgrams::write(File &file)
{
LHashIter<HMMIndex,HMMState> iter(states);
HMMState *state;
HMMIndex stateIndex;
while (state = iter.next(stateIndex)) {
VocabString stateName = stateVocab.getWord(stateIndex);
fprintf(file, "%s %s", stateName, (state->ngram ? INLINE_LM : NO_LM));
LHashIter<HMMIndex, Prob> transIter(state->transitions);
HMMIndex toIndex;
Prob *prob;
while(prob = transIter.next(toIndex)) {
fprintf(file, " %s %lf",
stateVocab.getWord(toIndex), (double)*prob);
}
fprintf(file, "\n");
/*
* Output the Ngram inline
*/
if (state->ngram) {
state->ngram->write(file);
}
}
}
/*
* LM state change: re-read the model from file
*/
void
HMMofNgrams::setState(const char *state)
{
char fileName[201];
if (sscanf(state, " %200s ", fileName) != 1) {
cerr << "no filename found in state info\n";
} else {
File lmFile(fileName, "r", false);
if (lmFile.error()) {
cerr << "error opening HMM file " << fileName << endl;
return;
}
if (!read(lmFile)) {
cerr << "failed to read HMM from " << fileName << endl;
return;
}
}
}
/*
* Forward algorithm for prefix probability computation
*/
LogP
HMMofNgrams::prefixProb(VocabIndex word, const VocabIndex *context,
LogP &contextProb, TextStats &stats)
{
/*
* pos points to the column currently being computed (we skip over the
* initial <s> token)
* prefix points to the tail of context that is used for conditioning
* the current word.
*/
unsigned pos;
int prefix;
if (context == 0) {
/*
* Reset the computation to the last iteration in the loop below
*/
pos = prevPos;
prefix = 0;
context = prevContext;
trellis.init(pos);
} else {
unsigned len = Vocab::length(context);
assert(len <= maxWordsPerLine);
/*
* Save these for possible recomputation with different
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -