⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 latticenbest.cc

📁 这是一款很好用的工具包
💻 CC
📖 第 1 页 / 共 3 页
字号:
/*
 * LatticeNBest.cc --
 *	N-best generation from lattices
 *
 * 	(Originally contributed by Dustin Hillard, University of Washington,
 *	Viterbi N-best added by Jing Zheng, SRI International.)
 */

#ifndef lint
static char Copyright[] = "Copyright (c) 2004-2006 SRI International.  All Rights Reserved.";
static char RcsId[] = "@(#)$Header: /home/srilm/devel/lattice/src/RCS/LatticeNBest.cc,v 1.25 2006/01/09 19:15:48 stolcke Exp $";
#endif

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <errno.h>

#include "Lattice.h"

#include "Array.cc"
#include "LHash.cc"
#include "IntervalHeap.cc"
#include "zio.h"
#include "mkdir.h"

#define DebugPrintFatalMessages         0 
#define DebugPrintFunctionality         1 
// for large functionality listed in options of the program
#define DebugPrintOutLoop               2
// for out loop of the large functions or small functions
#define DebugPrintInnerLoop             3
// for inner loop of the large functions or outloop of small functions
#define DebugPrintInnerLoop             3
#define DebugAStar             		4


/* *************************
 * A path through the lattice (stored as reversed linked list)
 * ************************* */
class LatticeNBestPath
{
 friend class LatticeNBestHyp;

 public:
  LatticeNBestPath(NodeIndex node, LatticeNBestPath *predecessor);
  ~LatticeNBestPath();
  void linkto();	// add a reference to this path
  void release();	// remove a reference to this path
  unsigned getPath(Array<NodeIndex> &path);

  NodeIndex node;
  LatticeNBestPath *pred;

 private:
  unsigned numReferences;
};


/* *************************
 * An nbest hyp through lattice (and a custom sort function for the queue)
 * ************************* */
class LatticeNBestHyp
{
 public:
  LatticeNBestHyp(double score, LogP myForwardProb, 
                  NodeIndex myNodeIndex, int mySuccIndex, Boolean endOfSent,
		  LatticeNBestPath *nbestPath, unsigned myWordCnt,
		  LogP myAcoustic, LogP myNgram, LogP myLanguage,
		  LogP myPron, LogP myDuration,
		  LogP myXscore1, LogP myXscore2, LogP myXscore3,
		  LogP myXscore4, LogP myXscore5, LogP myXscore6,
		  LogP myXscore7, LogP myXscore8, LogP myXscore9);
  ~LatticeNBestHyp();
  
  double score; // score for path (forward prob plus total backward prob)
  LogP forwardProb; // forward prob so far on this path
  Boolean endOfSent;
  LatticeNBestPath *nbestPath;		// linked list of nodes to the start

  NodeIndex nodeIndex;  
  int succIndex;

  // Accumulated HTK scores over path
  unsigned wordCnt;                     // number of words (ignore non-words)
  LogP acoustic;			// acoustic model log score
  LogP ngram;				// ngram model log score
  LogP language;			// language model log score
  LogP pron;				// pronunciation log score
  LogP duration;			// duration log score
  LogP xscore1;			        // extra score #1
  LogP xscore2;			        // extra score #2
  LogP xscore3;			        // extra score #3
  LogP xscore4;			        // extra score #4
  LogP xscore5;			        // extra score #5
  LogP xscore6;			        // extra score #6
  LogP xscore7;			        // extra score #7
  LogP xscore8;			        // extra score #8
  LogP xscore9;			        // extra score #9

  Boolean writeHyp(int hypNum, Lattice &lat, NBestOptions &nbestOut);
  char *getHypFeature(SubVocab &ignoreWords, Lattice &lat,
						const char *multiwordSeparator);
};


static void
printDebugHyp(Lattice &lat, unsigned numOutput, unsigned numHyps,
						     LatticeNBestHyp &hyp)
{
  lat.dout() << "Lattice::computeNBest: "
  	     << "Hyp " << numOutput
	     << " " << numHyps
	     << " : " << hyp.score << " ";

  Array<NodeIndex> path;
  hyp.nbestPath->getPath(path);

  for (int n = 0; n < (int)path.size(); n++) {
    LatticeNode *thisNode = lat.findNode(path[n]);
    assert(thisNode != 0);
    if (thisNode->word != Vocab_None) {
      lat.dout() << lat.getWord(thisNode->word)
      		 << "(" << path[n] << ")"; 
      if (thisNode->htkinfo) {
	lat.dout() << "{" << thisNode->htkinfo->acoustic
		   << "}{" << thisNode->htkinfo->language << "} ";
	}
    }
  }
  lat.dout() << "\n";
}


/* *************************
 * A-star N-best generation
 * ************************* */

struct nbestLess {  
  // custom sort for nbest scores
  bool operator() (const LatticeNBestHyp *first, const LatticeNBestHyp *second)     { return (first->score < second->score) || 
											     (first->score == second->score && first->wordCnt < second->wordCnt); }
};
struct nbestGreater {  
  // custom sort for nbest scores
  bool operator() (const LatticeNBestHyp *first, const LatticeNBestHyp *second)     { return (first->score > second->score) || 
											     (first->score == second->score && first->wordCnt > second->wordCnt); }
};
struct nbestEqual {  
  // custom sort for nbest scores
  bool operator() (const LatticeNBestHyp *first, const LatticeNBestHyp *second)     { return first->score == second->score && first->wordCnt == second->wordCnt; }
};



struct SuccInfo {
    NodeIndex to;
    LogP bwScore;    
};

static int
compareSucc(const void *p1, const void *p2)
{
    LogP s1 = ((const SuccInfo *) p1)->bwScore;
    LogP s2 = ((const SuccInfo *) p2)->bwScore;

    if (s1 == s2) return 0;
    else if (s1 < s2) return 1;
    else return -1;
}

struct NodeInfo {
    int numSuccs;
    SuccInfo *succs;
    
    NodeInfo() { numSuccs = 0; succs = 0; };
    ~NodeInfo() { if (numSuccs) delete [] succs; succs = 0; };

    void sortSuccs() {
    	if (numSuccs) qsort(succs, numSuccs, sizeof(SuccInfo), compareSucc);
    };
};

static int
compareLogP(const void *p1, const void *p2) 
{
    LogP pr = (*(const LogP *)p1 - *(const LogP *)p2);
    
    if (pr == 0) return 0;
    else if (pr < 0) return 1;
    else return -1;
}

/*
 * Compute top N word sequences with highest probability paths through latttice
 */
Boolean 
Lattice::computeNBest(unsigned N, NBestOptions &nbestOut, SubVocab &ignoreWords,
			    const char *multiwordSeparator, unsigned maxHyps,
			    unsigned nbestDuplicates)
{
  /*
   * Find the top N hyps in the lattice using A* search
   *   First compute the forward and backward max prob paths for each node
   *   Then select the top N paths with A* search (sorting with the computed
   *   max prob paths)
   */

  /*
   * topological sort
   */
  unsigned numNodes = getNumNodes(); 

  NodeIndex *sortedNodes = new NodeIndex[numNodes];
  assert(sortedNodes != 0);
  unsigned numReachable = sortNodes(sortedNodes);
  
  if (numReachable != numNodes) {
    dout() << "Lattice::computeNBest: warning: called with unreachable nodes\n";
  }

  if (sortedNodes[0] != initial) {
    dout() << "Lattice::computeNBest: initial node is not first\n";
    delete [] sortedNodes;
    return LogP_Inf;
  }
  
  unsigned finalPosition = 0;
  for (finalPosition = 1; finalPosition < numReachable; finalPosition ++) {
    if (sortedNodes[finalPosition] == final) break;
  }
  if (finalPosition == numReachable) {
    dout() << "Lattice::computeNBest: final node is not reachable\n";
    delete [] sortedNodes;
    return LogP_Inf;
  }
  
  /*
   * compute fb viterbi probabilities
   */
  LogP *viterbiBackwardProbs = new LogP[maxIndex];
  LogP *viterbiForwardProbs  = new LogP[maxIndex];
  assert(viterbiBackwardProbs != 0 && viterbiForwardProbs != 0);

  NodeInfo *nodeInfos = new NodeInfo[maxIndex];
  assert(nodeInfos != 0);

  LogP bestProb =
      computeForwardBackwardViterbi(viterbiForwardProbs, viterbiBackwardProbs);
  
  int i;
  for (i = 0; i < (int)maxIndex; i++) {
      LatticeNode *node = nodes.find(i);      
    
      if (!node) continue;
      NodeInfo *ni = nodeInfos + i;
    
      ni->numSuccs = node->outTransitions.numEntries();
      ni->succs = new SuccInfo[ni->numSuccs];
      assert(ni->succs != 0);
      
      TRANSITER_T<NodeIndex, LatticeTransition> transIter(node->outTransitions);
      
      NodeIndex toNode;
      int j = 0;
      while (LatticeTransition *trans = transIter.next(toNode)) {
          LogP bwScore = trans->weight + viterbiBackwardProbs[toNode];
          
          ni->succs[j].to = toNode;
          ni->succs[j].bwScore = bwScore;
          j ++;
      }

      ni->sortSuccs();
    }
  
  if (debug(DebugPrintFunctionality)) {
      dout() << "Lattice::computeNBest: best FB prob: " << bestProb << endl;
  }
  
  /*
   * select top nbest hyps from lattice with A* search
   */
  
  if (debug(DebugPrintFunctionality)) {
    dout() << "Lattice::computeNBest: writing nbest list\n";
  }

  IntervalHeap<LatticeNBestHyp*, nbestLess, nbestGreater, nbestEqual>
				    hyps(maxHyps > 0 ? maxHyps : (2 * N + 1));

  LHash<VocabString, unsigned> hypsPrinted;	// for duplicate removal

  /*
   * 
   * Implement A* 
   *
   * 1 - Initialize priority queue with a null theory
   * 2 - Pop the highest score hyp
   * 3 - If end-of-sentence, output hyp (return to 2)
   * 4 - Create new hyps for all outgoing word transitions
   * 5 - Score each new extended hyp, re-insert to queue (possibly marking end-of-sentence)
   * 6 - Go to 2 
   *
   */

  // start queue with hyp containing only the initial node
  LatticeNBestPath *initialPath = new LatticeNBestPath(initial, 0);
  assert(initialPath != 0);

  LatticeNBestHyp *initialHyp =
  		new LatticeNBestHyp(LogP_One, LogP_One, initial, -1,
                                    false, initialPath, 0, 
				    LogP_One, LogP_One, LogP_One, LogP_One,
				    LogP_One, LogP_One, LogP_One, LogP_One,
				    LogP_One, LogP_One, LogP_One, LogP_One,
				    LogP_One, LogP_One);
  assert(initialHyp != 0);

  hyps.push(initialHyp);

  int outputHyps = 0;
  Boolean firstPruned = true;

  while (outputHyps < N && !hyps.empty()) {
    LatticeNBestHyp *topHyp = hyps.top_max(); // get hyp
    hyps.pop_max();
    if (topHyp->endOfSent) {
      // check to see if this hyp was already printed
      char *feature =
		topHyp->getHypFeature(ignoreWords, *this, multiwordSeparator);
      Boolean isDuplicate;
      unsigned *timesPrinted = hypsPrinted.insert(feature, isDuplicate);
      if (isDuplicate) {
	*timesPrinted += 1;
      } else {
	*timesPrinted = 1;
      }

      if (isDuplicate && *timesPrinted > nbestDuplicates) {
        if (debug(DebugPrintOutLoop)) {
	  dout() << "Lattice::computeNBest: not outputting hypothesis because it matches previously printed one\n";
	}
	
	// debugging output
	if (debug(DebugPrintInnerLoop)) {
	  printDebugHyp(*this, outputHyps, hyps.size(), *topHyp);
	}

        free(feature);
      } else {
        // output hyp
	outputHyps++;
	topHyp->writeHyp(outputHyps, *this, nbestOut);

	// debugging output
	if (debug(DebugPrintOutLoop)) {
	  printDebugHyp(*this, outputHyps, hyps.size(), *topHyp);
	}
      }
    } else {
      // top hyp is not an end of sentence, so extend it
      // expand all outgoing paths from current node

      LatticeNode *node = findNode(topHyp->nbestPath->node); 
      assert(node != 0);
      NodeIndex nodeIndex = topHyp->nodeIndex;
      int succIndex = topHyp->succIndex + 1;
      NodeInfo *nodeInfo = &(nodeInfos[nodeIndex]);
	
      if (succIndex < nodeInfo->numSuccs) {
        NodeIndex toNodeIndex = nodeInfo->succs[succIndex].to;
	// compute accumulated scores
        double score = topHyp->forwardProb + nodeInfo->succs[succIndex].bwScore;

        // compute the forward part of the score
	LogP forwardProb = score - viterbiBackwardProbs[toNodeIndex];

	unsigned cnt = topHyp->wordCnt;		// word count (ignore non-words)
	LogP acoustic = topHyp->acoustic;	// acoustic model log score
	LogP ngram = topHyp->ngram;		// ngram model log score
	LogP language = topHyp->language;	// language model log score
	LogP pron = topHyp->pron;		// pronunciation log score
	LogP duration = topHyp->duration;	// duration log score
	LogP xscore1 = topHyp->xscore1; 	// extra score #1
	LogP xscore2 = topHyp->xscore2; 	// extra score #2
	LogP xscore3 = topHyp->xscore3; 	// extra score #3
	LogP xscore4 = topHyp->xscore4; 	// extra score #4
	LogP xscore5 = topHyp->xscore5; 	// extra score #5
	LogP xscore6 = topHyp->xscore6; 	// extra score #6
	LogP xscore7 = topHyp->xscore7; 	// extra score #7
	LogP xscore8 = topHyp->xscore8; 	// extra score #8
	LogP xscore9 = topHyp->xscore9; 	// extra score #9

	if (node->htkinfo) {
	  if (!ignoreWord(node->word) &&	// NULL and pause nodes
	      !ignoreWords.getWord(node->word) &&
	      !vocab.isNonEvent(node->word) &&	// <s> and other non-events
	      node->word != vocab.seIndex())
	    cnt      += 1;

	  if (node->htkinfo->acoustic != HTK_undef_float) 
	    acoustic += node->htkinfo->acoustic;
	  if (node->htkinfo->ngram != HTK_undef_float) 
	    ngram    += node->htkinfo->ngram;
	  if (node->htkinfo->language != HTK_undef_float) 
	    language += node->htkinfo->language;
	  if (node->htkinfo->pron != HTK_undef_float)
	    pron     += node->htkinfo->pron;
	  if (node->htkinfo->duration != HTK_undef_float) 
	    duration += node->htkinfo->duration;
	  if (node->htkinfo->xscore1 != HTK_undef_float) 
	    xscore1  += node->htkinfo->xscore1;
	  if (node->htkinfo->xscore2 != HTK_undef_float) 
	    xscore2  += node->htkinfo->xscore2;
	  if (node->htkinfo->xscore3 != HTK_undef_float) 
	    xscore3  += node->htkinfo->xscore3;
	  if (node->htkinfo->xscore4 != HTK_undef_float) 
	    xscore4  += node->htkinfo->xscore4;
	  if (node->htkinfo->xscore5 != HTK_undef_float) 
	    xscore5  += node->htkinfo->xscore5;
	  if (node->htkinfo->xscore6 != HTK_undef_float) 
	    xscore6  += node->htkinfo->xscore6;
	  if (node->htkinfo->xscore7 != HTK_undef_float) 
	    xscore7  += node->htkinfo->xscore7;
	  if (node->htkinfo->xscore8 != HTK_undef_float) 
	    xscore8  += node->htkinfo->xscore8;
	  if (node->htkinfo->xscore9 != HTK_undef_float) 
	    xscore9  += node->htkinfo->xscore9;
	}
						    // add this node to path
	LatticeNBestPath *thisPath =
			new LatticeNBestPath(toNodeIndex, topHyp->nbestPath);
	assert(thisPath != 0);
    
        Boolean isFinal = (toNodeIndex == final);

	LatticeNBestHyp *expandedHyp =
	    new LatticeNBestHyp(score, forwardProb, toNodeIndex, -1,
				isFinal, thisPath, cnt, 
				    acoustic, ngram, language, pron, duration,
				xscore1, xscore2, xscore3, xscore4, xscore5,
				xscore6, xscore7, xscore8, xscore9);
	assert(expandedHyp != 0);
	if (maxHyps > 0 && hyps.size() >= maxHyps) {
	  LatticeNBestHyp *pruneHyp = hyps.top_min(); // get hyp
	  hyps.pop_min();
	  delete pruneHyp;
	  if (debug(DebugPrintOutLoop) ||
	      firstPruned && debug(DebugPrintFunctionality))
	  {
	    dout() << "Lattice::computeNBest: max number of hyps reached, pruning lowest score hyp\n";
	    firstPruned = false;
	  }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -