📄 ngram-class.cc
字号:
/*
* ngram-class --
* Induce class ngram models from counts
*
*/
#ifndef lint
static char Copyright[] = "Copyright (c) 1999-2006 SRI International. All Rights Reserved.";
static char RcsId[] = "@(#)$Id: ngram-class.cc,v 1.26 2006/01/05 20:21:27 stolcke Exp $";
#endif
#include <iostream>
using namespace std;
#include <stdlib.h>
#include <locale.h>
#include <assert.h>
#include "option.h"
#include "version.h"
#include "zio.h"
#include "File.h"
#include "Debug.h"
#include "Prob.h"
#include "Vocab.h"
#include "SubVocab.h"
#include "TextStats.h"
#include "NgramStats.h"
#include "LHash.cc"
#include "Map2.cc"
#include "Array.cc"
#include "Trie.cc"
#ifdef INSTANTIATE_TEMPLATES
INSTANTIATE_MAP1(VocabIndex,VocabIndex,LogP);
#ifdef USE_SARRAY
// XXX: avoid multiple definitions with NgramLM
INSTANTIATE_LHASH(VocabIndex,LogP);
#endif
#endif
#define DEBUG_TEXTSTATS 1
#define DEBUG_TRACE_MERGE 2
#define DEBUG_PRINT_CONTRIBS 3 // in interactive mode
static int version = 0;
static char *vocabFile = 0;
static int toLower = 0;
static char *noclassFile = 0;
static char *countsFile = 0;
static char *textFile = 0;
static char *classesFile = 0;
static char *classCountsFile = 0;
static unsigned numClasses = 1;
static int fullMerge = 0;
static int interact = 0;
static int debug = 0;
static int saveFreq = 0;
static Option options[] = {
{ OPT_TRUE, "version", &version, "print version information" },
{ OPT_UINT, "debug", &debug, "debugging level" },
{ OPT_STRING, "vocab", &vocabFile, "vocab file" },
{ OPT_TRUE, "tolower", &toLower, "map vocabulary to lowercase" },
{ OPT_STRING, "noclass-vocab", &noclassFile, "vocabulary not to be classed" },
{ OPT_STRING, "counts", &countsFile, "counts file to read" },
{ OPT_STRING, "text", &textFile, "text file to count" },
{ OPT_UINT, "numclasses", &numClasses, "number of classes to induce" },
{ OPT_TRUE, "full", &fullMerge, "perform full greedy merging" },
{ OPT_FALSE, "incremental", &fullMerge, "perform incremental greedy merging" },
{ OPT_TRUE, "interact", &interact, "perform interactive merging" },
{ OPT_STRING, "class-counts", &classCountsFile, "class N-gram count output" },
{ OPT_STRING, "classes", &classesFile, "class definitions output" },
{ OPT_INT, "save", &saveFreq, "save classes/counts every this many iterations" },
};
/*
* Compute n \log n correctly even if n = 0.
*/
static inline LogP
NlogN(NgramCount count)
{
if (count == 0) {
return LogP_One;
} else {
return count * ProbToLogP(count);
}
}
/*
* Many-to-one class-to-word mapping
*/
class UniqueWordClasses: public Debug
{
public:
UniqueWordClasses(Vocab &vocab, SubVocab &classVocab);
~UniqueWordClasses() {};
VocabIndex newClass(); // create a new class
void initialize(NgramStats &counts, SubVocab &noclassVocab);
// initialize classes from word counts
void merge(VocabIndex c1, VocabIndex c2); // merge classes
LogP bestMerge(Vocab &mergeSet, VocabIndex &c1, VocabIndex &c2);
// single best merge step
void fullMerge(unsigned numClases); // full greedy merging
void incrementalMerge(unsigned numClases); // incremental merging
void writeClasses(File &file); // write class definitions
void writeCounts(File &file) // write class ngram counts
{ classNgramCounts.write(file, 0, true); };
NgramCount getCount(VocabIndex c1) // class unigram count
{ return getCount(c1, Vocab_None); };
inline NgramCount getCount(VocabIndex c1, VocabIndex c2);
// class bigram count
inline NgramCount getCountR(VocabIndex c1, VocabIndex c2);
// reverse bigram count
LogP totalLogP(); // total log likelihood
LogP diffLogP(VocabIndex c1, VocabIndex c2);
// log likelihood difference for
// class merging
void computeClassContribs(); // recompute classContribs vector
void computeMergeContribs(); // recompute mergeContribs matrix
void getStats(TextStats &stats);
void writeContribs(File &file); // dump contrib vector
Vocab &vocab;
Vocab &classVocab;
protected:
LHash<VocabIndex,VocabIndex> wordToClass; // word->class map
LHash<VocabIndex,NgramCount> wordCounts; // word counts
NgramStats classNgramCounts; // trie of (C,...,C) counts
NgramStats classNgramCountsR; // trie of reversed counts,
LHash<VocabIndex,LogP> classContribs; // class contributions to
// total log likelihood
Map2<VocabIndex,VocabIndex,LogP> mergeContribs;
// merge-pair contributions
// to delta log likelihood
void computeMergeContrib(VocabIndex c1); // recompute mergeContribs
LogP computeClassContrib(VocabIndex c); // recompute classContribs
LogP computeMergeContrib(VocabIndex c1, VocabIndex c2); // same
void mergeCounts(NgramStats &counts, VocabIndex c1, VocabIndex c2);
unsigned genSymCount;
};
UniqueWordClasses::UniqueWordClasses(Vocab &vocab, SubVocab &classVocab)
: vocab(vocab), classVocab(classVocab),
classNgramCounts(vocab, 2), classNgramCountsR(vocab, 2),
genSymCount(0)
{
/*
* Make sure the classes are subset of base vocabulary
*/
assert(&vocab == &classVocab.baseVocab());
};
VocabIndex
UniqueWordClasses::newClass()
{
char className[30];
sprintf(className, "CLASS-%05u", ++genSymCount);
assert(vocab.getIndex(className) == Vocab_None);
return classVocab.addWord(className);
}
void
UniqueWordClasses::initialize(NgramStats &counts, SubVocab &noclassVocab)
{
VocabIndex ngram[3];
/*
* Make sure the noclassVocab is subset of base vocabulary
*/
assert(&vocab == &noclassVocab.baseVocab());
/*
* Enumerate unigrams
*/
NgramsIter iter1(counts, ngram, 1);
NgramCount *count;
while (count = iter1.next()) {
VocabIndex classUnigram[2];
if (noclassVocab.getWord(ngram[0]) != 0) {
/*
* A word that is not supposed to be classed ...
*/
classUnigram[0] = ngram[0];
} else {
/*
* ... one that is: find or create class for it.
*/
Boolean found;
VocabIndex *class1 = wordToClass.insert(ngram[0], found);
if (!found) {
*class1 = newClass();
if (debug(DEBUG_TRACE_MERGE)) {
dout() << "\tcreating " << classVocab.getWord(*class1)
<< " for word " << vocab.getWord(ngram[0])
<< endl;
}
}
classUnigram[0] = *class1;
}
classUnigram[1] = Vocab_None;
*classNgramCounts.insertCount(classUnigram) += *count;
*classNgramCountsR.insertCount(classUnigram) += *count;
*wordCounts.insert(ngram[0]) += *count;
}
/*
* Enumerate bigrams
*/
NgramsIter iter2(counts, ngram, 2);
while (count = iter2.next()) {
VocabIndex classBigram[3];
if (noclassVocab.getWord(ngram[0]) != 0) {
classBigram[0] = ngram[0];
} else {
VocabIndex *class1 = wordToClass.find(ngram[0]);
if (class1 == 0) {
cerr << "word 1 in bigram \"" << (vocab.use(), ngram)
<< "\" has no unigram count\n";
exit(1);
}
classBigram[0] = *class1;
}
if (noclassVocab.getWord(ngram[1]) != 0) {
classBigram[1] = ngram[1];
} else {
VocabIndex *class2 = wordToClass.find(ngram[1]);
if (class2 == 0) {
cerr << "word 2 in bigram \"" << (vocab.use(), ngram)
<< "\" has no unigram count\n";
exit(1);
}
classBigram[1] = *class2;
}
classBigram[2] = Vocab_None;
*classNgramCounts.insertCount(classBigram) += *count;
Vocab::reverse(classBigram);
*classNgramCountsR.insertCount(classBigram) += *count;
}
}
void
UniqueWordClasses::mergeCounts(NgramStats &counts,
VocabIndex c1, VocabIndex c2)
{
VocabIndex unigram[2]; unigram[1] = Vocab_None;
VocabIndex unigram2[2]; unigram2[1] = Vocab_None;
VocabIndex bigram[3]; bigram[2] = Vocab_None;
NgramCount *count;
/*
* Update Counts
* 1) add row c2 to c1 row
*/
unigram[0] = c2;
NgramsIter iter2(counts, unigram, unigram2, 1);
while(count = iter2.next()) {
bigram[0] = c1;
bigram[1] = unigram2[0];
*counts.insertCount(bigram) += *count;
}
/*
* 2) Remove row c2
*/
unigram[0] = c1;
unigram2[0] = c2;
count = counts.removeCount(unigram2);
if (count != 0) {
*counts.insertCount(unigram) += *count;
}
/*
* 3) add column c2 to column c1 and remove column c2
*/
NgramsIter iter3(counts, unigram, 1);
while (count = iter3.next()) {
bigram[0] = unigram[0];
bigram[1] = c2;
NgramCount *count2 = counts.removeCount(bigram);
if (count2 != 0) {
bigram[1] = c1;
*counts.insertCount(bigram) += *count2;
}
}
}
void
UniqueWordClasses::merge(VocabIndex c1, VocabIndex c2)
{
/*
* Destructively merge c2 into c1 ...
*/
assert(c1 != c2);
/*
* Make sure both c1 and c2 are classes
*/
assert(classVocab.getWord(c1) != 0);
assert(classVocab.getWord(c2) != 0);
/*
* Update class membership
*/
LHashIter<VocabIndex,VocabIndex> iter1(wordToClass);
VocabIndex *clasz;
VocabIndex word;
while (clasz = iter1.next(word)) {
if (*clasz == c2) {
*clasz = c1;
}
}
/*
* Update class contribs vector
*/
classContribs.remove(c1);
classContribs.remove(c2);
{
LHashIter<VocabIndex,LogP> iter(classContribs);
VocabIndex clasz;
LogP *logp;
while (logp = iter.next(clasz)) {
*logp += NlogN(getCount(clasz, c1) + getCount(clasz, c2))
+ NlogN(getCount(c1, clasz) + getCount(c2, clasz))
- NlogN(getCount(clasz, c1)) - NlogN(getCount(clasz, c2))
- NlogN(getCount(c1, clasz)) - NlogN(getCount(c2, clasz));
}
}
/*
* Update merge contribs matrix
*/
mergeContribs.remove(c1);
mergeContribs.remove(c2);
{
Map2Iter<VocabIndex,VocabIndex,LogP> iter1(mergeContribs);
VocabIndex class1;
while(iter1.next(class1)) {
mergeContribs.remove(class1, c1);
mergeContribs.remove(class1, c2);
Map2Iter2<VocabIndex,VocabIndex,LogP> iter2(mergeContribs, class1);
VocabIndex class2;
LogP *logp;
while (logp = iter2.next(class2)) {
*logp += NlogN(getCount(class1,c1) + getCount(class2,c1) +
getCount(class1,c2) + getCount(class2,c2))
+ NlogN(getCount(c1,class1) + getCount(c1,class2) +
getCount(c2,class1) + getCount(c2,class2))
- NlogN(getCount(class1,c1) + getCount(class2,c1))
- NlogN(getCount(class1,c2) + getCount(class2,c2))
- NlogN(getCount(c1,class1) + getCount(c1,class2))
- NlogN(getCount(c2,class1) + getCount(c2,class2));
}
}
}
/*
* Update counts
*/
mergeCounts(classNgramCounts, c1, c2);
mergeCounts(classNgramCountsR, c1, c2);
/*
* Get rid of old class
*/
classVocab.remove(c2);
}
void
UniqueWordClasses::writeClasses(File &file)
{
/*
* Sort words by class and compute probabilities
*/
Map2<VocabIndex,VocabIndex,Prob> classWordProbs;
LHashIter<VocabIndex,NgramCount> wordIter(wordCounts);
NgramCount *wordCount;
VocabIndex word;
while (wordCount = wordIter.next(word)) {
VocabIndex *clasz = wordToClass.find(word);
/*
* Ignore words that are not classed
*/
if (clasz == 0) continue;
/*
* get total class count
*/
VocabIndex unigram[2];
unigram[0] = *clasz;
unigram[1] = Vocab_None;
NgramCount *classCount = classNgramCounts.findCount(unigram);
assert(classCount != 0);
assert(*classCount != 0 || *wordCount == 0);
Prob prob = (*classCount == 0) ? 0.0 : ((Prob)*wordCount) / *classCount;
*classWordProbs.insert(*clasz,word) = prob;
}
/*
* Dump class expansion in sorted order
*/
VocabIter classIter(classVocab, true);
VocabIndex clasz;
while (classIter.next(clasz)) {
Map2Iter2<VocabIndex,VocabIndex,Prob> wordIter(classWordProbs, clasz);
VocabIndex word;
Prob *prob;
while (prob = wordIter.next(word)) {
fprintf(file, "%s %lg %s\n", classVocab.getWord(clasz),
*prob, vocab.getWord(word));
}
}
}
void
UniqueWordClasses::writeContribs(File &file)
{
fprintf(file, "=== class contribs ===\n");
LHashIter<VocabIndex,LogP> iter(classContribs);
VocabIndex clasz;
LogP *logp;
while (logp = iter.next(clasz)) {
fprintf(file, "%s %lg\n", vocab.getWord(clasz), *logp);
}
fprintf(file, "=== merge contribs ===\n");
Map2Iter<VocabIndex,VocabIndex,LogP> iter1(mergeContribs);
VocabIndex class1;
while(iter1.next(class1)) {
Map2Iter2<VocabIndex,VocabIndex,LogP> iter2(mergeContribs, class1);
VocabIndex class2;
while (logp = iter2.next(class2)) {
fprintf(file, "%s %s %lg\n", vocab.getWord(class1),
vocab.getWord(class2), *logp);
}
}
fprintf(file, "=== end of contribs ===\n");
}
void
UniqueWordClasses::getStats(TextStats &stats)
{
LHashIter<VocabIndex,NgramCount> wordIter(wordCounts);
NgramCount *count;
VocabIndex word;
stats.numWords = 0;
while (count = wordIter.next(word)) {
if (word == vocab.seIndex()) {
stats.numSentences = *count;
} else if (word != vocab.ssIndex()) {
stats.numWords += *count;
}
}
stats.prob = totalLogP();
stats.numOOVs = 0;
stats.zeroProbs = 0;
}
inline NgramCount
UniqueWordClasses::getCount(VocabIndex c1, VocabIndex c2)
{
VocabIndex bigram[3];
bigram[0] = c1; bigram[1] = c2; bigram[2] = Vocab_None;
NgramCount *count = classNgramCounts.findCount(bigram);
return count ? *count : (NgramCount)0;
}
inline NgramCount
UniqueWordClasses::getCountR(VocabIndex c1, VocabIndex c2)
{
VocabIndex bigram[3];
bigram[0] = c1; bigram[1] = c2; bigram[2] = Vocab_None;
NgramCount *count = classNgramCountsR.findCount(bigram);
return count ? *count : (NgramCount)0;
}
LogP
UniqueWordClasses::totalLogP()
{
/*
* Total log likelihood =
* \sum_{i} n(w_i) \log n(w_i)
* + \sum_{i,j} n(c_i,c_j) \log n(c_i,c_j)
* - 2 \sum_{j} n(c_j) \log n(c_j)
*/
LogP total = LogP_One;
NgramCount *count;
/*
* summation over words
*/
LHashIter<VocabIndex,NgramCount> wordIter(wordCounts);
VocabIndex word;
while (count = wordIter.next(word)) {
total += NlogN(*count);
}
/*
* summation over class bigrams
*/
VocabIndex classNgram[3];
NgramsIter iter2(classNgramCounts, classNgram, 2);
while (count = iter2.next()) {
total += NlogN(*count);
}
/*
* summation over class unigrams
*/
NgramsIter iter1(classNgramCounts, classNgram, 1);
while (count = iter1.next()) {
total -= 2.0 * NlogN(*count);
}
return total;
}
LogP
UniqueWordClasses::diffLogP(VocabIndex c1, VocabIndex c2)
{
assert(c1 != c2);
return computeMergeContrib(c1, c2)
- computeClassContrib(c1)
- computeClassContrib(c2);
}
/*
* Compute the contribution of each class in an auxiliary array
* (a la the s_k(i) in Brown et al 1992).
*
* classContrib(i) = \sum_j n(c_i,c_j) \log n(c_i,c_j)
* + \sum_j n(c_j,c_i) \log n(c_j,c_i)
* - n(c_i,c_i) \log n(c_i,c_i)
*/
void
UniqueWordClasses::computeClassContribs()
{
if (debug(DEBUG_TRACE_MERGE)) {
dout() << "computing class contrib vector\n";
}
/*
* Clear the classContribs array
*/
classContribs.clear();
/*
* Recompute in a single pass over all counts
*/
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -