📄 disambig.cc
字号:
if (bestSymbol == Vocab_None || *symbolProb > maxPosterior) {
bestSymbol = symbol;
maxPosterior = *symbolProb;
}
totalPosterior = AddLogP(totalPosterior, *symbolProb);
}
if (bestSymbol == Vocab_None) {
cerr << "no forward-backward state for position "
<< pos << endl;
return 0;
}
hiddenWids[0][pos] = bestSymbol;
/*
* Print posterior probabilities
*/
if (posteriors) {
cout << vocab.getWord(wids[pos]) << "\t";
symbolIter.init();
while (symbolProb = symbolIter.next(symbol)) {
LogP2 posterior = *symbolProb - totalPosterior;
cout << " " << map.vocab2.getWord(symbol)
<< " " << (logMap ? posterior : LogPtoProb(posterior));
}
cout << endl;
}
/*
* update v1-v2 counts if requested
*/
if (counts) {
symbolIter.init();
while (symbolProb = symbolIter.next(symbol)) {
LogP2 posterior = *symbolProb - totalPosterior;
counts->put(wids[pos], symbol,
counts->get(wids[pos], symbol) +
LogPtoProb(posteriors));
}
}
}
/*
* Return total string probability summing over all paths
*/
totalProb[0] = trellis.sumLogP(len-1);
hiddenWids[0][len] = Vocab_None;
return 1;
}
}
/*
* Get one input sentences at a time, map it to wids,
* disambiguate it, and print out the result
*/
void
disambiguateFile(File &file, VocabMap &map, LM &lm, VocabMap *counts)
{
char *line;
VocabString sentence[maxWordsPerLine];
while (line = file.getline()) {
unsigned numWords = Vocab::parseWords(line, sentence, maxWordsPerLine);
if (numWords == maxWordsPerLine) {
file.position() << "too many words per sentence\n";
} else {
VocabIndex wids[maxWordsPerLine + 2];
map.vocab1.getIndices(sentence, &wids[1], maxWordsPerLine,
map.vocab1.unkIndex());
wids[0] = map.vocab1.ssIndex();
if (noEOS) {
wids[numWords + 1] = Vocab_None;
} else {
wids[numWords + 1] = map.vocab1.seIndex();
wids[numWords + 2] = Vocab_None;
}
makeArray(VocabIndex *, hiddenWids, numNbest);
makeArray(VocabString, hiddenWords, maxWordsPerLine + 2);
for (unsigned n = 0; n < numNbest; n++) {
hiddenWids[n] = new VocabIndex[maxWordsPerLine + 2];
}
makeArray(LogP, totalProb, numNbest);
unsigned numHyps =
disambiguateSentence(map.vocab1, wids, hiddenWids,
totalProb, map, lm, counts, numNbest);
if (!numHyps) {
file.position() << "Disambiguation failed\n";
} else if (totals) {
cout << totalProb[0] << endl;
} else if (!posteriors) {
for (unsigned n = 0; n < numHyps; n++) {
map.vocab2.getWords(hiddenWids[n], hiddenWords,
maxWordsPerLine + 2);
if (numNbest > 1) {
cout << "NBEST_" << n << " " << totalProb[n] << " ";
}
if (keepUnk) {
/*
* Look for <unk> symbols in the output and replace
* them with the corresponding input tokens
*/
for (unsigned i = 0;
hiddenWids[n][i] != Vocab_None;
i++)
{
if (i > 0 &&
hiddenWids[n][i] == map.vocab2.unkIndex())
{
hiddenWords[i] = sentence[i - 1];
}
}
}
cout << (map.vocab2.use(), hiddenWords) << endl;
}
}
for (unsigned n = 0; n < numNbest; n++) {
delete [] hiddenWids[n];
}
}
}
}
/*
* Read entire file ignoring line breaks, map it to wids,
* disambiguate it, and print out the result
*/
void
disambiguateFileContinuous(File &file, VocabMap &map, LM &lm,
VocabMap *counts)
{
char *line;
Array<VocabIndex> wids;
unsigned lineStart = 0; // index into the above to mark the offset for the
// current line's data
while (line = file.getline()) {
VocabString words[maxWordsPerLine];
unsigned numWords =
Vocab::parseWords(line, words, maxWordsPerLine);
if (numWords == maxWordsPerLine) {
file.position() << "too many words per line\n";
} else {
// This effectively allocates more space
wids[lineStart + numWords] = Vocab_None;
map.vocab1.getIndices(words, &wids[lineStart], numWords,
map.vocab1.unkIndex());
lineStart += numWords;
}
}
if (lineStart == 0) { // empty input -- nothing to do
return;
}
makeArray(VocabIndex *, hiddenWids, numNbest);
makeArray(VocabString, hiddenWords, maxWordsPerLine + 2);
for (unsigned n = 0; n < numNbest; n++) {
hiddenWids[n] = new VocabIndex[lineStart + 1];
hiddenWids[n][lineStart] = Vocab_None;
}
makeArray(LogP, totalProb, numNbest);
unsigned numHyps = disambiguateSentence(map.vocab1, &wids[0], hiddenWids,
totalProb, map, lm, counts, numNbest);
if (!numHyps) {
file.position() << "Disambiguation failed\n";
} else if (totals) {
cout << totalProb[0] << endl;
} else if (!posteriors) {
for (unsigned n = 0; n < numHyps; n++) {
map.vocab2.getWords(hiddenWids[n], hiddenWords,
maxWordsPerLine + 2);
if (numNbest > 1) {
cout << "NBEST_" << n << " " << totalProb[n] << " ";
}
for (unsigned i = 0; hiddenWids[n][i] != Vocab_None; i++) {
// XXX: keepUnk not implemented yet.
cout << map.vocab2.getWord(hiddenWids[n][i]) << " ";
}
cout << endl;
}
}
for (unsigned n = 0; n < numNbest; n++) {
delete [] hiddenWids[n];
}
}
/*
* Read a "text-map" file, containing one word per line, followed by
* map entries;
* disambiguate it, and print out the result
*/
void
disambiguateTextMap(File &file, Vocab &vocab, LM &lm, VocabMap *counts)
{
char *line;
while (line = file.getline()) {
/*
* Hack alert! We pass the map entries associated with the word
* instances in a VocabMap, but we encode the word position (not
* its identity) as the first VocabIndex.
*/
PosVocabMap map(lm.vocab);
unsigned numWords = 0;
Array<VocabIndex> wids;
/*
* Process one sentence
*/
do {
/*
* Read map line
*/
VocabString mapFields[maxWordsPerLine];
unsigned howmany =
Vocab::parseWords(line, mapFields, maxWordsPerLine);
if (howmany == maxWordsPerLine) {
file.position() << "text map line has too many fields\n";
return;
}
/*
* First field is the V1 word
* Note we use addWord() since V1 words are by definition
* only found in the textmap, so there are no OOVs here.
*/
wids[numWords] = vocab.addWord(mapFields[0]);
/*
* Parse the remaining words as either probs or V2 words
*/
unsigned i = 1;
while (i < howmany) {
double prob;
VocabIndex w2 = lm.vocab.addWord(mapFields[i++]);
if (i < howmany && sscanf(mapFields[i], "%lf", &prob)) {
i ++;
} else {
prob = logMap ? LogP_One : 1.0;
}
map.put((VocabIndex)numWords, w2, prob);
}
} while (wids[numWords ++] != vocab.seIndex() &&
(line = file.getline()));
if (numWords > 0) {
wids[numWords] = Vocab_None;
makeArray(VocabIndex *, hiddenWids, numNbest);
for (unsigned n = 0; n < numNbest; n++) {
hiddenWids[n] = new VocabIndex[numWords + 1];
hiddenWids[n][numWords] = Vocab_None;
}
makeArray(LogP, totalProb, numNbest);
unsigned numHyps =
disambiguateSentence(vocab, &wids[0], hiddenWids, totalProb,
map, lm, counts, numNbest, true);
if (!numHyps) {
file.position() << "Disambiguation failed\n";
} else if (totals) {
cout << totalProb[0] << endl;
} else if (!posteriors) {
for (unsigned n = 0; n < numHyps; n++) {
if (numNbest > 1) {
cout << "NBEST_" << n << " " << totalProb[n] << " ";
}
for (unsigned i = 0; hiddenWids[n][i] != Vocab_None; i++) {
cout << map.vocab2.getWord(hiddenWids[n][i]) << " ";
}
cout << endl;
}
}
for (unsigned n = 0; n < numNbest; n++) {
delete [] hiddenWids[n];
}
}
}
}
int
main(int argc, char **argv)
{
setlocale(LC_CTYPE, "");
setlocale(LC_COLLATE, "");
Opt_Parse(argc, argv, options, Opt_Number(options), 0);
if (version) {
printVersion(RcsId);
exit(0);
}
/*
* Construct language model
*/
Vocab hiddenVocab;
Vocab vocab;
LM *hiddenLM;
VocabMap map(vocab, hiddenVocab, logMap);
vocab.toLower() = tolower1? true : false;
hiddenVocab.toLower() = tolower2 ? true : false;
hiddenVocab.unkIsWord() = keepUnk ? true : false;
if (mapFile) {
File file(mapFile, "r");
if (!map.read(file)) {
cerr << "format error in map file\n";
exit(1);
}
}
if (classesFile) {
File file(classesFile, "r");
if (!map.readClasses(file)) {
cerr << "format error in classes file\n";
exit(1);
}
}
if (lmFile) {
File file(lmFile, "r");
hiddenLM = new Ngram(hiddenVocab, order);
assert(hiddenLM != 0);
hiddenLM->debugme(debug);
hiddenLM->read(file);
} else {
hiddenLM = new NullLM(hiddenVocab);
assert(hiddenLM != 0);
hiddenLM->debugme(debug);
}
VocabMap *counts;
if (countsFile) {
counts = new VocabMap(vocab, hiddenVocab);
assert(counts != 0);
counts->remove(vocab.ssIndex(), hiddenVocab.ssIndex());
counts->remove(vocab.seIndex(), hiddenVocab.seIndex());
counts->remove(vocab.unkIndex(), hiddenVocab.unkIndex());
} else {
counts = 0;
}
if (textFile) {
File file(textFile, "r");
if (continuous) {
disambiguateFileContinuous(file, map, *hiddenLM, counts);
} else {
disambiguateFile(file, map, *hiddenLM, counts);
}
}
if (textMapFile) {
File file(textMapFile, "r");
disambiguateTextMap(file, vocab, *hiddenLM, counts);
}
if (countsFile) {
File file(countsFile, "w");
counts->writeBigrams(file);
}
if (mapWriteFile) {
File file(mapWriteFile, "w");
map.write(file);
}
if (vocab1File) {
File file(vocab1File, "w");
hiddenVocab.write(file);
}
if (vocab2File) {
File file(vocab2File, "w");
vocab.write(file);
}
#ifdef DEBUG
delete hiddenLM;
return 0;
#endif /* DEBUG */
exit(0);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -