📄 ngram-class.cc
字号:
VocabIndex bigram[3];
NgramsIter iter(classNgramCounts, bigram, 2);
NgramCount *count;
while (count = iter.next()) {
*classContribs.insert(bigram[0]) += NlogN(*count);
if (bigram[0] != bigram[1]) {
*classContribs.insert(bigram[1]) += NlogN(*count);
}
}
}
LogP
UniqueWordClasses::computeClassContrib(VocabIndex c)
{
Boolean found;
LogP *logp = classContribs.insert(c, found);
if (found) {
return *logp;
}
LogP total = LogP_One;
VocabIndex class1[2];
class1[0] = c;
class1[1] = Vocab_None;
VocabIndex class2[2];
NgramCount *count;
NgramsIter iter1(classNgramCounts, class1, class2, 1);
while (count = iter1.next()) {
total += NlogN(*count);
}
NgramsIter iter2(classNgramCountsR, class1, class2, 1);
while (count = iter2.next()) {
if (class2[0] != c) {
total += NlogN(*count);
}
}
/*
* cache result
*/
*logp = total;
return total;
}
/*
* Compute the contribution of a merge pair to log likelihood difference
* in an auxiliary array
* mergeContrib(c1, c2) =
* n(c_12,c_12) \log n(c_12,c_12)
* + \sum_{i \neq 1,2} n(c_i, c_12) \log n(c_i, c_12)
* + \sum_{j \neq 1,2} n(c_12, c_j) \log n(c_12, c_j)
* + n(c_1, c_2) \log n(c_1, c_2) + n(c_2, c_1) \log n(c_2, c_1)
* - 2 n(c_12) \log n(c_12)
*/
void
UniqueWordClasses::computeMergeContribs()
{
if (debug(DEBUG_TRACE_MERGE)) {
dout() << "computing merge contrib matrix\n";
}
VocabIter iter1(classVocab);
VocabIndex class1;
while (iter1.next(class1)) {
VocabIter iter2(classVocab);
VocabIndex class2;
while (iter2.next(class2)) {
if (class1 < class2) {
(void)computeMergeContrib(class1, class2);
}
}
}
}
void
UniqueWordClasses::computeMergeContrib(VocabIndex c1)
{
VocabIter iter(classVocab);
VocabIndex c2;
while (iter.next(c2)) {
if (c1 != c2) {
(void)computeMergeContrib(c1, c2);
}
}
}
LogP
UniqueWordClasses::computeMergeContrib(VocabIndex c1, VocabIndex c2)
{
assert(c1 != c2);
/*
* For efficiency we only store and compute for c1 < c2
*/
if (c1 > c2) {
VocabIndex tmp = c2;
c2 = c1;
c1 = tmp;
}
Boolean found;
LogP *logp = mergeContribs.insert(c1, c2, found);
if (found) {
return *logp;
}
VocabIndex unigram[2]; unigram[1] = Vocab_None;
VocabIndex bigram[3]; bigram[2] = Vocab_None;
NgramCount *count;
LogP total = LogP_One;
/*
* n(c_12,c_12) \log n(c_12,c_12)
*/
total += NlogN(getCount(c1, c1) + getCount(c1, c2) +
getCount(c2, c1) + getCount(c2, c2));
/*
* + \sum_{i \neq 1,2} n(c_i, c_12) \log n(c_i, c_12)
*/
unigram[0] = c1;
NgramsIter iter0(classNgramCountsR, unigram, &bigram[1], 1);
while (count = iter0.next()) {
if (bigram[1] != c1 && bigram[1] != c2) {
total += NlogN(*count + getCountR(c2, bigram[1]));
}
}
unigram[0] = c2;
NgramsIter iter1(classNgramCountsR, unigram, &bigram[1], 1);
while (count = iter1.next()) {
if (bigram[1] != c1 && bigram[1] != c2 &&
getCountR(c1, bigram[1]) == 0)
{
total += NlogN(*count);
}
}
/*
* + \sum_{j \neq 1,2} n(c_12, c_j) \log n(c_12, c_j)
*/
unigram[0] = c1;
NgramsIter iter2(classNgramCounts, unigram, &bigram[1], 1);
while (count = iter2.next()) {
if (bigram[1] != c1 && bigram[1] != c2) {
total += NlogN(*count + getCount(c2, bigram[1]));
}
}
unigram[0] = c2;
NgramsIter iter3(classNgramCounts, unigram, &bigram[1], 1);
while (count = iter3.next()) {
if (bigram[1] != c1 && bigram[1] != c2 &&
getCount(c1, bigram[1]) == 0)
{
total += NlogN(*count);
}
}
/*
* + n(c_1, c_2) \log n(c_1, c_2) + n(c_2, c_1) \log n(c_2, c_1)
*/
total += NlogN(getCount(c1, c2)) + NlogN(getCount(c2, c1));
/*
* - 2 n(c_12) \log n(c_12)
*/
NgramCount n1 = getCount(c1);
NgramCount n2 = getCount(c2);
total -= 2 * NlogN(n1 + n2);
/*
* + 2 [ n(c_1) \log n(c1) + n(c_2) \log n(c_2)
*/
total += 2 * NlogN(n1) + 2 * NlogN(n2);
/*
* Cache result
*/
*logp = total;
return total;
}
/*
* Find and perform best merge pair
*/
LogP
UniqueWordClasses::bestMerge(Vocab &mergeSet, VocabIndex &b1, VocabIndex &b2)
{
VocabIndex bestC1 = Vocab_None, bestC2 = Vocab_None;
LogP bestDiff;
VocabIter iter1(mergeSet);
VocabIndex c1;
while (iter1.next(c1)) {
VocabIter iter2(iter1);
VocabIndex c2;
while (iter2.next(c2)) {
LogP diff = diffLogP(c1, c2);
if (bestC1 == Vocab_None || diff > bestDiff) {
bestC1 = c1;
bestC2 = c2;
bestDiff = diff;
}
}
}
if (debug(DEBUG_TRACE_MERGE)) {
dout() << "\tmerging " << mergeSet.getWord(bestC1)
<< " and " << mergeSet.getWord(bestC2)
<< " diff = " << bestDiff
<< endl;
}
merge(bestC1, bestC2);
b1 = bestC1;
b2 = bestC2;
return bestDiff;
}
/*
* Create a writable file if basename is defined and
* iter is a multiple of freq. Used in saving preliminary results
* during merging.
*/
static File *
logFile(const char *basename, unsigned freq, unsigned iter)
{
if (freq > 0 && basename != 0 && iter%freq == 0) {
makeArray(char, filename, strlen(basename) + 10);
if (stdio_filename_p(basename)) {
printf("*** SAVE FOR ITERATION %06d ***\n", iter);
fflush(stdout);
strcpy(filename, basename);
} else {
sprintf(filename, "%s.%06d%s", basename, iter,
compressed_filename_p(basename) ? COMPRESS_SUFFIX :
gzipped_filename_p(basename) ? GZIP_SUFFIX : "");
}
File *file = new File(filename, "w");
assert(file != 0);
return file;
} else {
return 0;
}
}
/*
* Full greedy class merging
* This is the first, O(V^3) algorithm in Brown et al. (1992)
*/
void
UniqueWordClasses::fullMerge(unsigned numClasses)
{
if (numClasses < 1) {
numClasses = 1;
}
TextStats stats;
getStats(stats);
unsigned numTokens = stats.numWords + stats.numSentences;
unsigned iters = 0;
if (debug(DEBUG_TRACE_MERGE)) {
dout() << "iter " << iters
<< ": " << classVocab.numWords() << " classes, "
<< "perplexity = " << LogPtoPPL(stats.prob/numTokens)
<< endl;
}
while (classVocab.numWords() > numClasses) {
iters ++;
VocabIndex b1, b2;
bestMerge(classVocab, b1, b2);
if (debug(DEBUG_TRACE_MERGE)) {
dout() << "iter " << iters
<< ": " << classVocab.numWords() << " classes, "
<< "perplexity = " << LogPtoPPL(totalLogP()/numTokens)
<< endl;
}
/*
* Save classes and counts if and when requested
*/
File *cf = logFile(classesFile, saveFreq, iters);
if (cf != 0) {
writeClasses(*cf);
delete cf;
}
cf = logFile(classCountsFile, saveFreq, iters);
if (cf != 0) {
writeCounts(*cf);
delete cf;
}
}
}
/*
* Order classes by count
*/
static UniqueWordClasses *orderClasses;
static int
orderByCount(VocabIndex c1, VocabIndex c2)
{
return orderClasses->getCount(c2) - orderClasses->getCount(c1);
}
/*
* Incremental greedy class merging
* This is the second, O(VC^2) algorithm in Brown et al. (1992)
*/
void
UniqueWordClasses::incrementalMerge(unsigned numClasses)
{
if (numClasses < 1) {
numClasses = 1;
}
TextStats stats;
getStats(stats);
unsigned numTokens = stats.numWords + stats.numSentences;
unsigned iters = 0;
if (debug(DEBUG_TRACE_MERGE)) {
dout() << "iter " << iters
<< ": " << classVocab.numWords() << " classes, "
<< "perplexity = " << LogPtoPPL(stats.prob/numTokens)
<< endl;
}
/*
* Sort classes by count
*/
makeArray(VocabIndex, listOfClasses, vocab.numWords());
unsigned nClasses = 0;
VocabIndex unigram[2];
orderClasses = this;
NgramsIter unigramIter(classNgramCounts, unigram, 1, orderByCount);
NgramCount *classCount;
while (classCount = unigramIter.next()) {
if (classVocab.getWord(unigram[0]) != 0) {
listOfClasses[nClasses ++] = unigram[0];
}
}
/*
* Construct the subset of classes undergoing merging
*/
SubVocab mergeSet(classVocab);
/*
* Add the first numClasses to the merge set
*/
unsigned i;
for (i = 0; i < nClasses && i < numClasses; i++) {
if (debug(DEBUG_TRACE_MERGE)) {
dout() << "\tadding " << classVocab.getWord(listOfClasses[i])
<< endl;
}
mergeSet.addWord(listOfClasses[i]);
}
/*
* Now add one extra class at a time, and merge after each addition
*/
for ( ; i < nClasses; i ++) {
if (debug(DEBUG_TRACE_MERGE)) {
dout() << "\tadding " << classVocab.getWord(listOfClasses[i])
<< endl;
}
mergeSet.addWord(listOfClasses[i]);
iters ++;
VocabIndex b1, b2;
bestMerge(mergeSet, b1, b2);
if (debug(DEBUG_TRACE_MERGE)) {
dout() << "iter " << iters
<< ": " << classVocab.numWords() << " classes, "
<< "perplexity = " << LogPtoPPL(totalLogP()/numTokens)
<< endl;
}
mergeSet.remove(b2);
/*
* Save classes and counts if and when requested
*/
File *cf = logFile(classesFile, saveFreq, iters);
if (cf != 0) {
writeClasses(*cf);
delete cf;
}
cf = logFile(classCountsFile, saveFreq, iters);
if (cf != 0) {
writeCounts(*cf);
delete cf;
}
}
}
/*
* Simple interactive class merging
*/
void
interactiveMerge(UniqueWordClasses &classes)
{
while (1) {
char class1[30], class2[30];
class1[0] = class2[0] = '\0';
cout << "Enter two class names> ";
cin >> class1 >> class2 ;
if (!*class1) break;
VocabIndex c1 = classes.classVocab.getIndex(class1);
if (c1 == Vocab_None) {
cerr << class1 << " is not a valid class; try again.\n";
continue;
}
VocabIndex c2 = classes.classVocab.getIndex(class2);
if (c2 == Vocab_None) {
cerr << class2 << " is not a valid class; try again.\n";
continue;
}
if (c1 == c2) {
cerr << "Classes must be distinct; try again.\n";
continue;
}
cout << "Merging class " << classes.classVocab.getWord(c1)
<< " and " << classes.classVocab.getWord(c2) << endl;
LogP delta = classes.diffLogP(c1, c2);
cout << "Projected delta = " << delta << endl;
classes.merge(c1, c2);
{
if (classesFile) {
File file(classesFile, "w");
classes.writeClasses(file);
}
if (classCountsFile) {
File file(classCountsFile, "w");
classes.writeCounts(file);
}
if (debug >= DEBUG_PRINT_CONTRIBS) {
File file(stdout);
classes.writeContribs(file);
classes.computeClassContribs();
classes.computeMergeContribs();
classes.writeContribs(file);
}
TextStats stats;
classes.getStats(stats);
cout << stats;
}
}
}
int
main(int argc, char **argv)
{
setlocale(LC_CTYPE, "");
setlocale(LC_COLLATE, "");
Opt_Parse(argc, argv, options, Opt_Number(options), 0);
if (version) {
printVersion(RcsId);
exit(0);
}
Vocab vocab;
vocab.toLower() = toLower ? true : false;
SubVocab classVocab(vocab);
SubVocab noclassVocab(vocab);
UniqueWordClasses classes(vocab, classVocab);
classes.debugme(debug);
if (vocabFile) {
File file(vocabFile, "r");
vocab.read(file);
}
if (noclassFile) {
File file(noclassFile, "r");
noclassVocab.read(file);
} else {
/*
* Assume <s> and </s> are not to be classed
*/
noclassVocab.addWord(vocab.ssIndex());
noclassVocab.addWord(vocab.seIndex());
}
if (countsFile || textFile) {
NgramStats bigrams(vocab, 2);
bigrams.debugme(debug);
/*
* Restrict vocabulary if user specied one
*/
if (vocabFile) {
bigrams.openVocab = false;
}
if (textFile) {
File file(textFile, "r");
bigrams.countFile(file);
}
if (countsFile) {
File file(countsFile, "r");
bigrams.read(file);
}
classes.initialize(bigrams, noclassVocab);
} else {
cerr << "Specify counts or text file as input.\n";
exit(1);
}
if (numClasses > 0) {
if (fullMerge) {
classes.fullMerge(numClasses);
} else {
classes.incrementalMerge(numClasses);
}
}
if (classesFile) {
File file(classesFile, "w");
classes.writeClasses(file);
}
if (classCountsFile) {
File file(classCountsFile, "w");
classes.writeCounts(file);
}
if (debug >= DEBUG_TEXTSTATS) {
TextStats stats;
classes.getStats(stats);
cerr << stats;
}
if (interact) {
interactiveMerge(classes);
}
exit(0);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -