⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nbest-optimize.cc

📁 这是一款很好用的工具包
💻 CC
📖 第 1 页 / 共 4 页
字号:
/*
 * nbest-optimize --
 *	Optimize score combination for N-best rescoring
 */

#ifndef lint
static char Copyright[] = "Copyright (c) 2000-2006 SRI International.  All Rights Reserved.";
static char RcsId[] = "@(#)$Id: nbest-optimize.cc,v 1.45 2006/01/09 18:08:21 stolcke Exp $";
#endif

#include <iostream>
using namespace std;
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <locale.h>
#ifndef _MSC_VER
#include <unistd.h>
#endif
#include <math.h>
#if defined(sun) || defined(sgi)
#include <ieeefp.h>
#endif
#include <signal.h>

#ifndef SIGALRM
#define NO_TIMEOUT
#endif

#include "option.h"
#include "version.h"
#include "File.h"
#include "Vocab.h"
#include "zio.h"

#include "NullLM.h"
#include "RefList.h"
#include "NBestSet.h"
#include "WordAlign.h"
#include "WordMesh.h"
#include "Array.cc"
#include "LHash.cc"

#define DEBUG_TRAIN 1
#define DEBUG_ALIGNMENT	2
#define DEBUG_SCORES 3
#define DEBUG_RANKING 4

typedef float NBestScore;			/* same as LogP */

unsigned numScores;				/* number of score dimensions */
unsigned numFixedWeights;			/* number of fixed weights */
LHash<RefString,NBestScore **> nbestScores;	/* matrices of nbest scores,
						 * one matrix per nbest list */
LHash<RefString,WordMesh *> nbestAlignments;	/* nbest alignments */
LHash<RefString,unsigned *> nbestErrors;	/* nbest error counts */

Array<double> lambdas;				/* score weights */
Array<double> lambdaDerivs;			/* error derivatives wrt same */
Array<double> prevLambdaDerivs;
Array<double> prevLambdaDeltas;
Array<Boolean> fixLambdas;			/* lambdas to keep constant */
unsigned numRefWords;				/* number of reference words */
unsigned totalError;				/* number of word errors */
double totalLoss;				/* smoothed loss */
Array<double> bestLambdas;			/* lambdas at lowest error */
unsigned bestError;				/* lower error count */

Array<double> lambdaSteps;			/* simplex step sizes  */
Array<double> simplex;				/* current simplex points  */

static int version = 0;
static int oneBest = 0;				/* optimize 1-best error */
static int oneBestFirst = 0;			/* 1-best then nbest error */
static int noReorder = 0;
static unsigned debug = 0;
static char *vocabFile = 0;
static int toLower = 0;
static int multiwords = 0;
static char *noiseTag = 0;
static char *noiseVocabFile = 0;
static char *hiddenVocabFile = 0;
static char *refFile = 0;
static char *errorsDir = 0;
static char *nbestFiles = 0;
static unsigned maxNbest = 0;
static char *printHyps = 0;
static char *nbestDirectory = 0;
static char **scoreDirectories = 0;
static char *writeRoverControl = 0;
static int quickprop = 0;

static double rescoreLMW = 8.0;
static double rescoreWTW = 0.0;
static double posteriorScale = 0.0;
static double posteriorScaleStep = 1.0;
static int combineLinear = 0;
static int nonNegative = 0;

static char *initLambdas = 0;
static char *initSimplex = 0;
static double alpha = 1.0;
static double epsilon = 0.1;
static double epsilonStepdown = 0.0;
static double minEpsilon = 0.0001;
static double minLoss = 0;
static double maxDelta = 1000;
static unsigned maxIters = 100000;
static double converge = 0.0001;
static unsigned maxBadIters = 10;
static unsigned maxAmoebaRestarts = 100000;
static unsigned maxTime = 0;

static int optRest;

static Option options[] = {
    { OPT_TRUE, "version", &version, "print version information" },
    { OPT_STRING, "refs", &refFile, "reference transcripts" },
    { OPT_STRING, "nbest-files", &nbestFiles, "list of training N-best files" },
    { OPT_UINT, "max-nbest", &maxNbest, "maximum number of hyps to consider" },
    { OPT_TRUE, "1best", &oneBest, "optimize 1-best error" },
    { OPT_TRUE, "1best-first", &oneBestFirst, "optimize 1-best error before full optimization" },
    { OPT_TRUE, "no-reorder", &noReorder, "don't reorder N-best hyps before aligning and align refs first" },
    { OPT_STRING, "errors", &errorsDir, "directory containing error counts" },
    { OPT_STRING, "vocab", &vocabFile, "set vocabulary" },
    { OPT_TRUE, "tolower", &toLower, "map vocabulary to lowercase" },
    { OPT_TRUE, "multiwords", &multiwords, "split multiwords in N-best hyps" },
    { OPT_STRING, "noise", &noiseTag, "noise tag to skip" },
    { OPT_STRING, "noise-vocab", &noiseVocabFile, "noise vocabulary to skip" },
    { OPT_STRING, "hidden-vocab", &hiddenVocabFile, "subvocabulary to be kept separate in mesh alignment" },
    { OPT_STRING, "write-rover-control", &writeRoverControl, "nbest-rover control output file" },
    { OPT_FLOAT, "rescore-lmw", &rescoreLMW, "rescoring LM weight" },
    { OPT_FLOAT, "rescore-wtw", &rescoreWTW, "rescoring word transition weight" },
    { OPT_FLOAT, "posterior-scale", &posteriorScale, "divisor for log posterior estimates" },
    { OPT_TRUE, "combine-linear", &combineLinear, "combine scores linearly (not log-linearly" },
    { OPT_TRUE, "non-negative", &nonNegative, "limit search to non-negative weights" },
    { OPT_STRING, "init-lambdas", &initLambdas, "initial lambda values" },
    { OPT_STRING, "init-amoeba-simplex", &initSimplex, "initial amoeba simplex points" },
    { OPT_FLOAT, "alpha", &alpha, "sigmoid slope parameter" },
    { OPT_FLOAT, "epsilon", &epsilon, "learning rate parameter" },
    { OPT_FLOAT, "epsilon-stepdown", &epsilonStepdown, "epsilon step-down factor" },
    { OPT_FLOAT, "min-epsilon", &minEpsilon, "minimum epsilon after step-down" },
    { OPT_FLOAT, "min-loss", &minLoss, "samples with loss below this are ignored" },
    { OPT_FLOAT, "max-delta", &maxDelta, "threshold to filter large deltas" },
    { OPT_UINT, "maxiters", &maxIters, "maximum number of learning iterations" },
    { OPT_UINT, "max-bad-iters", &maxBadIters, "maximum number of iterations without improvement" },
    { OPT_UINT, "max-amoeba-restarts", &maxAmoebaRestarts, "maximum number of Amoeba restarts" },
#ifndef NO_TIMEOUT
    { OPT_UINT, "max-time", &maxTime, "abort search after this many seconds" },
#endif
    { OPT_FLOAT, "converge", &converge, "minimum relative change in objective function" },
    { OPT_STRING, "print-hyps", &printHyps, "output file for final top hyps" },
    { OPT_TRUE, "quickprop", &quickprop, "use QuickProp gradient descent" },
    { OPT_UINT, "debug", &debug, "debugging level" },
    { OPT_REST, "-", &optRest, "indicate end of option list" },
};

static Boolean abortSearch = false;

#ifndef NO_TIMEOUT
/*
 * deal with different signal hander types
 */
#ifndef _sigargs
#define _sigargs int
#endif

void catchAlarm(_sigargs)
{
    abortSearch = true;
}
#endif /* !NO_TIMEOUT */

double
sigmoid(double x)
{
    return 1/(1 + exp(- alpha * x));
}

void
dumpScores(ostream &str, NBestSet &nbestSet)
{
    NBestSetIter iter(nbestSet);
    NBestList *nbest;
    RefString id;

    while (nbest = iter.next(id)) {
	str << "id = " << id << endl;

	NBestScore ***scores = nbestScores.find(id);

	if (!scores) {
	    str << "no scores found!\n";
	} else {
	    for (unsigned j = 0; j < nbest->numHyps(); j ++) {
		str << "Hyp " << j << ":" ;
		for (unsigned i = 0; i < numScores; i ++) {
		    str << " " << (*scores)[i][j];
		}
		str << endl;
	    }
	}
    }
}

void
dumpAlignment(ostream &str, WordMesh &alignment)
{
    for (unsigned pos = 0; pos < alignment.length(); pos ++) {
	Array<HypID> *hypMap;
	VocabIndex word;

	str << "position " << pos << endl;

	WordMeshIter iter(alignment, pos);
	while (hypMap = iter.next(word)) {
	    str << "  word = " << alignment.vocab.getWord(word) << endl;

	    for (unsigned k = 0; k < hypMap->size(); k ++) {
		str << " " << (*hypMap)[k];
	    }
	    str << endl;
	}
    }
}

/*
 * compute hypothesis score (weighted sum of log scores)
 */
LogP
hypScore(unsigned hyp, NBestScore **scores)
{
    static NBestScore **lastScores = 0;
    static Array<LogP> *cachedScores = 0;

    if (scores != lastScores) {
	delete cachedScores;
	cachedScores = new Array<LogP>;
	assert(cachedScores != 0);
	lastScores = scores;
    }

    if (hyp < cachedScores->size()) {
	if ((*cachedScores)[hyp] != 0.0) {
	    return (*cachedScores)[hyp];
	}
    } else {
	for (unsigned j = cachedScores->size(); j < hyp; j ++) {
	    (*cachedScores)[j] = 0.0;
	}
    }

    LogP score;

    double *weights = lambdas.data(); /* bypass index range check for speed */

    if (combineLinear) {
	/* linear combination, even though probabilities are encoded as logs */
	Prob prob = 0.0;
	for (unsigned i = 0; i < numScores; i ++) {
	    prob += weights[i] * LogPtoProb(scores[i][hyp]);
	}
	score = ProbToLogP(prob);
    } else {
	/* log-linear combination */
	score = 0.0;
	for (unsigned i = 0; i < numScores; i ++) {
	    score += weightLogP(weights[i], scores[i][hyp]);
	}
    }
    return ((*cachedScores)[hyp] = score);
}

/*
 * compute summed hyp scores (sum of unnormalized posteriors of all hyps
 *	containing a word)
 *	isCorrect is set to true if hyps contains the reference (refID)
 *	The last parameter is used to collect auxiliary sums needed for
 *	derivatives
 */
Prob
wordScore(Array<HypID> &hyps, NBestScore **scores, Boolean &isCorrect,
								Prob *a = 0)
{
    Prob totalScore = 0.0;
    isCorrect = false;

    if (a != 0) {
	for (unsigned i = 0; i < numScores; i ++) {
	    a[i] = 0.0;
	}
    }

    for (unsigned k = 0; k < hyps.size(); k ++) {
	if (hyps[k] == refID) {
	    /*
	     * This hyp represents the correct word string, but doesn't 		     * contribute to the posterior probability for the word.
	     */
	    isCorrect = true;
	} else {
	    Prob score = LogPtoProb(hypScore(hyps[k], scores));

	    totalScore += score;
	    if (a != 0) {
		for (unsigned i = 0; i < numScores; i ++) {
		    a[i] += weightLogP(score, scores[i][hyps[k]]);
		}
	    }
	}
    }

    return totalScore;
}


/*
 * compute loss and derivatives for a single nbest list
 */
void
computeDerivs(RefString id, NBestScore **scores, WordMesh &alignment)
{
    /* 
     * process all positions in alignment
     */
    for (unsigned pos = 0; pos < alignment.length(); pos++) {
	VocabIndex corWord = Vocab_None;
	Prob corScore = 0.0;
	Array<HypID> *corHyps;

	VocabIndex bicWord = Vocab_None;
	Prob bicScore = 0.0;
	Array<HypID> *bicHyps;

	if (debug >= DEBUG_RANKING) {
	    cerr << "   position " << pos << endl;
	}

	WordMeshIter iter(alignment, pos);

	Array<HypID> *hypMap;
	VocabIndex word;
	while (hypMap = iter.next(word)) {
	    /*
	     * compute total score for word and check if it's the correct one
	     */
	    Boolean isCorrect;
	    Prob totalScore = wordScore(*hypMap, scores, isCorrect);

	    if (isCorrect) {
		corWord = word;
		corScore = totalScore;
		corHyps = hypMap;
	    } else {
		if (bicWord == Vocab_None || bicScore < totalScore) {
		    bicWord = word;
		    bicScore = totalScore;
		    bicHyps = hypMap;
		}
	    }
	}

	/*
	 * There must be a correct hyp
	 */
	assert(corWord != Vocab_None);

	if (debug >= DEBUG_RANKING) {
	    cerr << "      cor word = " << alignment.vocab.getWord(corWord)
		 << " score = " << corScore << endl;
	    cerr << "      bic word = " << (bicWord == Vocab_None ? "NONE" :
					    alignment.vocab.getWord(bicWord))
		 << " score = " << bicScore << endl;
	}

	unsigned wordError = (bicScore > corScore);
	double smoothError = 
			sigmoid(ProbToLogP(bicScore) - ProbToLogP(corScore));

	totalError += wordError;
	totalLoss += smoothError;

	/*
	 * If all word hyps are correct or incorrect, or loss is below a set
	 * threshold, then this sample cannot help us and we exclude it from
	 * the derivative computation
	 */
	if (bicScore == 0.0 || corScore == 0.0 || smoothError < minLoss) {
	    continue;
	}

	/*
	 * Compute the auxiliary vectors for derivatives
	 */
	Boolean dummy;
	makeArray(Prob, corA, numScores);
	wordScore(*corHyps, scores, dummy, corA);

	makeArray(Prob, bicA, numScores);
	wordScore(*bicHyps, scores, dummy, bicA);

	/*
	 * Accumulate derivatives
	 */
	double sigmoidDeriv = alpha * smoothError * (1 - smoothError);

	for (unsigned i = 0; i < numScores; i ++) {
	    double delta = (bicA[i] / bicScore - corA[i] / corScore);

	    if (fabs(delta) > maxDelta) {
		cerr << "skipping large delta " << delta
		     << " at id " << id
		     << " position " << pos
		     << " score " << i
		     << endl;
	    } else {
		lambdaDerivs[i] += sigmoidDeriv * delta;
	    }
	}
    }
}

/*
 * do a single pass over all nbest lists, computing loss function
 * and accumulating derivatives for lambdas.
 */
void
computeDerivs(NBestSet &nbestSet)
{
    /*
     * Initialize error counts and derivatives
     */
    totalError = 0;
    totalLoss = 0.0;

    for (unsigned i = 0; i < numScores; i ++) {
	lambdaDerivs[i] = 0.0;
    }

    NBestSetIter iter(nbestSet);
    NBestList *nbest;
    RefString id;

    while (nbest = iter.next(id)) {
	NBestScore ***scores = nbestScores.find(id);
	assert(scores != 0);
	WordMesh **alignment = nbestAlignments.find(id);
	assert(alignment != 0);

	computeDerivs(id, *scores, **alignment);
    }
}

/*
 * compute 1-best word error for a single nbest list
 * Note: uses global lambdas variable (yuck!)
 */
double
compute1bestErrors(RefString id, NBestScore **scores, NBestList &nbest)
{
    unsigned numHyps = nbest.numHyps();
    unsigned bestHyp;
    LogP bestScore;

    if (numHyps == 0) {
	return 0.0;
    }

    /*
     * Find hyp with highest score
     */
    for (unsigned i = 0; i < numHyps; i ++) {
	LogP score = hypScore(i, scores);

	if (i == 0 || score > bestScore) {
	    bestScore = score;
	    bestHyp = i;
	}
    }

    return nbest.getHyp(bestHyp).numErrors;
}

/*
 * compute sausage word error for a single nbest list
 * Note: uses global lambdas variable (yuck!)
 */
double
computeSausageErrors(RefString id, NBestScore **scores, WordMesh &alignment)
{
    int result = 0;

    /* 
     * process all positions in alignment
     */
    for (unsigned pos = 0; pos < alignment.length(); pos++) {
	VocabIndex corWord = Vocab_None;
	Prob corScore = 0.0;
	Array<HypID> *corHyps;

	VocabIndex bicWord = Vocab_None;
	Prob bicScore = 0.0;
	Array<HypID> *bicHyps;

	if (debug >= DEBUG_RANKING) {
	    cerr << "   position " << pos << endl;
	}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -