⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 wordmesh.cc

📁 这是一款很好用的工具包
💻 CC
📖 第 1 页 / 共 3 页
字号:
/*
 * WordMesh.cc --
 *	Word Meshes (aka Confusion Networks aka Sausages)
 */

#ifndef lint
static char Copyright[] = "Copyright (c) 1995-2006 SRI International.  All Rights Reserved.";
static char RcsId[] = "@(#)$Header: /home/srilm/devel/lm/src/RCS/WordMesh.cc,v 1.37 2006/01/09 18:08:21 stolcke Exp $";
#endif

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>

#include "WordMesh.h"
#include "WordAlign.h"

#include "Array.cc"
#include "LHash.cc"
#include "SArray.cc"

/*
 * Special token used to represent an empty position in an alignment column
 */
const VocabString deleteWord = "*DELETE*";

WordMesh::WordMesh(Vocab &vocab, const char *myname, VocabDistance *distance)
    : MultiAlign(vocab, myname), numAligns(0), totalPosterior(0.0),
      distance(distance)
{
    deleteIndex = vocab.addWord(deleteWord);
}

WordMesh::~WordMesh()
{
    if (name != 0) {
	free(name);
    }

    for (unsigned i = 0; i < numAligns; i ++) {
	delete aligns[i];

	LHashIter<VocabIndex,NBestWordInfo> infoIter(*wordInfo[i]);
	NBestWordInfo *winfo;
	VocabIndex word;
	while (winfo = infoIter.next(word)) {
	    winfo->~NBestWordInfo();
	}
	delete wordInfo[i];

	LHashIter<VocabIndex,Array<HypID> > mapIter(*hypMap[i]);
	Array<HypID> *hyps;
	while (hyps = mapIter.next(word)) {
	    hyps->~Array();
	}
	delete hypMap[i];
    }
}
   
Boolean
WordMesh::isEmpty()
{
    return numAligns == 0;
}

/*
 * alignment set to sort by posterior (parameter to comparePosteriors)
 */
static LHash<VocabIndex, Prob> *compareAlign;

static int
comparePosteriors(VocabIndex w1, VocabIndex w2)
{
    Prob diff = *compareAlign->find(w1) - *compareAlign->find(w2);

    if (diff < 0.0) {
	return 1;
    } else if (diff > 0.0) {
	return -1;
    } else {
	return 0;
    }
}

Boolean
WordMesh::write(File &file)
{
    if (name != 0) {
	fprintf(file, "name %s\n", name);
    }
    fprintf(file, "numaligns %u\n", numAligns);
    fprintf(file, "posterior %lg\n", totalPosterior);

    for (unsigned i = 0; i < numAligns; i ++) {
	fprintf(file, "align %u", i);

	compareAlign = aligns[sortedAligns[i]];
	LHashIter<VocabIndex,Prob> alignIter(*compareAlign, comparePosteriors);

	Prob *prob;
	VocabIndex word;
	VocabIndex refWord = Vocab_None;

	while (prob = alignIter.next(word)) {
	    fprintf(file, " %s %lg", vocab.getWord(word), *prob);

	    /*
	     * See if this word is the reference one
	     */
	    Array<HypID> *hypList = hypMap[sortedAligns[i]]->find(word);
	    if (hypList) {
		for (unsigned j = 0; j < hypList->size(); j++) {
		    if ((*hypList)[j] == refID) {
			refWord = word;
			break;
		    }
		}
	    }
	}
	fprintf(file, "\n");

	/*
	 * Print column and transition posterior sums,
	 * if different from total Posterior
	 */
	Prob myPosterior = columnPosteriors[sortedAligns[i]];

	if (myPosterior != totalPosterior) {
	    fprintf(file, "posterior %u %lg\n", i, myPosterior);
	}

	Prob transPosterior = transPosteriors[sortedAligns[i]];

	if (transPosterior != totalPosterior) {
	    fprintf(file, "transposterior %u %lg\n", i, transPosterior);
	}

	/* 
	 * Print reference word (if known)
	 */
	if (refWord != Vocab_None) {
	    fprintf(file, "reference %u %s\n", i, vocab.getWord(refWord));
	}

	/*
	 * Dump hyp IDs if known
	 */
	LHashIter<VocabIndex,Array<HypID> >
			mapIter(*hypMap[sortedAligns[i]], comparePosteriors);
	Array<HypID> *hypList;

	while (hypList = mapIter.next(word)) {
	    /*
	     * Only output hyp IDs if they are different from the refID
	     * (to avoid redundancy with "reference" line)
	     */
	    if (hypList->size() > (unsigned) (word == refWord)) {
		fprintf(file, "hyps %u %s", i, vocab.getWord(word));

		for (unsigned j = 0; j < hypList->size(); j++) {
		    if ((*hypList)[j] != refID) {
			fprintf(file, " %d", (int)(*hypList)[j]);
		    }
		}
		fprintf(file, "\n");
	    }
	}

	/*
	 * Dump word backtrace info if known
	 */
	LHashIter<VocabIndex,NBestWordInfo>
			infoIter(*wordInfo[sortedAligns[i]], comparePosteriors);
	NBestWordInfo *winfo;

	while (winfo = infoIter.next(word)) {
	    fprintf(file, "info %u %s ", i, vocab.getWord(word));
	    winfo->write(file);
	    fprintf(file, "\n");
	}
    }

    return true;
}

Boolean
WordMesh::read(File &file)
{
    for (unsigned i = 0; i < numAligns; i ++) {
	delete aligns[i];
    }
   
    char *line;

    totalPosterior = 1.0;

    while (line = file.getline()) {
	char arg1[100];
	double arg2;
	unsigned parsed;
	unsigned index;

	if (sscanf(line, "numaligns %u", &parsed) == 1) {
	    if (numAligns > 0) {
		file.position() << "repeated numaligns specification\n";
		return false;
	    }
	    numAligns = parsed;
		
	    for (unsigned i = 0; i < numAligns; i ++) {
		sortedAligns[i] = i;

		aligns[i] = new LHash<VocabIndex,Prob>;
		assert(aligns[i] != 0);

		wordInfo[i] = new LHash<VocabIndex,NBestWordInfo>;
		assert(wordInfo[i] != 0);

		hypMap[i] = new LHash<VocabIndex,Array<HypID> >;
		assert(hypMap[i] != 0);

		columnPosteriors[i] = transPosteriors[i] = totalPosterior;
	    }
	} else if (sscanf(line, "name %100s", arg1) == 1) {
	    if (name != 0) {
		free(name);
	    }
	    name = strdup(arg1);
	    assert(name != 0);
	} else if (sscanf(line, "posterior %100s %lg", arg1, &arg2) == 2 &&
	           // scan node index with %s so we fail if only one numerical
		   // arg is given (which case handled below)
		   sscanf(arg1, "%u", &index) == 1)
	    {
	    if (index >= numAligns) {
		file.position() << "position index exceeds numaligns\n";
		return false;
	    }

	    columnPosteriors[index] = arg2;
	} else if (sscanf(line, "transposterior %u %lg", &index, &arg2) == 2) {
	    if (index >= numAligns) {
		file.position() << "position index exceeds numaligns\n";
		return false;
	    }

	    transPosteriors[index] = arg2;
	} else if (sscanf(line, "posterior %lg", &arg2) == 1) {
	    totalPosterior = arg2;
	    for (unsigned j = 0; j < numAligns; j ++) {
		columnPosteriors[j] = transPosteriors[j] = arg2;
	    }
	} else if (sscanf(line, "align %u%n", &index, &parsed) == 1) {
	    if (index >= numAligns) {
		file.position() << "position index exceeds numaligns\n";
		return false;
	    }

	    char *cp = line + parsed;
	    while (sscanf(cp, "%100s %lg%n", arg1, &arg2, &parsed) == 2) {
		VocabIndex word = vocab.addWord(arg1);

		*aligns[index]->insert(word) = arg2;
		
		cp += parsed;
	    }
	} else if (sscanf(line, "reference %u %100s", &index, arg1) == 2) {
	    if (index >= numAligns) {
		file.position() << "position index exceeds numaligns\n";
		return false;
	    }

	    VocabIndex refWord = vocab.addWord(arg1);

	    /*
	     * Records word as part of the reference string
	     */
	    Array<HypID> *hypList = hypMap[index]->insert(refWord);
	    (*hypList)[hypList->size()] = refID;
	} else if (sscanf(line, "hyps %u %100s%n", &index, arg1, &parsed) == 2){
	    if (index >= numAligns) {
		file.position() << "position index exceeds numaligns\n";
		return false;
	    }

	    VocabIndex word = vocab.addWord(arg1);
	    Array<HypID> *hypList = hypMap[index]->insert(word);

	    /*
	     * Parse and record hyp IDs
	     */
	    char *cp = line + parsed;
	    unsigned hypID;
	    while (sscanf(cp, "%u%n", &hypID, &parsed) == 1) {
		(*hypList)[hypList->size()] = hypID;
		*allHyps.insert(hypID) = hypID;

		cp += parsed;
	    }
	} else if (sscanf(line, "info %u %100s%n", &index, arg1, &parsed) == 2){
	    if (index >= numAligns) {
		file.position() << "position index exceeds numaligns\n";
		return false;
	    }

	    VocabIndex word = vocab.addWord(arg1);
	    NBestWordInfo *winfo = wordInfo[index]->insert(word);

	    winfo->word = word;
	    if (!winfo->parse(line + parsed)) {
		file.position() << "invalid word info\n";
		return false;
	    }
	} else {
	    file.position() << "unknown keyword\n";
	    return false;
	}
    }
    return true;
}

/*
 * Compute expected error from aligning a word to an alignment column
 * if column == 0 : compute insertion cost
 * if word == deleteIndex : compute deletion cost
 */
double
WordMesh::alignError(const LHash<VocabIndex,Prob>* column,
		     Prob columnPosterior,
		     VocabIndex word)
{
    if (column == 0) {
	/*
	 * Compute insertion cost for word
	 */
	if (word == deleteIndex) {
	    return 0.0;
	} else {
	    if (distance) {
		return columnPosterior * distance->penalty(word);
	    } else {
		return columnPosterior;
	    }
	}
    } else {
	if (word == deleteIndex) {
	    /* 
	     * Compute deletion cost for alignment column
	     */
	    if (distance) {
		double deletePenalty = 0.0;

		LHashIter<VocabIndex,Prob> iter(*column);
		Prob *prob;
		VocabIndex alignWord;
		while (prob = iter.next(alignWord)) {
		    if (alignWord != deleteIndex) {
			deletePenalty += *prob * distance->penalty(alignWord);
		    }
		}
		return deletePenalty;
	    } else {
		Prob *deleteProb = column->find(deleteIndex);
		return  columnPosterior - (deleteProb ? *deleteProb : 0.0);
	    }
	} else {
	    /*
	     * Compute "substitution" cost of word in column
	     */
	    if (distance) {
		/*
		 * Compute distance to existing alignment as a weighted 
		 * combination of distances
		 */
		double totalDistance = 0.0;

	    	LHashIter<VocabIndex,Prob> iter(*column);
		Prob *prob;
		VocabIndex alignWord;
		while (prob = iter.next(alignWord)) {
		    if (alignWord == deleteIndex) {
			totalDistance +=
			    *prob * distance->penalty(word);
		    } else {
			totalDistance +=
			    *prob * distance->distance(alignWord, word);
		    }
		}

		return totalDistance;
	    } else {
	        Prob *wordProb = column->find(word);
		return columnPosterior - (wordProb ? *wordProb : 0.0);
	    }
	}
    }
}

/*
 * Compute expected error from aligning two alignment columns
 * if column1 == 0 : compute insertion cost
 * if column2 == 0 : compute deletion cost
 */
double
WordMesh::alignError(const LHash<VocabIndex,Prob>* column1, 
		     Prob columnPosterior,
		     const LHash<VocabIndex,Prob>* column2,
		     Prob columnPosterior2)
{
    if (column2 == 0) {
	return alignError(column1, columnPosterior, deleteIndex);
    } else {
	/*
	 * compute weighted sum of aligning each of the column2 entries,
	 * using column2 posteriors as weights
	 */
	double totalDistance = 0.0;

	LHashIter<VocabIndex,Prob> iter(*column2);
	Prob *prob;
	VocabIndex word2;
	while (prob = iter.next(word2)) {
	    double error = alignError(column1, columnPosterior, word2);

	    /*
	     * Handle case where one of the entries has posterior 1, but 
	     * others have small nonzero posteriors, too.  The small ones
	     * can be ignored in the sum total, and this shortcut makes the
	     * numerical computation symmetric with respect to the case
	     * where posterior 1 occurs in column1 (as well as speeding things
	     * up).
	     */
	    if (*prob == columnPosterior2) {
		return *prob * error;
	    } else {
		totalDistance += *prob * error;
	    }
	}
	return totalDistance;
    }
}

/*
 * Align new words to existing alignment columns, expanding them as required
 * (derived from WordAlign())
 * If hypID != 0, then *hypID will record a sentence hyp ID for the 
 * aligned words.
 */
void
WordMesh::alignWords(const VocabIndex *words, Prob score,
			    Prob *wordScores, const HypID *hypID)
{
    unsigned numWords = Vocab::length(words);
    NBestWordInfo *winfo = new NBestWordInfo[numWords + 1];
    assert(winfo != 0);

    /*
     * Fill word info array with word IDs and dummy info
     * Note: loop below also handles the final Vocab_None.
     */
    for (unsigned i = 0; i <= numWords; i ++) {
	winfo[i].word = words[i];
	winfo[i].invalidate();
	winfo[i].wordPosterior = 0.0;
	winfo[i].transPosterior = 0.0;
    }

    alignWords(winfo, score, wordScores, hypID);

    delete [] winfo;
}

/*
 * This is the generalized version of alignWords():
 *	- merges NBestWordInfo into the existing alignment
 *	- aligns word string between any two existing alignment positions
 *	- optionally returns the alignment positions assigned to aligned words
 *	- optionally returns the posterior probabilities of aligned words

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -