📄 latticeexpand.cc
字号:
/*
* LatticeExpand.cc --
* Lattice expansion and LM rescoring algorithms
*
*/
#ifndef lint
static char Copyright[] = "Copyright (c) 1997-2006 SRI International. All Rights Reserved.";
static char RcsId[] = "@(#)$Header: /home/srilm/devel/lattice/src/RCS/LatticeExpand.cc,v 1.5 2006/01/06 05:34:22 stolcke Exp $";
#endif
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include "Lattice.h"
#include "LHash.cc"
#include "Map2.cc"
#include "Array.cc"
#ifdef INSTANTIATE_TEMPLATES
INSTANTIATE_MAP2(NodeIndex, VocabContext, NodeIndex);
INSTANTIATE_LHASH(VocabIndex,PackedNode);
#endif
/*
* If the intlog weights of two transitions differ by no more than this
* they are considered identical in PackedNodeList::packNodes().
*/
#define PACK_TOLERANCE 0
#define DebugPrintFatalMessages 0
#define DebugPrintFunctionality 1
// for large functionality listed in options of the program
#define DebugPrintOutLoop 2
// for out loop of the large functions or small functions
#define DebugPrintInnerLoop 3
// for inner loop of the large functions or outloop of small functions
#ifndef USE_SARRAY_MAP2
/*
* Word ngram sorting function
* (used to iterate over contexts in node expansion maps in same order
* regardless of underlying datastructure)
*/
static int
ngramCompare(const VocabIndex *n1, const VocabIndex *n2)
{
return SArray_compareKey(n1, n2);
}
#endif /* USE_SARRAY_MAP2 */
/* this code is to replace weights on the links of a given lattice with
* the LM weights.
*/
Boolean
Lattice::replaceWeights(LM &lm)
{
if (debug(DebugPrintFunctionality)) {
dout() << "Lattice::replaceWeights: "
<< "replacing weights with new LM\n";
}
LHashIter<NodeIndex, LatticeNode> nodeIter(nodes);
NodeIndex nodeIndex;
while (LatticeNode *node = nodeIter.next(nodeIndex)) {
NodeIndex toNodeIndex;
VocabIndex wordIndex;
if (nodeIndex == initial) {
wordIndex = vocab.ssIndex();
} else {
wordIndex = node->word;
}
// need to check to see whether the word is in the vocab
TRANSITER_T<NodeIndex,LatticeTransition> transIter(node->outTransitions);
while (transIter.next(toNodeIndex)) {
LatticeNode * toNode = nodes.find(toNodeIndex);
VocabIndex toWordIndex;
LogP weight;
if (toNodeIndex == final) {
toWordIndex = vocab.seIndex(); }
else {
toWordIndex = toNode->word; }
if (toWordIndex == Vocab_None || toWordIndex == lm.vocab.pauseIndex()) {
/*
* NULL and pause nodes don't receive an language model weight
*/
weight = LogP_One;
} else {
VocabIndex context[2];
context[0] = wordIndex;
context[1] = Vocab_None;
weight = lm.wordProb(toWordIndex, context);
}
setWeightTrans(nodeIndex, toNodeIndex, weight);
}
}
return true;
}
/*
* Compute outgoing transition prob on demand. This saves LM computation
* for transitions that are cached.
*/
static Boolean
computeOutWeight(PackInput &packInput)
{
if (packInput.lm != 0) {
VocabIndex context[3];
context[0] = packInput.wordName;
context[1] = packInput.fromWordName;
context[2] = Vocab_None;
packInput.outWeight = packInput.lm->wordProb(packInput.toWordName, context);
packInput.lm = 0;
}
return true;
}
/* this function tries to pack together nodes in lattice
* 1) for non-self loop case: only when trigram prob exists,
* the from nodes with the same wordName will be packed;
* 2) for self loop case:
* the from nodes with the same wordName will be packed,
* regardless whether the trigram prob exists.
* But, the bigram and trigram will have separate nodes,
* which is reflected in two different out transitions from
* the mid node to the two different toNodes (bigram and trigram)
*/
Boolean
PackedNodeList::packNodes(Lattice &lat, PackInput &packInput)
{
PackedNode *packedNode = packedNodesByFromNode.find(packInput.fromWordName);
if (!packedNode && lastPackedNode != 0 &&
(packInput.toNodeIndex == lastPackedNode->toNode &&
computeOutWeight(packInput) &&
abs(LogPtoIntlog(packInput.outWeight) -
LogPtoIntlog(lastPackedNode->outWeight)) <= PACK_TOLERANCE))
{
packedNode = lastPackedNode;
NodeIndex midNode = packedNode->midNodeIndex;
// the fromNode could be different this time around, so we need to
// re-cache the mid-node
packedNode = packedNodesByFromNode.insert(packInput.fromWordName);
packedNode->midNodeIndex = midNode;
packedNode->toNode = packInput.toNodeIndex;
packedNode->outWeight = packInput.outWeight;
if (packInput.toNodeId == 2) {
packedNode->toNodeId = 2;
} else if (packInput.toNodeId == 3) {
packedNode->toNodeId = 3;
} else {
packedNode->toNodeId = 0;
}
lastPackedNode = packedNode;
}
if (packedNode) {
// only one transition is needed;
LatticeTransition t(packInput.inWeight, packInput.inFlag);
lat.insertTrans(packInput.fromNodeIndex, packedNode->midNodeIndex, t);
if (!packInput.toNodeId) {
// this is for non-self-loop node, no additional outgoing trans
// need to be added.
LatticeNode *midNode = lat.findNode(packedNode->midNodeIndex);
LatticeTransition * trans =
midNode->outTransitions.find(packInput.toNodeIndex);
// if it is another toNode, we need to create a link to it.
if (!trans) {
// it indicates that there is another ngram node needed.
computeOutWeight(packInput);
LatticeTransition t(packInput.outWeight, packInput.outFlag);
lat.insertTrans(packedNode->midNodeIndex, packInput.toNodeIndex, t);
if (debug(DebugPrintInnerLoop)) {
dout() << "PackedNodeList::packNodes: \n"
<< "insert (" << packInput.fromNodeIndex
<< ", " << packedNode->midNodeIndex << ", "
<< packInput.toNodeIndex << ")\n";
}
}
return true;
} else {
if (debug(DebugPrintInnerLoop)) {
dout() << "PackedNodeList::packNodes: \n"
<< "reusing (" << packInput.fromNodeIndex
<< ", " << packedNode->midNodeIndex << ", "
<< packInput.toNodeIndex << ")\n";
}
}
// the following part is for selfLoop case
// the toNode is for p(a | a, x) doesn't exist.
if (packInput.toNodeId == 2) {
if (!packedNode->toNode) {
computeOutWeight(packInput);
LatticeTransition t(packInput.outWeight, packInput.outFlag);
lat.insertTrans(packedNode->midNodeIndex, packInput.toNodeIndex, t);
packedNode->toNode = packInput.toNodeIndex;
}
return true;
}
// the toNode is for p(a | a, x) exists.
if (packInput.toNodeId == 3) {
if (!packedNode->toNode) {
computeOutWeight(packInput);
LatticeTransition t(packInput.outWeight, packInput.outFlag);
lat.insertTrans(packedNode->midNodeIndex, packInput.toNodeIndex, t);
packedNode->toNode = packInput.toNodeIndex;
}
return true;
}
} else {
// this is the first time to create triple.
NodeIndex newNodeIndex = lat.dupNode(packInput.wordName, markedFlag);
LatticeTransition t1(packInput.inWeight, packInput.inFlag);
lat.insertTrans(packInput.fromNodeIndex, newNodeIndex, t1);
computeOutWeight(packInput);
LatticeTransition t2(packInput.outWeight, packInput.outFlag);
lat.insertTrans(newNodeIndex, packInput.toNodeIndex, t2);
if (debug(DebugPrintInnerLoop)) {
dout() << "PackedNodeList::packNodes: \n"
<< "insert (" << packInput.fromNodeIndex
<< ", " << newNodeIndex << ", "
<< packInput.toNodeIndex << ")\n";
}
packedNode = packedNodesByFromNode.insert(packInput.fromWordName);
packedNode->midNodeIndex = newNodeIndex;
packedNode->toNode = packInput.toNodeIndex;
packedNode->outWeight = packInput.outWeight;
if (packInput.toNodeId == 2) {
packedNode->toNodeId = 2;
} else if (packInput.toNodeId == 3) {
packedNode->toNodeId = 3;
} else {
packedNode->toNodeId = 0;
}
lastPackedNode = packedNode;
}
return true;
}
// *************************************************
// compact expansion to trigram
// *************************************************
/* Basic Algorithm:
* Try to expand self loop to accomodate trigram
* the basic idea has two steps:
* 1) ignore the loop edge and process other edge combinations
* just like in other cases, this is done in the main expandNodeToTrigram
* program
* 2) IN THIS PROGRAM:
* a) duplicate the loop node (called postNode);
* b) add an additional node (called preNode) between fromNode and the
* loop node (postNode);
* c) create links between fromNode, preNode, postNode and toNode; and
* create the loop edge on the loop node (postNode).
*/
void
Lattice::initASelfLoopDB(SelfLoopDB &selfLoopDB, LM &lm,
NodeIndex nodeIndex, LatticeNode *node,
LatticeTransition *trans)
{
selfLoopDB.preNodeIndex = selfLoopDB.postNodeIndex2 =
selfLoopDB.postNodeIndex3 = 0;
selfLoopDB.nodeIndex = nodeIndex;
selfLoopDB.selfTransFlags = trans->flags;
selfLoopDB.wordName = node->word;
VocabIndex context[3];
context[0] = selfLoopDB.wordName;
context[1] = selfLoopDB.wordName;
context[2] = Vocab_None;
selfLoopDB.loopProb = lm.wordProb(selfLoopDB.wordName, context);
}
void
Lattice::initBSelfLoopDB(SelfLoopDB &selfLoopDB, LM &lm,
NodeIndex fromNodeIndex, LatticeNode * fromNode,
LatticeTransition *fromTrans)
{
// reinitialize the preNode
selfLoopDB.preNodeIndex = 0;
//
selfLoopDB.fromNodeIndex = fromNodeIndex;
selfLoopDB.fromWordName = fromNode->word;
//
selfLoopDB.fromSelfTransFlags = fromTrans->flags;
// compute prob for the link between preNode and postNode
VocabIndex context[3];
context[0] = selfLoopDB.wordName;
context[1] = selfLoopDB.fromWordName;
context[2] = Vocab_None;
selfLoopDB.prePostProb =
lm.wordProb(selfLoopDB.wordName, context);
// compute prob for fromPreProb;
context[0] = selfLoopDB.fromWordName;
context[1] = Vocab_None;
selfLoopDB.fromPreProb =
lm.wordProb(selfLoopDB.wordName, context);
}
void
Lattice::initCSelfLoopDB(SelfLoopDB &selfLoopDB, NodeIndex toNodeIndex,
LatticeTransition *toTrans)
{
selfLoopDB.toNodeIndex = toNodeIndex;
selfLoopDB.selfToTransFlags = toTrans->flags;
}
/*
* creating an expansion network for a self loop node.
* the order in which the network is created is reverse:
* 1) build the part of the network starting from postNode to toNode
* 2) use PackedNodeList class function to build the part of
* the network starting from fromNode to PostNode.
*
*/
Boolean
Lattice::expandSelfLoop(LM &lm, SelfLoopDB &selfLoopDB,
PackedNodeList &packedSelfLoopNodeList)
{
unsigned id = 0;
NodeIndex postNodeIndex, toNodeIndex = selfLoopDB.toNodeIndex;
LogP fromPreProb = selfLoopDB.fromPreProb;
LogP prePostProb = selfLoopDB.prePostProb;
VocabIndex wordName = selfLoopDB.wordName;
if (debug(DebugPrintOutLoop)) {
dout() << "Lattice::expandSelfLoop: "
<< "nodeIndex (" << selfLoopDB.nodeIndex << ")\n";
}
// create the part of the network from postNode to toNode
// if it doesn't exist.
// first compute the probs of the links in that part.
VocabIndex context[3];
context[0] = wordName;
context[1] = wordName;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -