📄 latticenbest.cc
字号:
/*
* LatticeNBest.cc --
* N-best generation from lattices
*
* (Originally contributed by Dustin Hillard, University of Washington,
* Viterbi N-best added by Jing Zheng, SRI International.)
*/
#ifndef lint
static char Copyright[] = "Copyright (c) 2004-2006 SRI International. All Rights Reserved.";
static char RcsId[] = "@(#)$Header: /home/srilm/devel/lattice/src/RCS/LatticeNBest.cc,v 1.25 2006/01/09 19:15:48 stolcke Exp $";
#endif
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <errno.h>
#include "Lattice.h"
#include "Array.cc"
#include "LHash.cc"
#include "IntervalHeap.cc"
#include "zio.h"
#include "mkdir.h"
#define DebugPrintFatalMessages 0
#define DebugPrintFunctionality 1
// for large functionality listed in options of the program
#define DebugPrintOutLoop 2
// for out loop of the large functions or small functions
#define DebugPrintInnerLoop 3
// for inner loop of the large functions or outloop of small functions
#define DebugPrintInnerLoop 3
#define DebugAStar 4
/* *************************
* A path through the lattice (stored as reversed linked list)
* ************************* */
class LatticeNBestPath
{
friend class LatticeNBestHyp;
public:
LatticeNBestPath(NodeIndex node, LatticeNBestPath *predecessor);
~LatticeNBestPath();
void linkto(); // add a reference to this path
void release(); // remove a reference to this path
unsigned getPath(Array<NodeIndex> &path);
NodeIndex node;
LatticeNBestPath *pred;
private:
unsigned numReferences;
};
/* *************************
* An nbest hyp through lattice (and a custom sort function for the queue)
* ************************* */
class LatticeNBestHyp
{
public:
LatticeNBestHyp(double score, LogP myForwardProb,
NodeIndex myNodeIndex, int mySuccIndex, Boolean endOfSent,
LatticeNBestPath *nbestPath, unsigned myWordCnt,
LogP myAcoustic, LogP myNgram, LogP myLanguage,
LogP myPron, LogP myDuration,
LogP myXscore1, LogP myXscore2, LogP myXscore3,
LogP myXscore4, LogP myXscore5, LogP myXscore6,
LogP myXscore7, LogP myXscore8, LogP myXscore9);
~LatticeNBestHyp();
double score; // score for path (forward prob plus total backward prob)
LogP forwardProb; // forward prob so far on this path
Boolean endOfSent;
LatticeNBestPath *nbestPath; // linked list of nodes to the start
NodeIndex nodeIndex;
int succIndex;
// Accumulated HTK scores over path
unsigned wordCnt; // number of words (ignore non-words)
LogP acoustic; // acoustic model log score
LogP ngram; // ngram model log score
LogP language; // language model log score
LogP pron; // pronunciation log score
LogP duration; // duration log score
LogP xscore1; // extra score #1
LogP xscore2; // extra score #2
LogP xscore3; // extra score #3
LogP xscore4; // extra score #4
LogP xscore5; // extra score #5
LogP xscore6; // extra score #6
LogP xscore7; // extra score #7
LogP xscore8; // extra score #8
LogP xscore9; // extra score #9
Boolean writeHyp(int hypNum, Lattice &lat, NBestOptions &nbestOut);
char *getHypFeature(SubVocab &ignoreWords, Lattice &lat,
const char *multiwordSeparator);
};
static void
printDebugHyp(Lattice &lat, unsigned numOutput, unsigned numHyps,
LatticeNBestHyp &hyp)
{
lat.dout() << "Lattice::computeNBest: "
<< "Hyp " << numOutput
<< " " << numHyps
<< " : " << hyp.score << " ";
Array<NodeIndex> path;
hyp.nbestPath->getPath(path);
for (int n = 0; n < (int)path.size(); n++) {
LatticeNode *thisNode = lat.findNode(path[n]);
assert(thisNode != 0);
if (thisNode->word != Vocab_None) {
lat.dout() << lat.getWord(thisNode->word)
<< "(" << path[n] << ")";
if (thisNode->htkinfo) {
lat.dout() << "{" << thisNode->htkinfo->acoustic
<< "}{" << thisNode->htkinfo->language << "} ";
}
}
}
lat.dout() << "\n";
}
/* *************************
* A-star N-best generation
* ************************* */
struct nbestLess {
// custom sort for nbest scores
bool operator() (const LatticeNBestHyp *first, const LatticeNBestHyp *second) { return (first->score < second->score) ||
(first->score == second->score && first->wordCnt < second->wordCnt); }
};
struct nbestGreater {
// custom sort for nbest scores
bool operator() (const LatticeNBestHyp *first, const LatticeNBestHyp *second) { return (first->score > second->score) ||
(first->score == second->score && first->wordCnt > second->wordCnt); }
};
struct nbestEqual {
// custom sort for nbest scores
bool operator() (const LatticeNBestHyp *first, const LatticeNBestHyp *second) { return first->score == second->score && first->wordCnt == second->wordCnt; }
};
struct SuccInfo {
NodeIndex to;
LogP bwScore;
};
static int
compareSucc(const void *p1, const void *p2)
{
LogP s1 = ((const SuccInfo *) p1)->bwScore;
LogP s2 = ((const SuccInfo *) p2)->bwScore;
if (s1 == s2) return 0;
else if (s1 < s2) return 1;
else return -1;
}
struct NodeInfo {
int numSuccs;
SuccInfo *succs;
NodeInfo() { numSuccs = 0; succs = 0; };
~NodeInfo() { if (numSuccs) delete [] succs; succs = 0; };
void sortSuccs() {
if (numSuccs) qsort(succs, numSuccs, sizeof(SuccInfo), compareSucc);
};
};
static int
compareLogP(const void *p1, const void *p2)
{
LogP pr = (*(const LogP *)p1 - *(const LogP *)p2);
if (pr == 0) return 0;
else if (pr < 0) return 1;
else return -1;
}
/*
* Compute top N word sequences with highest probability paths through latttice
*/
Boolean
Lattice::computeNBest(unsigned N, NBestOptions &nbestOut, SubVocab &ignoreWords,
const char *multiwordSeparator, unsigned maxHyps,
unsigned nbestDuplicates)
{
/*
* Find the top N hyps in the lattice using A* search
* First compute the forward and backward max prob paths for each node
* Then select the top N paths with A* search (sorting with the computed
* max prob paths)
*/
/*
* topological sort
*/
unsigned numNodes = getNumNodes();
NodeIndex *sortedNodes = new NodeIndex[numNodes];
assert(sortedNodes != 0);
unsigned numReachable = sortNodes(sortedNodes);
if (numReachable != numNodes) {
dout() << "Lattice::computeNBest: warning: called with unreachable nodes\n";
}
if (sortedNodes[0] != initial) {
dout() << "Lattice::computeNBest: initial node is not first\n";
delete [] sortedNodes;
return LogP_Inf;
}
unsigned finalPosition = 0;
for (finalPosition = 1; finalPosition < numReachable; finalPosition ++) {
if (sortedNodes[finalPosition] == final) break;
}
if (finalPosition == numReachable) {
dout() << "Lattice::computeNBest: final node is not reachable\n";
delete [] sortedNodes;
return LogP_Inf;
}
/*
* compute fb viterbi probabilities
*/
LogP *viterbiBackwardProbs = new LogP[maxIndex];
LogP *viterbiForwardProbs = new LogP[maxIndex];
assert(viterbiBackwardProbs != 0 && viterbiForwardProbs != 0);
NodeInfo *nodeInfos = new NodeInfo[maxIndex];
assert(nodeInfos != 0);
LogP bestProb =
computeForwardBackwardViterbi(viterbiForwardProbs, viterbiBackwardProbs);
int i;
for (i = 0; i < (int)maxIndex; i++) {
LatticeNode *node = nodes.find(i);
if (!node) continue;
NodeInfo *ni = nodeInfos + i;
ni->numSuccs = node->outTransitions.numEntries();
ni->succs = new SuccInfo[ni->numSuccs];
assert(ni->succs != 0);
TRANSITER_T<NodeIndex, LatticeTransition> transIter(node->outTransitions);
NodeIndex toNode;
int j = 0;
while (LatticeTransition *trans = transIter.next(toNode)) {
LogP bwScore = trans->weight + viterbiBackwardProbs[toNode];
ni->succs[j].to = toNode;
ni->succs[j].bwScore = bwScore;
j ++;
}
ni->sortSuccs();
}
if (debug(DebugPrintFunctionality)) {
dout() << "Lattice::computeNBest: best FB prob: " << bestProb << endl;
}
/*
* select top nbest hyps from lattice with A* search
*/
if (debug(DebugPrintFunctionality)) {
dout() << "Lattice::computeNBest: writing nbest list\n";
}
IntervalHeap<LatticeNBestHyp*, nbestLess, nbestGreater, nbestEqual>
hyps(maxHyps > 0 ? maxHyps : (2 * N + 1));
LHash<VocabString, unsigned> hypsPrinted; // for duplicate removal
/*
*
* Implement A*
*
* 1 - Initialize priority queue with a null theory
* 2 - Pop the highest score hyp
* 3 - If end-of-sentence, output hyp (return to 2)
* 4 - Create new hyps for all outgoing word transitions
* 5 - Score each new extended hyp, re-insert to queue (possibly marking end-of-sentence)
* 6 - Go to 2
*
*/
// start queue with hyp containing only the initial node
LatticeNBestPath *initialPath = new LatticeNBestPath(initial, 0);
assert(initialPath != 0);
LatticeNBestHyp *initialHyp =
new LatticeNBestHyp(LogP_One, LogP_One, initial, -1,
false, initialPath, 0,
LogP_One, LogP_One, LogP_One, LogP_One,
LogP_One, LogP_One, LogP_One, LogP_One,
LogP_One, LogP_One, LogP_One, LogP_One,
LogP_One, LogP_One);
assert(initialHyp != 0);
hyps.push(initialHyp);
int outputHyps = 0;
Boolean firstPruned = true;
while (outputHyps < N && !hyps.empty()) {
LatticeNBestHyp *topHyp = hyps.top_max(); // get hyp
hyps.pop_max();
if (topHyp->endOfSent) {
// check to see if this hyp was already printed
char *feature =
topHyp->getHypFeature(ignoreWords, *this, multiwordSeparator);
Boolean isDuplicate;
unsigned *timesPrinted = hypsPrinted.insert(feature, isDuplicate);
if (isDuplicate) {
*timesPrinted += 1;
} else {
*timesPrinted = 1;
}
if (isDuplicate && *timesPrinted > nbestDuplicates) {
if (debug(DebugPrintOutLoop)) {
dout() << "Lattice::computeNBest: not outputting hypothesis because it matches previously printed one\n";
}
// debugging output
if (debug(DebugPrintInnerLoop)) {
printDebugHyp(*this, outputHyps, hyps.size(), *topHyp);
}
free(feature);
} else {
// output hyp
outputHyps++;
topHyp->writeHyp(outputHyps, *this, nbestOut);
// debugging output
if (debug(DebugPrintOutLoop)) {
printDebugHyp(*this, outputHyps, hyps.size(), *topHyp);
}
}
} else {
// top hyp is not an end of sentence, so extend it
// expand all outgoing paths from current node
LatticeNode *node = findNode(topHyp->nbestPath->node);
assert(node != 0);
NodeIndex nodeIndex = topHyp->nodeIndex;
int succIndex = topHyp->succIndex + 1;
NodeInfo *nodeInfo = &(nodeInfos[nodeIndex]);
if (succIndex < nodeInfo->numSuccs) {
NodeIndex toNodeIndex = nodeInfo->succs[succIndex].to;
// compute accumulated scores
double score = topHyp->forwardProb + nodeInfo->succs[succIndex].bwScore;
// compute the forward part of the score
LogP forwardProb = score - viterbiBackwardProbs[toNodeIndex];
unsigned cnt = topHyp->wordCnt; // word count (ignore non-words)
LogP acoustic = topHyp->acoustic; // acoustic model log score
LogP ngram = topHyp->ngram; // ngram model log score
LogP language = topHyp->language; // language model log score
LogP pron = topHyp->pron; // pronunciation log score
LogP duration = topHyp->duration; // duration log score
LogP xscore1 = topHyp->xscore1; // extra score #1
LogP xscore2 = topHyp->xscore2; // extra score #2
LogP xscore3 = topHyp->xscore3; // extra score #3
LogP xscore4 = topHyp->xscore4; // extra score #4
LogP xscore5 = topHyp->xscore5; // extra score #5
LogP xscore6 = topHyp->xscore6; // extra score #6
LogP xscore7 = topHyp->xscore7; // extra score #7
LogP xscore8 = topHyp->xscore8; // extra score #8
LogP xscore9 = topHyp->xscore9; // extra score #9
if (node->htkinfo) {
if (!ignoreWord(node->word) && // NULL and pause nodes
!ignoreWords.getWord(node->word) &&
!vocab.isNonEvent(node->word) && // <s> and other non-events
node->word != vocab.seIndex())
cnt += 1;
if (node->htkinfo->acoustic != HTK_undef_float)
acoustic += node->htkinfo->acoustic;
if (node->htkinfo->ngram != HTK_undef_float)
ngram += node->htkinfo->ngram;
if (node->htkinfo->language != HTK_undef_float)
language += node->htkinfo->language;
if (node->htkinfo->pron != HTK_undef_float)
pron += node->htkinfo->pron;
if (node->htkinfo->duration != HTK_undef_float)
duration += node->htkinfo->duration;
if (node->htkinfo->xscore1 != HTK_undef_float)
xscore1 += node->htkinfo->xscore1;
if (node->htkinfo->xscore2 != HTK_undef_float)
xscore2 += node->htkinfo->xscore2;
if (node->htkinfo->xscore3 != HTK_undef_float)
xscore3 += node->htkinfo->xscore3;
if (node->htkinfo->xscore4 != HTK_undef_float)
xscore4 += node->htkinfo->xscore4;
if (node->htkinfo->xscore5 != HTK_undef_float)
xscore5 += node->htkinfo->xscore5;
if (node->htkinfo->xscore6 != HTK_undef_float)
xscore6 += node->htkinfo->xscore6;
if (node->htkinfo->xscore7 != HTK_undef_float)
xscore7 += node->htkinfo->xscore7;
if (node->htkinfo->xscore8 != HTK_undef_float)
xscore8 += node->htkinfo->xscore8;
if (node->htkinfo->xscore9 != HTK_undef_float)
xscore9 += node->htkinfo->xscore9;
}
// add this node to path
LatticeNBestPath *thisPath =
new LatticeNBestPath(toNodeIndex, topHyp->nbestPath);
assert(thisPath != 0);
Boolean isFinal = (toNodeIndex == final);
LatticeNBestHyp *expandedHyp =
new LatticeNBestHyp(score, forwardProb, toNodeIndex, -1,
isFinal, thisPath, cnt,
acoustic, ngram, language, pron, duration,
xscore1, xscore2, xscore3, xscore4, xscore5,
xscore6, xscore7, xscore8, xscore9);
assert(expandedHyp != 0);
if (maxHyps > 0 && hyps.size() >= maxHyps) {
LatticeNBestHyp *pruneHyp = hyps.top_min(); // get hyp
hyps.pop_min();
delete pruneHyp;
if (debug(DebugPrintOutLoop) ||
firstPruned && debug(DebugPrintFunctionality))
{
dout() << "Lattice::computeNBest: max number of hyps reached, pruning lowest score hyp\n";
firstPruned = false;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -