⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ngram-class.cc

📁 这是一款很好用的工具包
💻 CC
📖 第 1 页 / 共 2 页
字号:
    VocabIndex bigram[3];
    NgramsIter iter(classNgramCounts, bigram, 2);
    NgramCount *count;

    while (count = iter.next()) {
	*classContribs.insert(bigram[0]) += NlogN(*count);
	if (bigram[0] != bigram[1]) {
	    *classContribs.insert(bigram[1]) += NlogN(*count);
	}
    }
}

LogP
UniqueWordClasses::computeClassContrib(VocabIndex c)
{
    Boolean found;
    LogP *logp = classContribs.insert(c, found);

    if (found) {
	return *logp;
    }

    LogP total = LogP_One;

    VocabIndex class1[2];
    class1[0] = c;
    class1[1] = Vocab_None;
    VocabIndex class2[2];

    NgramCount *count;
    NgramsIter iter1(classNgramCounts, class1, class2, 1);

    while (count = iter1.next()) {
	total += NlogN(*count);
    }

    NgramsIter iter2(classNgramCountsR, class1, class2, 1);

    while (count = iter2.next()) {
	if (class2[0] != c) {
	    total += NlogN(*count);
	}
    }

    /*
     * cache result
     */
    *logp = total;

    return total;
}

/*
 * Compute the contribution of a merge pair to log likelihood difference
 * in an auxiliary array
 * mergeContrib(c1, c2) =
 *	n(c_12,c_12) \log n(c_12,c_12)
 *    +	\sum_{i \neq 1,2}  n(c_i, c_12) \log n(c_i, c_12)
 *    + \sum_{j \neq 1,2}  n(c_12, c_j) \log n(c_12, c_j)
 *    + n(c_1, c_2) \log n(c_1, c_2) + n(c_2, c_1) \log n(c_2, c_1)
 *    - 2 n(c_12) \log n(c_12)
 */
void
UniqueWordClasses::computeMergeContribs()
{
    if (debug(DEBUG_TRACE_MERGE)) {
	dout() << "computing merge contrib matrix\n";
    }

    VocabIter iter1(classVocab);
    VocabIndex class1;

    while (iter1.next(class1)) {
	VocabIter iter2(classVocab);
	VocabIndex class2;

	while (iter2.next(class2)) {
	    if (class1 < class2) {
		(void)computeMergeContrib(class1, class2);
	    }
	}
    }
}

void
UniqueWordClasses::computeMergeContrib(VocabIndex c1)
{
    VocabIter iter(classVocab);
    VocabIndex c2;

    while (iter.next(c2)) {
	if (c1 != c2) {
	    (void)computeMergeContrib(c1, c2);
	}
    }
}

LogP
UniqueWordClasses::computeMergeContrib(VocabIndex c1, VocabIndex c2)
{
    assert(c1 != c2);

    /*
     * For efficiency we only store and compute for c1 < c2
     */
    if (c1 > c2) {
	VocabIndex tmp = c2;
	c2 = c1;
	c1 = tmp;
    }

    Boolean found;
    LogP *logp = mergeContribs.insert(c1, c2, found);
    if (found) {
	return *logp;
    }

    VocabIndex unigram[2]; unigram[1] = Vocab_None;
    VocabIndex bigram[3]; bigram[2] = Vocab_None;
    NgramCount *count;

    LogP total = LogP_One;

    /*
     * n(c_12,c_12) \log n(c_12,c_12)
     */
    total += NlogN(getCount(c1, c1) + getCount(c1, c2) +
			getCount(c2, c1) + getCount(c2, c2));

    /*
     * + \sum_{i \neq 1,2}  n(c_i, c_12) \log n(c_i, c_12)
     */
    unigram[0] = c1;
    NgramsIter iter0(classNgramCountsR, unigram, &bigram[1], 1);
    while (count = iter0.next()) {
	if (bigram[1] != c1 && bigram[1] != c2) {
	    total += NlogN(*count + getCountR(c2, bigram[1]));
	}
    }

    unigram[0] = c2;
    NgramsIter iter1(classNgramCountsR, unigram, &bigram[1], 1);
    while (count = iter1.next()) {
	if (bigram[1] != c1 && bigram[1] != c2 &&
	    getCountR(c1, bigram[1]) == 0)
	{
	    total += NlogN(*count);
	}
    }

    /*
     * + \sum_{j \neq 1,2}  n(c_12, c_j) \log n(c_12, c_j)
     */
    unigram[0] = c1;
    NgramsIter iter2(classNgramCounts, unigram, &bigram[1], 1);
    while (count = iter2.next()) {
	if (bigram[1] != c1 && bigram[1] != c2) {
	    total += NlogN(*count + getCount(c2, bigram[1]));
	}
    }

    unigram[0] = c2;
    NgramsIter iter3(classNgramCounts, unigram, &bigram[1], 1);
    while (count = iter3.next()) {
	if (bigram[1] != c1 && bigram[1] != c2 &&
	    getCount(c1, bigram[1]) == 0)
	{
	    total += NlogN(*count);
	}
    }

    /*
     * + n(c_1, c_2) \log n(c_1, c_2) + n(c_2, c_1) \log n(c_2, c_1)
     */
    total += NlogN(getCount(c1, c2)) + NlogN(getCount(c2, c1));

    /*
     * - 2 n(c_12) \log n(c_12)
     */
    NgramCount n1 = getCount(c1);
    NgramCount n2 = getCount(c2);

    total -= 2 * NlogN(n1 + n2);

    /*
     * + 2 [ n(c_1) \log n(c1) + n(c_2) \log n(c_2)
     */
    total += 2 * NlogN(n1) + 2 * NlogN(n2);

    /*
     * Cache result
     */
    *logp = total;

    return total;
}

/*
 * Find and perform best merge pair
 */
LogP
UniqueWordClasses::bestMerge(Vocab &mergeSet, VocabIndex &b1, VocabIndex &b2)
{
    VocabIndex bestC1 = Vocab_None, bestC2 = Vocab_None;
    LogP bestDiff;

    VocabIter iter1(mergeSet);
    VocabIndex c1;

    while (iter1.next(c1)) {
	VocabIter iter2(iter1);
	VocabIndex c2;

	while (iter2.next(c2)) {
	    LogP diff = diffLogP(c1, c2);

	    if (bestC1 == Vocab_None || diff > bestDiff) {
		bestC1 = c1;
		bestC2 = c2;
		bestDiff = diff;
	    }
	}
    }
    if (debug(DEBUG_TRACE_MERGE)) {
	dout() << "\tmerging " << mergeSet.getWord(bestC1)
	       << " and " << mergeSet.getWord(bestC2)
	       << " diff = " << bestDiff 
	       << endl;
    }

    merge(bestC1, bestC2);

    b1 = bestC1;
    b2 = bestC2;

    return bestDiff;
}

/*
 * Create a writable file if basename is defined and 
 * iter is a multiple of freq.  Used in saving preliminary results
 * during merging.
 */
static File *
logFile(const char *basename, unsigned freq, unsigned iter)
{
    if (freq > 0 && basename != 0 && iter%freq == 0) {
	makeArray(char, filename, strlen(basename) + 10);

	if (stdio_filename_p(basename)) {
	    printf("*** SAVE FOR ITERATION %06d ***\n", iter);
	    fflush(stdout);
	    strcpy(filename, basename);
	} else {
	    sprintf(filename, "%s.%06d%s", basename, iter,
			    compressed_filename_p(basename) ? COMPRESS_SUFFIX :
			      gzipped_filename_p(basename) ? GZIP_SUFFIX : "");
	}

	File *file = new File(filename, "w");
	assert(file != 0);

	return file;
    } else {
	return 0;
    }
}

/*
 * Full greedy class merging
 *	This is the first, O(V^3) algorithm in Brown et al. (1992)
 */
void
UniqueWordClasses::fullMerge(unsigned numClasses)
{
    if (numClasses < 1) {
	numClasses = 1;
    }

    TextStats stats;
    getStats(stats);

    unsigned numTokens = stats.numWords + stats.numSentences;

    unsigned iters = 0;

    if (debug(DEBUG_TRACE_MERGE)) {
	dout() << "iter " << iters
	       << ": " << classVocab.numWords() << " classes, "
	       << "perplexity = " << LogPtoPPL(stats.prob/numTokens)
	       << endl;
    }

    while (classVocab.numWords() > numClasses) {
	iters ++;
	VocabIndex b1, b2;

	bestMerge(classVocab, b1, b2);

	if (debug(DEBUG_TRACE_MERGE)) {
	    dout() << "iter " << iters
		   << ": " << classVocab.numWords() << " classes, "
		   << "perplexity = " << LogPtoPPL(totalLogP()/numTokens)
		   << endl;
	}

	/*
	 * Save classes and counts if and when requested
	 */
	File *cf = logFile(classesFile, saveFreq, iters);
	if (cf != 0) {
	    writeClasses(*cf);
	    delete cf;
	}

	cf = logFile(classCountsFile, saveFreq, iters);
	if (cf != 0) {
	    writeCounts(*cf);
	    delete cf;
	}
    }
}

/*
 * Order classes by count
 */
static UniqueWordClasses *orderClasses;

static int
orderByCount(VocabIndex c1, VocabIndex c2)
{
    return orderClasses->getCount(c2) - orderClasses->getCount(c1);
}

/*
 * Incremental greedy class merging
 *	This is the second, O(VC^2) algorithm in Brown et al. (1992)
 */
void
UniqueWordClasses::incrementalMerge(unsigned numClasses)
{
    if (numClasses < 1) {
	numClasses = 1;
    }

    TextStats stats;
    getStats(stats);

    unsigned numTokens = stats.numWords + stats.numSentences;

    unsigned iters = 0;

    if (debug(DEBUG_TRACE_MERGE)) {
	dout() << "iter " << iters
	       << ": " << classVocab.numWords() << " classes, "
	       << "perplexity = " << LogPtoPPL(stats.prob/numTokens)
	       << endl;
    }

    /*
     * Sort classes by count
     */
    makeArray(VocabIndex, listOfClasses, vocab.numWords());

    unsigned nClasses = 0;

    VocabIndex unigram[2];
    orderClasses = this;
    NgramsIter unigramIter(classNgramCounts, unigram, 1, orderByCount);
    NgramCount *classCount;

    while (classCount = unigramIter.next()) {
	if (classVocab.getWord(unigram[0]) != 0) {
	    listOfClasses[nClasses ++] = unigram[0];
	}
    }

    /*
     * Construct the subset of classes undergoing merging
     */
    SubVocab mergeSet(classVocab);

    /*
     * Add the first numClasses to the merge set
     */
    unsigned i;
    for (i = 0; i < nClasses && i < numClasses; i++) {
	if (debug(DEBUG_TRACE_MERGE)) {
	    dout() << "\tadding " << classVocab.getWord(listOfClasses[i])
		   << endl;
	}
	mergeSet.addWord(listOfClasses[i]);
    }

    /*
     * Now add one extra class at a time, and merge after each addition
     */
    for ( ; i < nClasses; i ++) {
	if (debug(DEBUG_TRACE_MERGE)) {
	    dout() << "\tadding " << classVocab.getWord(listOfClasses[i])
		   << endl;
	}
	mergeSet.addWord(listOfClasses[i]);

	iters ++;
	VocabIndex b1, b2;

	bestMerge(mergeSet, b1, b2);

	if (debug(DEBUG_TRACE_MERGE)) {
	    dout() << "iter " << iters
		   << ": " << classVocab.numWords() << " classes, "
		   << "perplexity = " << LogPtoPPL(totalLogP()/numTokens)
		   << endl;
	}

	mergeSet.remove(b2);

	/*
	 * Save classes and counts if and when requested
	 */
	File *cf = logFile(classesFile, saveFreq, iters);
	if (cf != 0) {
	    writeClasses(*cf);
	    delete cf;
	}

	cf = logFile(classCountsFile, saveFreq, iters);
	if (cf != 0) {
	    writeCounts(*cf);
	    delete cf;
	}
    }
}

/*
 * Simple interactive class merging
 */
void
interactiveMerge(UniqueWordClasses &classes)
{
    while (1) {
	char class1[30], class2[30];
	class1[0] = class2[0] = '\0';

	cout << "Enter two class names> ";
	cin >> class1 >> class2 ;

	if (!*class1) break;

	VocabIndex c1 = classes.classVocab.getIndex(class1);
	if (c1 == Vocab_None) {
	    cerr << class1 << " is not a valid class; try again.\n";
	    continue;
	}

	VocabIndex c2 = classes.classVocab.getIndex(class2);
	if (c2 == Vocab_None) {
	    cerr << class2 << " is not a valid class; try again.\n";
	    continue;
	}

	if (c1 == c2) {
	    cerr << "Classes must be distinct; try again.\n";
	    continue;
	}

	cout << "Merging class " << classes.classVocab.getWord(c1)
	     << " and " << classes.classVocab.getWord(c2) << endl;

	LogP delta = classes.diffLogP(c1, c2);
	cout << "Projected delta = " << delta << endl;

	classes.merge(c1, c2);

	{
	    if (classesFile) {
		File file(classesFile, "w");
		classes.writeClasses(file);
	    }

	    if (classCountsFile) {
		File file(classCountsFile, "w");
		classes.writeCounts(file);
	    }

	    if (debug >= DEBUG_PRINT_CONTRIBS) {
		File file(stdout);

		classes.writeContribs(file);

		classes.computeClassContribs();
		classes.computeMergeContribs();

		classes.writeContribs(file);
	    }

	    TextStats stats;
	    classes.getStats(stats);
	    cout << stats;
	}
    }
}

int
main(int argc, char **argv)
{
    setlocale(LC_CTYPE, "");
    setlocale(LC_COLLATE, "");

    Opt_Parse(argc, argv, options, Opt_Number(options), 0);

    if (version) {
	printVersion(RcsId);
	exit(0);
    }

    Vocab vocab;
    vocab.toLower() = toLower ? true : false;

    SubVocab classVocab(vocab);
    SubVocab noclassVocab(vocab);

    UniqueWordClasses classes(vocab, classVocab);
    classes.debugme(debug);

    if (vocabFile) {
	File file(vocabFile, "r");
	vocab.read(file);
    }

    if (noclassFile) {
	File file(noclassFile, "r");
	noclassVocab.read(file);
    } else {
	/*
	 * Assume <s> and </s> are not to be classed
	 */
	noclassVocab.addWord(vocab.ssIndex());
	noclassVocab.addWord(vocab.seIndex());
    }

    if (countsFile || textFile) {
	NgramStats bigrams(vocab, 2);
	bigrams.debugme(debug);

	/*
	 * Restrict vocabulary if user specied one
	 */
	if (vocabFile) {
	    bigrams.openVocab = false;
	}

	if (textFile) {
	    File file(textFile, "r");
	    bigrams.countFile(file);
	}

	if (countsFile) {
	    File file(countsFile, "r");
	    bigrams.read(file);
	}

	classes.initialize(bigrams, noclassVocab);
    } else {
	cerr << "Specify counts or text file as input.\n";
	exit(1);
    }

    if (numClasses > 0) {
	if (fullMerge) {
	    classes.fullMerge(numClasses);
	} else {
	    classes.incrementalMerge(numClasses);
	}
    }

    if (classesFile) {
	File file(classesFile, "w");

	classes.writeClasses(file);
    }

    if (classCountsFile) {
	File file(classCountsFile, "w");

	classes.writeCounts(file);
    }

    if (debug >= DEBUG_TEXTSTATS) {
	TextStats stats;
	classes.getStats(stats);
	cerr << stats;
    }

    if (interact) {
	interactiveMerge(classes);
    }

    exit(0);
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -