⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ngram-class.cc

📁 这是一款很好用的工具包
💻 CC
📖 第 1 页 / 共 2 页
字号:
/*
 * ngram-class --
 *	Induce class ngram models from counts
 *
 */

#ifndef lint
static char Copyright[] = "Copyright (c) 1999-2006 SRI International.  All Rights Reserved.";
static char RcsId[] = "@(#)$Id: ngram-class.cc,v 1.26 2006/01/05 20:21:27 stolcke Exp $";
#endif

#include <iostream>
using namespace std;
#include <stdlib.h>
#include <locale.h>
#include <assert.h>

#include "option.h"
#include "version.h"
#include "zio.h"
#include "File.h"
#include "Debug.h"
#include "Prob.h"
#include "Vocab.h"
#include "SubVocab.h"
#include "TextStats.h"
#include "NgramStats.h"
#include "LHash.cc"
#include "Map2.cc"
#include "Array.cc"
#include "Trie.cc"

#ifdef INSTANTIATE_TEMPLATES
INSTANTIATE_MAP1(VocabIndex,VocabIndex,LogP);

#ifdef USE_SARRAY
// XXX: avoid multiple definitions with NgramLM
INSTANTIATE_LHASH(VocabIndex,LogP);
#endif

#endif

#define DEBUG_TEXTSTATS		1
#define DEBUG_TRACE_MERGE	2
#define DEBUG_PRINT_CONTRIBS	3	// in interactive mode

static int version = 0;
static char *vocabFile = 0;
static int toLower = 0;
static char *noclassFile = 0;
static char *countsFile = 0;
static char *textFile = 0;
static char *classesFile = 0;
static char *classCountsFile = 0;
static unsigned numClasses = 1;
static int fullMerge = 0;
static int interact = 0;
static int debug = 0;
static int saveFreq = 0;

static Option options[] = {
    { OPT_TRUE, "version", &version, "print version information" },
    { OPT_UINT, "debug", &debug, "debugging level" },
    { OPT_STRING, "vocab", &vocabFile, "vocab file" },
    { OPT_TRUE, "tolower", &toLower, "map vocabulary to lowercase" },
    { OPT_STRING, "noclass-vocab", &noclassFile, "vocabulary not to be classed" },
    { OPT_STRING, "counts", &countsFile, "counts file to read" },
    { OPT_STRING, "text", &textFile, "text file to count" },
    { OPT_UINT, "numclasses", &numClasses, "number of classes to induce" },
    { OPT_TRUE, "full", &fullMerge, "perform full greedy merging" },
    { OPT_FALSE, "incremental", &fullMerge, "perform incremental greedy merging" },
    { OPT_TRUE, "interact", &interact, "perform interactive merging" },
    { OPT_STRING, "class-counts", &classCountsFile, "class N-gram count output" },
    { OPT_STRING, "classes", &classesFile, "class definitions output" },
    { OPT_INT, "save", &saveFreq, "save classes/counts every this many iterations" },
};

/*
 * Compute n \log n correctly even if n = 0.
 */
static inline LogP
NlogN(NgramCount count)
{
    if (count == 0) {
	return LogP_One;
    } else {
	return count * ProbToLogP(count);
    }
}

/*
 * Many-to-one class-to-word mapping
 */
class UniqueWordClasses: public Debug
{
public:
    UniqueWordClasses(Vocab &vocab, SubVocab &classVocab);
    ~UniqueWordClasses() {};

    VocabIndex newClass();		// create a new class
    void initialize(NgramStats &counts, SubVocab &noclassVocab);
					// initialize classes from word counts
    void merge(VocabIndex c1, VocabIndex c2);	// merge classes
    LogP bestMerge(Vocab &mergeSet, VocabIndex &c1, VocabIndex &c2);
						// single best merge step
    void fullMerge(unsigned numClases);		// full greedy merging
    void incrementalMerge(unsigned numClases);	// incremental merging

    void writeClasses(File &file);	// write class definitions
    void writeCounts(File &file)	// write class ngram counts
	{ classNgramCounts.write(file, 0, true); };

    NgramCount getCount(VocabIndex c1)		// class unigram count
	{ return getCount(c1, Vocab_None); };		
    inline NgramCount getCount(VocabIndex c1, VocabIndex c2);
						// class bigram count
    inline NgramCount getCountR(VocabIndex c1, VocabIndex c2);
						// reverse bigram count

    LogP totalLogP();			// total log likelihood
    LogP diffLogP(VocabIndex c1, VocabIndex c2);
					// log likelihood difference for 
					// class merging

    void computeClassContribs();	// recompute classContribs vector
    void computeMergeContribs();	// recompute mergeContribs matrix

    void getStats(TextStats &stats);

    void writeContribs(File &file);	// dump contrib vector

    Vocab &vocab;
    Vocab &classVocab;

protected:
    LHash<VocabIndex,VocabIndex> wordToClass;	// word->class map
    LHash<VocabIndex,NgramCount> wordCounts;	// word counts
    NgramStats classNgramCounts;		// trie of (C,...,C) counts
    NgramStats classNgramCountsR;		// trie of reversed counts,
    LHash<VocabIndex,LogP> classContribs;	// class contributions to
						// total log likelihood
    Map2<VocabIndex,VocabIndex,LogP> mergeContribs;
						// merge-pair contributions
						// to delta log likelihood
    void computeMergeContrib(VocabIndex c1);	// recompute mergeContribs
    LogP computeClassContrib(VocabIndex c);	// recompute classContribs
    LogP computeMergeContrib(VocabIndex c1, VocabIndex c2);	// same

    void mergeCounts(NgramStats &counts, VocabIndex c1, VocabIndex c2);
    
    unsigned genSymCount;
};

UniqueWordClasses::UniqueWordClasses(Vocab &vocab, SubVocab &classVocab)
    : vocab(vocab), classVocab(classVocab),
      classNgramCounts(vocab, 2), classNgramCountsR(vocab, 2),
      genSymCount(0)
      
{
    /*
     * Make sure the classes are subset of base vocabulary
     */
    assert(&vocab == &classVocab.baseVocab());
};

VocabIndex
UniqueWordClasses::newClass()
{
    char className[30];

    sprintf(className, "CLASS-%05u", ++genSymCount);
    assert(vocab.getIndex(className) == Vocab_None);

    return classVocab.addWord(className);
}

void
UniqueWordClasses::initialize(NgramStats &counts, SubVocab &noclassVocab)
{
    VocabIndex ngram[3];

    /*
     * Make sure the noclassVocab is subset of base vocabulary
     */
    assert(&vocab == &noclassVocab.baseVocab());

    /*
     * Enumerate unigrams
     */
    NgramsIter iter1(counts, ngram, 1);
    NgramCount *count;

    while (count = iter1.next()) {
	VocabIndex classUnigram[2];
	
	if (noclassVocab.getWord(ngram[0]) != 0) {
	    /*
	     * A word that is not supposed to be classed ...
	     */
	    classUnigram[0] = ngram[0];
	} else {
	    /* 
	     * ... one that is: find or create class for it.
	     */
	    Boolean found;
	    VocabIndex *class1 = wordToClass.insert(ngram[0], found);
	    if (!found) {
		*class1 = newClass();

		if (debug(DEBUG_TRACE_MERGE)) {
		    dout() << "\tcreating " << classVocab.getWord(*class1)
			   << " for word " << vocab.getWord(ngram[0])
			   << endl;
		}
	    }
	    classUnigram[0] = *class1;
	}
	classUnigram[1] = Vocab_None;

	*classNgramCounts.insertCount(classUnigram) += *count;
	*classNgramCountsR.insertCount(classUnigram) += *count;

	*wordCounts.insert(ngram[0]) += *count;
    }

    /*
     * Enumerate bigrams
     */
    NgramsIter iter2(counts, ngram, 2);

    while (count = iter2.next()) {
	VocabIndex classBigram[3];

	if (noclassVocab.getWord(ngram[0]) != 0) {
	    classBigram[0] = ngram[0];
	} else {
	    VocabIndex *class1 = wordToClass.find(ngram[0]);
	    if (class1 == 0) {
		cerr << "word 1 in bigram \"" << (vocab.use(), ngram)
		     << "\" has no unigram count\n";
		exit(1);
	    }
	    classBigram[0] = *class1;
	}

	if (noclassVocab.getWord(ngram[1]) != 0) {
	    classBigram[1] = ngram[1];
	} else {
	    VocabIndex *class2 = wordToClass.find(ngram[1]);
	    if (class2 == 0) {
		cerr << "word 2 in bigram \"" << (vocab.use(), ngram)
		     << "\" has no unigram count\n";
		exit(1);
	    }
	    classBigram[1] = *class2;
	}

	classBigram[2] = Vocab_None;
	*classNgramCounts.insertCount(classBigram) += *count;

	Vocab::reverse(classBigram);
	*classNgramCountsR.insertCount(classBigram) += *count;
    }
}

void
UniqueWordClasses::mergeCounts(NgramStats &counts,
					VocabIndex c1, VocabIndex c2)
{
    VocabIndex unigram[2]; unigram[1] = Vocab_None;
    VocabIndex unigram2[2]; unigram2[1] = Vocab_None;
    VocabIndex bigram[3]; bigram[2] = Vocab_None;
    NgramCount *count;

    /*
     * Update Counts 
     * 1) add row c2 to c1 row
     */
    unigram[0] = c2;
    NgramsIter iter2(counts, unigram, unigram2, 1);
    while(count = iter2.next()) {
	bigram[0] = c1;
	bigram[1] = unigram2[0];
	*counts.insertCount(bigram) += *count;
    }

    /*
     * 2) Remove row c2
     */
    unigram[0] = c1;
    unigram2[0] = c2;
    count = counts.removeCount(unigram2);
    if (count != 0) {
	*counts.insertCount(unigram) += *count;
    }

    /*
     * 3) add column c2 to column c1 and remove column c2
     */
    NgramsIter iter3(counts, unigram, 1);
    while (count = iter3.next()) {
	bigram[0] = unigram[0];
	bigram[1] = c2;
	NgramCount *count2 = counts.removeCount(bigram);

	if (count2 != 0) {
	    bigram[1] = c1;
	    *counts.insertCount(bigram) += *count2;
	}
    }
}

void
UniqueWordClasses::merge(VocabIndex c1, VocabIndex c2)
{
    /*
     * Destructively merge c2 into c1 ...
     */
    assert(c1 != c2);

    /*
     * Make sure both c1 and c2 are classes
     */
    assert(classVocab.getWord(c1) != 0);
    assert(classVocab.getWord(c2) != 0);

    /*
     * Update class membership
     */
    LHashIter<VocabIndex,VocabIndex> iter1(wordToClass);
    VocabIndex *clasz;
    VocabIndex word;

    while (clasz = iter1.next(word)) {
	if (*clasz == c2) {
	    *clasz = c1;
	}
    }

    /*
     * Update class contribs vector
     */
    classContribs.remove(c1);
    classContribs.remove(c2);
    {
	LHashIter<VocabIndex,LogP> iter(classContribs);
	VocabIndex clasz;
	LogP *logp;

	while (logp = iter.next(clasz)) {
	    *logp += NlogN(getCount(clasz, c1) + getCount(clasz, c2)) 
		   + NlogN(getCount(c1, clasz) + getCount(c2, clasz))
		   - NlogN(getCount(clasz, c1)) - NlogN(getCount(clasz, c2))
		   - NlogN(getCount(c1, clasz)) - NlogN(getCount(c2, clasz));
	}
    }

    /*
     * Update merge contribs matrix
     */
    mergeContribs.remove(c1);
    mergeContribs.remove(c2);
    {
	Map2Iter<VocabIndex,VocabIndex,LogP> iter1(mergeContribs);
	VocabIndex class1;

	while(iter1.next(class1)) {
	    mergeContribs.remove(class1, c1);
	    mergeContribs.remove(class1, c2);

	    Map2Iter2<VocabIndex,VocabIndex,LogP> iter2(mergeContribs, class1);
	    VocabIndex class2;
	    LogP *logp;

	    while (logp = iter2.next(class2)) {
		*logp += NlogN(getCount(class1,c1) + getCount(class2,c1) +
			       getCount(class1,c2) + getCount(class2,c2))
		       + NlogN(getCount(c1,class1) + getCount(c1,class2) +
			       getCount(c2,class1) + getCount(c2,class2))
		       - NlogN(getCount(class1,c1) + getCount(class2,c1))
		       - NlogN(getCount(class1,c2) + getCount(class2,c2))
		       - NlogN(getCount(c1,class1) + getCount(c1,class2))
		       - NlogN(getCount(c2,class1) + getCount(c2,class2));
	    }
	}
    }

    /*
     * Update counts
     */
    mergeCounts(classNgramCounts, c1, c2);
    mergeCounts(classNgramCountsR, c1, c2);

    /*
     * Get rid of old class
     */
    classVocab.remove(c2);

}

void
UniqueWordClasses::writeClasses(File &file)
{
    /*
     * Sort words by class and compute probabilities
     */
    Map2<VocabIndex,VocabIndex,Prob> classWordProbs;

    LHashIter<VocabIndex,NgramCount> wordIter(wordCounts);
    NgramCount *wordCount;
    VocabIndex word;

    while (wordCount = wordIter.next(word)) {
	VocabIndex *clasz = wordToClass.find(word);

	/*
	 * Ignore words that are not classed
	 */
	if (clasz == 0) continue;

	/*
	 * get total class count
	 */
	VocabIndex unigram[2];
	unigram[0] = *clasz;
	unigram[1] = Vocab_None;

	NgramCount *classCount = classNgramCounts.findCount(unigram);
	assert(classCount != 0);

	assert(*classCount != 0 || *wordCount == 0);

	Prob prob = (*classCount == 0) ? 0.0 : ((Prob)*wordCount) / *classCount;

	*classWordProbs.insert(*clasz,word) = prob;
    }

    /*
     * Dump class expansion in sorted order
     */
    VocabIter classIter(classVocab, true);
    VocabIndex clasz;

    while (classIter.next(clasz)) {
	Map2Iter2<VocabIndex,VocabIndex,Prob> wordIter(classWordProbs, clasz);
	VocabIndex word;

	Prob *prob;
	while (prob = wordIter.next(word)) {
	    fprintf(file, "%s %lg %s\n", classVocab.getWord(clasz),
					*prob, vocab.getWord(word));
	}
    }
}

void
UniqueWordClasses::writeContribs(File &file)
{
    fprintf(file, "=== class contribs ===\n");

    LHashIter<VocabIndex,LogP> iter(classContribs);
    VocabIndex clasz;
    LogP *logp;

    while (logp = iter.next(clasz)) {
	fprintf(file, "%s %lg\n", vocab.getWord(clasz), *logp);
    }

    fprintf(file, "=== merge contribs ===\n");

    Map2Iter<VocabIndex,VocabIndex,LogP> iter1(mergeContribs);
    VocabIndex class1;
    while(iter1.next(class1)) {
	Map2Iter2<VocabIndex,VocabIndex,LogP> iter2(mergeContribs, class1);
	VocabIndex class2;
	while (logp = iter2.next(class2)) {
	    fprintf(file, "%s %s %lg\n", vocab.getWord(class1),
					vocab.getWord(class2), *logp);
	}
    }
    fprintf(file, "=== end of contribs ===\n");
}

void
UniqueWordClasses::getStats(TextStats &stats)
{
    LHashIter<VocabIndex,NgramCount> wordIter(wordCounts);
    NgramCount *count;
    VocabIndex word;

    stats.numWords = 0;

    while (count = wordIter.next(word)) {
	if (word == vocab.seIndex()) {
	    stats.numSentences = *count;
	} else if (word != vocab.ssIndex()) {
	    stats.numWords += *count;
	}
    }

    stats.prob = totalLogP();
    stats.numOOVs = 0;
    stats.zeroProbs = 0;
}

inline NgramCount
UniqueWordClasses::getCount(VocabIndex c1, VocabIndex c2)
{
    VocabIndex bigram[3];
    bigram[0] = c1; bigram[1] = c2; bigram[2] = Vocab_None;

    NgramCount *count = classNgramCounts.findCount(bigram);
    return count ? *count : (NgramCount)0;
}

inline NgramCount
UniqueWordClasses::getCountR(VocabIndex c1, VocabIndex c2)
{
    VocabIndex bigram[3];
    bigram[0] = c1; bigram[1] = c2; bigram[2] = Vocab_None;

    NgramCount *count = classNgramCountsR.findCount(bigram);
    return count ? *count : (NgramCount)0;
}

LogP
UniqueWordClasses::totalLogP()
{
    /*
     * Total log likelihood =
     *	    \sum_{i} n(w_i) \log n(w_i)
     *	    + \sum_{i,j} n(c_i,c_j) \log n(c_i,c_j)
     *      - 2 \sum_{j} n(c_j) \log n(c_j)
     */

    LogP total = LogP_One;

    NgramCount *count;

    /*
     * summation over words
     */
    LHashIter<VocabIndex,NgramCount> wordIter(wordCounts);
    VocabIndex word;

    while (count = wordIter.next(word)) {
	total += NlogN(*count);
    }

    /*
     * summation over class bigrams
     */
    VocabIndex classNgram[3];
    NgramsIter iter2(classNgramCounts, classNgram, 2);

    while (count = iter2.next()) {
	total += NlogN(*count);
    }

    /*
     * summation over class unigrams
     */
    NgramsIter iter1(classNgramCounts, classNgram, 1);

    while (count = iter1.next()) {
	total -= 2.0 * NlogN(*count);
    }

    return total;
}

LogP
UniqueWordClasses::diffLogP(VocabIndex c1, VocabIndex c2)
{
    assert(c1 != c2);

    return computeMergeContrib(c1, c2)
	   - computeClassContrib(c1)
	   - computeClassContrib(c2);
}

/*
 * Compute the contribution of each class in an auxiliary array
 * (a la the s_k(i) in Brown et al 1992).
 *
 * classContrib(i) = \sum_j n(c_i,c_j) \log n(c_i,c_j)
 *              + \sum_j n(c_j,c_i) \log n(c_j,c_i)
 *		- n(c_i,c_i) \log n(c_i,c_i)
 */
void
UniqueWordClasses::computeClassContribs()
{
    if (debug(DEBUG_TRACE_MERGE)) {
	dout() << "computing class contrib vector\n";
    }

    /*
     * Clear the classContribs array
     */
    classContribs.clear();

    /*
     * Recompute in a single pass over all counts
     */

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -