⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 trellis.cc

📁 这是一款很好用的工具包
💻 CC
字号:
/* * Trellis.cc -- *	Finite-state trellis dynamic programming.  This file contains functions for * Trellis and its associated classes: TrellisNode, TrellisSlice, TrellisNBest, * and TrellisIter. */#ifndef _Trellis_cc_#define _Trellis_cc_#ifndef lintstatic char Trellis_Copyright[] = "Copyright (c) 1995,1997,2001 SRI International.  All Rights Reserved.";static char Trellis_RcsId[] = "@(#)$Header: /home/srilm/devel/lm/src/RCS/Trellis.cc,v 1.21 2006/01/05 20:21:27 stolcke Exp $";#endif#include <iostream>using namespace std;#include <string.h>#include <stdlib.h>#include <assert.h>#include "Trellis.h"#include "LHash.cc"#define INSTANTIATE_TRELLIS(StateT) \    INSTANTIATE_LHASH(StateT,TrellisNode<StateT>); \    template class Trellis<StateT>template <class StateT>Trellis<StateT>::Trellis(unsigned len, unsigned numNbest)  : trellisSize(len), numNbest(numNbest){    assert(len > 0);    trellis = new TrellisSlice<StateT> [len];    assert(trellis != 0);    init(0);}template <class StateT>Trellis<StateT>::~Trellis(){    delete [] trellis;}template <class StateT>voidTrellis<StateT>::clear(){    /*     * This function used to clear the entire trellis, which is wasteful     * since we typically only ever use a small fraction of its full length.     * Clearing of old entries is now done incrementally, on-demand in     * TrellisSlice::init().     */    init();}template <class StateT>voidTrellis<StateT>::init(unsigned t){    assert(t < trellisSize);    currTime = t;    trellis[t].init();			// Initialize the time slice t.}template <class StateT>voidTrellis<StateT>::step(){    currTime ++;    assert(currTime < trellisSize);    trellis[currTime].init();}/* * Explicitly set the total and max probability of a path ending at the given * state */template <class StateT>voidTrellis<StateT>::setProb(StateT state, LogP prob){    TrellisSlice<StateT>& currSlice = trellis[currTime];    Boolean foundP;    TrellisNode<StateT> *node = currSlice.insert(state, foundP);    node->lprob = prob;    if (foundP) {	if (node->nbestSize() > 0 && prob > node->nbest[0].score) {	    node->nbest[0].score = prob;	}	return;    }    /*     * Not found.  Create a new entry.     */    node->backlpr = LogP_Zero;    node->backmax = LogP_Zero;    node->nbest.init(numNbest);    if (node->nbestSize() > 0) {	node->nbest[0].score = prob;    }}template <class StateT>LogPTrellis<StateT>::getLogP(StateT state, unsigned t){    assert(t <= currTime);    TrellisNode<StateT> *node = trellis[t].find(state);    return (node? node->lprob : LogP_Zero);}template <class StateT>LogPTrellis<StateT>::getMax(StateT state, unsigned t, LogP &backmax){    assert(t <= currTime);    TrellisNode<StateT> *node = trellis[t].find(state);    if (node && node->nbestSize() > 0) {	backmax = node->backmax;	return node->nbest[0].score;    }    backmax = LogP_Zero;    return LogP_Zero;}template<class StateT>voidTrellis<StateT>::update(StateT oldState, StateT newState, LogP trans){    assert(currTime > 0 && currTime < trellisSize);    TrellisSlice<StateT>& lastSlice = trellis[currTime-1];    TrellisSlice<StateT>& currSlice = trellis[currTime];    TrellisNode<StateT> *oldNode = lastSlice.find(oldState);    /*     * If the predecessor state doesn't exist its probability is     * implicitly zero and we have nothing to do!     */    if (!oldNode) {	return;    }    Boolean foundP;    TrellisNode<StateT> *newNode = currSlice.insert(newState, foundP);    LogP2 newProb = oldNode->lprob + trans;	// Accumulate total FW prob.    if (!foundP) {	newNode->lprob = (LogP)newProb;	newNode->backlpr = LogP_Zero;	newNode->backmax = LogP_Zero;	newNode->nbest.init(numNbest);    } else {	newNode->lprob = (LogP)AddLogP(newNode->lprob, newProb);    }    /*     * Update Viterbi related info.     */    for (unsigned i = 0; i < oldNode->nbestSize(); i++) {	LogP totalProb = oldNode->nbest[i].score + trans;	newNode->nbest.insert(Hyp<StateT>(totalProb, oldState, i));    }}template <class StateT>LogPTrellis<StateT>::sumLogP(unsigned t){    assert(t <= currTime);    return trellis[t].sum();}template <class StateT>StateTTrellis<StateT>::max(unsigned t){    assert(t <= currTime);    return trellis[t].max();}template <class StateT>voidTrellis<StateT>::setBackProb(StateT state, LogP prob){    TrellisNode<StateT> *node = trellis[backTime].find(state);    if (!node) {	cerr << "trying to set backward prob for nonexistent node " << state	     << " at time " << backTime << endl;	return;    }    node->backlpr = prob;    if (prob > node->backmax) {	node->backmax = prob;    }}template <class StateT>LogPTrellis<StateT>::getBackLogP(StateT state, unsigned t){    assert(t <= currTime);    TrellisNode<StateT> *node = trellis[t].find(state);    return (node? node->backlpr : LogP_Zero);}template <class StateT>voidTrellis<StateT>::initBack(unsigned t){    assert(t <= currTime);    backTime = t;}template <class StateT>voidTrellis<StateT>::stepBack(){    assert(backTime > 0);    backTime --;}template <class StateT>voidTrellis<StateT>::updateBack(StateT oldState, StateT newState, LogP trans){    assert(backTime != (unsigned)-1);	/* check for underflow */    TrellisSlice<StateT>& currSlice = trellis[backTime];    TrellisSlice<StateT>& nextSlice = trellis[backTime + 1];    TrellisNode<StateT> *nextNode = nextSlice.find(newState);    /*     * If the successor state doesn't exist its probability is     * implicitly zero and we have nothing to do!     */    if (!nextNode) {	return;    }    TrellisNode<StateT> *thisNode = currSlice.find(oldState);    if (!thisNode) {	cerr << "trying to update backward prob for nonexistent node "	     << oldState << " at time " << backTime << endl;	return;    }    /* Accumulate total backward prob    */    LogP2 thisProb = nextNode->backlpr + trans;    thisNode->backlpr = (LogP)AddLogP(thisNode->backlpr, thisProb);    LogP totalMax = nextNode->backmax + trans;    if (totalMax > thisNode->backmax) {	thisNode->backmax = totalMax;    }}//-------------Viterbi backtrace algorithms-------------------------------/* * Returns in "path" the most likely partial path of the given length, len. * We obtain this by calling the overloaded viterbi() with an unmapped * lastState, which causes it to default to the most likely last state. */template <class StateT>unsignedTrellis<StateT>::viterbi(StateT *path, unsigned len){    LogP dummy;    StateT lastState;    Map_noKey(lastState);    return nbest_viterbi(path, len, 0, dummy, lastState);}/* Same as viterbi(), but instead returns the nth best partial path */template <class StateT>unsignedTrellis<StateT>::nbest_viterbi(StateT *path, unsigned len, unsigned nth, LogP& score){    StateT lastState;    Map_noKey(lastState);    return nbest_viterbi(path, len, nth, score, lastState);}/* * If lastState is unmapped, this returns in "path" the Viterbi backtrace * of the nth best partial path of the given length from the n-best of all * the nbest lists in the required timeslice.  Alternately, lastState may be * mapped, in which case, the returned path is just the nth best partial * path of the given length that ends at the given state. */template <class StateT>unsignedTrellis<StateT>::nbest_viterbi(StateT *path, unsigned len, unsigned n,					    LogP &score, StateT lastState){    if (len > currTime + 1) {		  // Sanity check	len = currTime + 1;    }    assert(len > 0 && len <= trellisSize);    if (n >= numNbest) {	return 0;    }    StateT currState;    int currWhichbest;        /*     * Suppose lastState is explicitly given. i.e., mapped.  Then we     * backtrace from this state's nth best hyp.  Otherwise, we     * construct the nbest from the required time slice, determine     * which state actually ends the nth overall-best hyp and     * backtrace from that state.     */    if (Map_noKeyP(lastState)) {	TrellisNBestList<StateT>& nblist = trellis[len-1].nbest(numNbest);	currState = nblist[n].prev;	currWhichbest = nblist[n].whichbest;	score = nblist[n].score;    } else {	currState = lastState;	currWhichbest = n;	TrellisNode<StateT> *node = trellis[len-1].find(currState);	if (!node) {	    return 0;	}	score = node->nbest[n].score;    }    unsigned pos = len;    while (!Map_noKeyP(currState)) {	assert(pos > 0);	pos --;	path[pos] = currState;	TrellisNode<StateT> *currNode = trellis[pos].find(currState);	assert(currNode);	currState = currNode->nbest[currWhichbest].prev;	currWhichbest = currNode->nbest[currWhichbest].whichbest;    }    if (pos != 0) {		// Backtrace failed before reaching start	  return 0;    }    return len;}//------------------ Slice related functions -----------------------------------template<class StateT>ostream&operator<<(ostream& os, const TrellisSlice<StateT>& slice){    LHashIter<StateT, TrellisNode<StateT> > iter(slice.nodes);    TrellisNode<StateT>* node;    StateT state;    while (node = iter.next(state))    os << "  State: [" << state << "],\t" << node->nbestSize() << "-Best = "       << *node << endl;    return os;}template <class StateT>TrellisSlice<StateT>::~TrellisSlice(){    /*     * Destroy node structures and associated n-best lists     */    init();}/* * Initialization of a time slice. */template <class StateT>voidTrellisSlice<StateT>::init(){    LHashIter<StateT, TrellisNode<StateT> > iter(nodes);    TrellisNode<StateT> *node;    StateT state;    /*     * XXX: We need to explicitly destroy the nodes in the hash table,     * due to lossage in LHash, to cause n-best lists to be freed.     * Unfortunately gcc 2.8.1 has a bug that prevents us from calling      * ~TrellisNode(), so we make do with clear().     */    while (node = iter.next(state)) {#if __GNUC__ == 2 && __GNUC_MINOR__ <= 8	node->clear();#else	node->~TrellisNode<StateT>();#endif    }    nodes.clear(0);    /*     * The globalNbest list is cleared and left unexpanded.     * We only fill it in when asked for.     */    globalNbest.init(0);}/* * Returns the log of the sum of the probabilities of paths that end at * the current time slice. */template <class StateT>LogPTrellisSlice<StateT>::sum(){    LHashIter<StateT, TrellisNode<StateT> > iter(nodes);    TrellisNode<StateT> *node;    StateT state;    LogP2 logSum = LogP_Zero;    while (node = iter.next(state)) {	logSum = AddLogP(logSum, node->lprob);    }    return (LogP)logSum;}/* * Returns the state that ends the highest probability path at the current * time slice. */template <class StateT>StateTTrellisSlice<StateT>::max(){    LHashIter<StateT, TrellisNode<StateT> > iter(nodes);    TrellisNode<StateT> *node;    StateT state, maxState;    LogP maxProb = LogP_Zero;    Map_noKey(maxState);    while (node = iter.next(state)) {	if (Map_noKeyP(maxState) ||	    node->nbestSize() > 0 && node->nbest[0].score > maxProb)	{	    maxProb = node->nbest[0].score;	    maxState = state;	}    }    return maxState;}/* * Calculates the nbest list of paths ending at the current time slice. * Once this is calculated, it is stored in the globalNbest member to avoid * recomputation. The n-best list thus computed is the n-best of the union * of all the n-best hyps belonging to each state in this time-slice. * * To get the n-best paths over the entire trellis, we must first call this * function on the last time slice.  The nbest list thus obtained can then * be back-traced to obtain n-best of the best paths. */template<class StateT>TrellisNBestList<StateT>&TrellisSlice<StateT>::nbest(unsigned numNbest){    if (globalNbest.size() >= numNbest) {	return globalNbest;    }    globalNbest.init(numNbest);    LHashIter<StateT, TrellisNode<StateT> > iter(nodes);    TrellisNode<StateT> *node;    StateT state;    while (node = iter.next(state)) {	for (unsigned n = 0; n < node->nbestSize(); n++) {	    globalNbest.insert(Hyp<StateT>(node->nbest[n].score, state, n));	}    }    return globalNbest;}//-------------------TrellisNBestList functions--------------------------------template<class StateT>TrellisNBestList<StateT>::TrellisNBestList(unsigned num)  : numNbest(0), nblist(0){    init(num);}template<class StateT>TrellisNBestList<StateT>::~TrellisNBestList(){    delete [] nblist;}/* * allocate or clear an N-best list */template<class StateT>voidTrellisNBestList<StateT>::init(unsigned newSize){    StateT s;    Map_noKey(s);    Hyp<StateT> h(LogP_Zero, s, 0);    if (newSize == 0) {	delete [] nblist;	nblist = 0;    } else if (newSize > numNbest) {	delete [] nblist;	nblist = new Hyp<StateT> [newSize];	assert(nblist != 0);    }    numNbest = newSize;    /*     * clear entries     */    for (unsigned i = 0; i < numNbest; i++) {	nblist[i] = h;    }}/* * Moves n bytes from src to dst starting at the end.  This is useful * to "shift down" part of an array. */template<class T>inline void rmemmove(T *dst, T *src, unsigned n){    T *d = dst + n;    T *s = src + n;    while (n--) {	*(--d) = *(--s);    }}/* Returns the position where hyp would be inserted into the nbest * list.  This may be numNbest if the hyp is worse than the worst * hyp already in the list. */template<class StateT>inline unsignedTrellisNBestList<StateT>::findrank(const Hyp<StateT>& hyp) const{    unsigned low = 0, high = numNbest - 1;    while (low+1 < high) {	unsigned m = (high+low)/2;	if (nblist[m].score >= hyp.score) {	    low = m;	} else {	    high = m;	}    }    /*     * low+1 == high at this point, but it may be that low == n-1     * where n is the correct insertion point, e.g., when inserting     * 2.5 in (...,3,2,...).     */    while (low < numNbest && nblist[low].score >= hyp.score) {	low ++;    }    return low;}/* * insert(hyp) inserts the given hyp into the current nBestList if the score * of hyp is better (greater) than the score of the worst hyp in the list.  The * hyp is inserted before the very first hyp in the list that has a score * *worse* than it. */template<class StateT>void TrellisNBestList<StateT>::insert(const Hyp<StateT>& hyp){    unsigned i = findrank(hyp);    if (i < numNbest) {	rmemmove<Hyp<StateT> >(&nblist[i+1], &nblist[i], numNbest-i-1);	nblist[i] = hyp;    }}//------------------ Iteration over states in a trellis slice -----------------template <class StateT>TrellisIter<StateT>::TrellisIter(Trellis<StateT> &trellis, unsigned t) : sliceIter(trellis.trellis[t].nodes){    assert(t <= trellis.currTime);}template <class StateT>voidTrellisIter<StateT>::init(){    sliceIter.init();}template <class StateT>BooleanTrellisIter<StateT>::next(StateT &state, LogP &prob){    TrellisNode<StateT> *node = sliceIter.next(state);    if (!node) {	return false;    }    prob = node->lprob;    return true;}#endif /* _Trellis_cc_ */

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -