📄 lattice.cc
字号:
LHashIter<VocabIndex,Prob> wordsIter(*words);
Prob *prob;
VocabIndex word;
while (prob = wordsIter.next(word)) {
// insert a HTK lattice link (which is a node between two NULL
// nodes) for each word in this alignment column,
// also put the log of the posterior prob in xscore1
NodeIndex newNode = getMaxIndex();
// replace the *DELETE* word with NULL node
if (word == inputMesh.deleteIndex) {
insertNode(NullNodeName, newNode);
} else {
insertNode(vocab.getWord(word), newNode);
}
if (word == vocab.ssIndex()) {
haveSentStart = true;
}
if (word == vocab.seIndex()) {
haveSentEnd = true;
}
LatticeNode *node = findNode(newNode);
assert(node != 0);
HTKWordInfo *linkinfo = new HTKWordInfo;
assert(linkinfo != 0);
node->htkinfo = htkinfos[htkinfos.size()] = linkinfo;
node->htkinfo->xscore1 = ProbToLogP(*prob);
LatticeTransition inTrans(ProbToLogP(*prob), 0);
LatticeTransition outTrans(LogP_One, 0);
insertTrans(fromNodeIndex, newNode, inTrans);
insertTrans(newNode, toNodeIndex, outTrans);
}
// the next column of words will follow this column
fromNodeIndex = toNodeIndex;
}
// make last node the final one
setFinal(fromNodeIndex);
// provide sentence start/end tags if not found in lattice
if (!haveSentStart) {
LatticeNode *node = findNode(initial);
assert(node != 0);
node->word = vocab.ssIndex();
}
if (!haveSentEnd) {
LatticeNode *node = findNode(final);
assert(node != 0);
node->word = vocab.seIndex();
}
return true;
}
Boolean
Lattice::writeCompactPFSG(File &file)
{
if (debug(DebugPrintFunctionality)) {
dout() << "Lattice::writeCompactPFSG: writing ";
}
if (duration != 0.0) {
fprintf(file, "name %s(duration=%lg)\n", name, duration);
} else {
fprintf(file, "name %s\n", name);
}
/*
* We remap the internal node indices to consecutive unsigned integers
* to allow a compact output representation.
* We iterate over all nodes, renumbering them, and also counting the
* number of transitions overall.
*/
// map nodeIndex to unsigned
LHash<NodeIndex,unsigned> nodeMap;
unsigned numNodes = 0;
unsigned numTransitions = 0;
fprintf(file, "nodes %d", getNumNodes());
LHashIter<NodeIndex, LatticeNode> nodeIter(nodes, nodeSort);
NodeIndex nodeIndex;
while (LatticeNode *node = nodeIter.next(nodeIndex)) {
*nodeMap.insert(nodeIndex) = numNodes ++;
numTransitions += node->outTransitions.numEntries();
fprintf(file, " %s", (nodeIndex == initial || nodeIndex == final) ?
NullNodeName : getWord(node->word));
}
fprintf(file, "\n");
if (initial != NoNode) {
fprintf(file, "initial %u\n", *nodeMap.find(initial));
}
if (final != NoNode) {
fprintf(file, "final %u\n", *nodeMap.find(final));
}
fprintf(file, "transitions %u\n", numTransitions);
if (debug(DebugPrintFunctionality)) {
dout() << numNodes << " nodes, "
<< numTransitions << " transitions\n";
}
nodeIter.init();
while (LatticeNode *node = nodeIter.next(nodeIndex)) {
unsigned *fromNodeId = nodeMap.find(nodeIndex);
NodeIndex toNode;
TRANSITER_T<NodeIndex,LatticeTransition>
transIter(node->outTransitions);
while (LatticeTransition *trans = transIter.next(toNode)) {
unsigned int *toNodeId = nodeMap.find(toNode);
assert(toNodeId != 0);
int logToPrint = LogPtoIntlog(trans->weight);
if (limitIntlogs && logToPrint < minIntlog) {
logToPrint = minIntlog;
}
fprintf(file, "%u %u %d\n", *fromNodeId, *toNodeId, logToPrint);
}
}
fprintf(file, "\n");
return true;
}
Boolean
Lattice::writePFSG(File &file)
{
if (debug(DebugPrintFunctionality)) {
dout() << "Lattice::writePFSG: writing ";
}
if (duration != 0.0) {
fprintf(file, "name %s(duration=%lg)\n", name, duration);
} else {
fprintf(file, "name %s\n", name);
}
NodeIndex nodeIndex;
unsigned numTransitions = 0;
fprintf(file, "nodes %d", maxIndex);
for (nodeIndex = 0; nodeIndex < maxIndex; nodeIndex ++) {
LatticeNode *node = nodes.find(nodeIndex);
if (node) {
numTransitions += node->outTransitions.numEntries();
}
fprintf(file, " %s",
(nodeIndex == initial || nodeIndex == final || node == 0) ?
NullNodeName : getWord(node->word));
}
fprintf(file, "\n");
if (initial != NoNode) {
fprintf(file, "initial %u\n", initial);
}
if (final != NoNode) {
fprintf(file, "final %u\n", final);
}
fprintf(file, "transitions %u\n", numTransitions);
if (debug(DebugPrintFunctionality)) {
dout() << maxIndex << " nodes, " << numTransitions << " transitions\n";
}
LHashIter<NodeIndex, LatticeNode> nodeIter(nodes, nodeSort);
while (LatticeNode *node = nodeIter.next(nodeIndex)) {
NodeIndex toNode;
TRANSITER_T<NodeIndex,LatticeTransition>
transIter(node->outTransitions);
while (LatticeTransition *trans = transIter.next(toNode)) {
int logToPrint = LogPtoIntlog(trans->weight);
if (limitIntlogs && logToPrint < minIntlog) {
logToPrint = minIntlog;
}
fprintf(file, "%u %u %d\n",
nodeIndex,
toNode,
logToPrint);
}
}
fprintf(file, "\n");
return true;
}
unsigned
Lattice::getNumTransitions()
{
unsigned numTransitions = 0;
LHashIter<NodeIndex, LatticeNode> nodeIter(nodes);
NodeIndex nodeIndex;
while (LatticeNode *node = nodeIter.next(nodeIndex)) {
numTransitions += node->outTransitions.numEntries();
}
return numTransitions;
}
// this is for debugging purpose
Boolean
Lattice::printNodeIndexNamePair(File &file)
{
if (debug(DebugPrintFunctionality)) {
dout() << "Lattice::printNodeIndexNamePair: "
<< "printing Index-Name pairs!\n";
}
LHashIter<NodeIndex, LatticeNode> nodeIter(nodes, nodeSort);
NodeIndex nodeIndex;
while (LatticeNode *node = nodeIter.next(nodeIndex)) {
fprintf(file, "%d %s (%d)\n", nodeIndex,
getWord(node->word), node->word);
}
return true;
}
Boolean
Lattice::readPFSGFile(File &file)
{
Boolean val;
while (fgetc(file) != EOF) {
fseek(file, -1, SEEK_CUR);
val = readPFSG(file);
while (fgetc(file) == '\n' || fgetc(file) == ' ') {}
fseek(file, -1, SEEK_CUR);
}
return val;
}
Boolean
Lattice::writePFSGFile(File &file)
{
return true;
}
/* **************************************************
some more complex functions of Lattice class
************************************************** */
// *****************************************************
// *******************algorithm*************************
// going through all the Null nodes,
// if nodeIndex is the initial or final node, skip,
// if nodeIndex is not a Null node, skip
//
// if nodeIndex is a Null node,
// going through all the inTransitions,
// collect weight for the inTransition,
// collect the source node s,
// remove the inTransition,
// going through all the outTransitions,
// collect the weight for the outTransition,
// combine it with the inTransition weight,
// insert an outTransition to s,
// remove the outTransition
LogP
Lattice::detectSelfLoop(NodeIndex nodeIndex)
{
LogP base = 10;
LogP weight = unit();
LatticeNode *node = nodes.find(nodeIndex);
if (!node) {
if (debug(DebugPrintFatalMessages)) {
dout() << "Fatal Error in Lattice::detectSelfLoop: "
<< nodeIndex << "\n";
}
exit(-1);
}
LatticeTransition *trans;
trans = node->outTransitions.find(nodeIndex);
if (!trans) {
return weight;
} else {
weight = combWeights(trans->weight, weight);
}
if (!weight) {
return weight; }
else {
return (-log(1-exp(weight*log(base)))/log(base));
}
}
// it removes all the nodes that have given word
Boolean
Lattice::removeAllXNodes(VocabIndex xWord)
{
if (debug(DebugPrintFunctionality)) {
dout() << "Lattice::removeAllXNodes: "
<< "removing all " << getWord(xWord) << endl;
}
LHashIter<NodeIndex, LatticeNode> nodeIter(nodes);
NodeIndex nodeIndex;
while (LatticeNode *node = nodeIter.next(nodeIndex)) {
if (debug(DebugPrintInnerLoop)) {
dout() << "Lattice::removeAllXNodes: processing nodeIndex "
<< nodeIndex << "\n";
}
if (nodeIndex == final || nodeIndex == initial) {
continue;
}
if (node->word == xWord) {
// this node is a Null node
if (debug(DebugPrintInnerLoop)) {
dout() << "Lattice::removeAllXNodes: "
<< "remove node " << nodeIndex << "\n";
}
LogP loopweight = detectSelfLoop(nodeIndex);
// remove the current node, all the incoming and outgoing edges
// and create new edges
// Notice that all the edges are recorded in two places:
// inTransitions and outTransitions
TRANSITER_T<NodeIndex,LatticeTransition>
inTransIter(node->inTransitions);
NodeIndex fromNodeIndex;
while (LatticeTransition *inTrans = inTransIter.next(fromNodeIndex)) {
if (debug(DebugPrintInnerLoop)) {
dout() << "Lattice::removeAllXNodes: "
<< " fromNodeIndex " << fromNodeIndex << "\n";
}
LogP inWeight = inTrans->weight;
if (fromNodeIndex == nodeIndex) {
continue;
}
TRANSITER_T<NodeIndex,LatticeTransition>
outTransIter(node->outTransitions);
NodeIndex toNodeIndex;
while (LatticeTransition *trans = outTransIter.next(toNodeIndex)) {
if (debug(DebugPrintInnerLoop)) {
dout() << "Lattice::removeAllXNodes: "
<< " toNodeIndex " << toNodeIndex << "\n";
}
if (toNodeIndex == nodeIndex) {
continue;
}
// loopweight is 1 in the prob domain and
// loopweight is 0 in the log domain, if no loop
// for the current node
LogP weight = combWeights(inWeight, trans->weight);
weight = combWeights(weight, loopweight);
unsigned flag = 0;
// record where pause nodes were eliminated
if (xWord != Vocab_None && xWord == vocab.pauseIndex()) {
flag = pauseTFlag;
}
LatticeTransition t(weight, flag);
// new transition inherits properties from both parents
t.flags |= inTrans->flags | trans->flags;
// ... except for "direct (non-pause) connection"
t.flags &= ~directTFlag;
// a non-pause connection is carried over if we are removing
// a null-node and each of the joined transitions was direct
if (xWord == Vocab_None &&
inTrans->flags&directTFlag && trans->flags&directTFlag)
{
t.flags |= directTFlag;
}
insertTrans(fromNodeIndex, toNodeIndex, t);
}
} // end of inserting new edges
// deleting xWord node
removeNode(nodeIndex);
} // end of processing xWord node
}
return true;
}
Boolean
Lattice::recoverPause(NodeIndex nodeIndex, Boolean loop, Boolean all)
{
if (debug(DebugPrintOutLoop)) {
dout() << "Lattice::recoverPause: "
<< "processing nodeIndex " << nodeIndex << "\n";
}
// this array is created to avoid inserting new elements into
// temporary index, while iterating over it.
TRANS_T<NodeIndex,LatticeTransition> newTransitions;
// going throught all the successive nodes of the current node (nodeIndex)
LatticeNode *node = findNode(nodeIndex);
// see if we want to insert a pause after this word unconditionally
Boolean alwaysInsertPause = all && !ignoreWord(node->word);
TRANSITER_T<NodeIndex,LatticeTransition>
outTransIter(node->outTransitions);
NodeIndex toNodeIndex;
while (LatticeTransition *trans = outTransIter.next(toNodeIndex)) {
// processing nodes at the next level
LatticeNode *toNode = findNode(toNodeIndex);
LogP weight = trans->weight;
Boolean direct = trans->getFlag(directTFlag);
// if we're inserting pauses everywhere OR
// if the current edge is a pause edge. insert a pause node
// and its two edges.
if (alwaysInsertPause && toNode->word != vocab.pauseIndex() ||
trans->getFlag(pauseTFlag)) {
NodeIndex newNodeIndex = dupNode(vocab.pauseIndex(), 0);
LatticeNode *newNode = findNode(newNodeIndex);
LatticeTransition *newTrans = newTransitions.insert(newNodeIndex);
newTrans->flags = 0;
newTrans->weight = weight;
LatticeTransition t(unit(), 0);
insertTrans(newNodeIndex, toNodeIndex, t);
// add self-loop
if (loo
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -