📄 hmmtree.java
字号:
/* * Copyright 1999-2002 Carnegie Mellon University. * Portions Copyright 2002 Sun Microsystems, Inc. * Portions Copyright 2002 Mitsubishi Electric Research Laboratories. * All Rights Reserved. Use is subject to license terms. * * See the file "license.terms" for information on usage and * redistribution of this file, and for a DISCLAIMER OF ALL * WARRANTIES. * */package edu.cmu.sphinx.linguist.lextree;import java.util.ArrayList;import java.util.Collection;import java.util.HashMap;import java.util.HashSet;import java.util.Iterator;import java.util.List;import java.util.Map;import java.util.Set;import edu.cmu.sphinx.linguist.WordSequence;import edu.cmu.sphinx.linguist.acoustic.HMM;import edu.cmu.sphinx.linguist.acoustic.HMMPosition;import edu.cmu.sphinx.linguist.acoustic.Unit;import edu.cmu.sphinx.linguist.dictionary.Dictionary;import edu.cmu.sphinx.linguist.dictionary.Pronunciation;import edu.cmu.sphinx.linguist.dictionary.Word;import edu.cmu.sphinx.linguist.language.ngram.LanguageModel;import edu.cmu.sphinx.linguist.util.HMMPool;import edu.cmu.sphinx.util.LogMath;import edu.cmu.sphinx.util.Timer;import edu.cmu.sphinx.util.Utilities;/** * Represents the vocabulary as a lex tree with nodes in the tree * representing either words (WordNode) or units (HMMNode). HMMNodes * may be shared. */class HMMTree { private HMMPool hmmPool; private InitialWordNode initialNode; private Dictionary dictionary; private LanguageModel lm; private boolean addFillerWords = false; private boolean addSilenceWord = true; private Set entryPoints = new HashSet(); private Set exitPoints = new HashSet(); private Set allWords = null; private EntryPointTable entryPointTable; private boolean debug = false; private float languageWeight; private Map endNodeMap; private WordNode sentenceEndWordNode; /** * Creates the HMMTree * * @param pool the pool of HMMs and units * @param dictionary the dictionary containing the pronunciations * @param lm the source of the set of words to add to the lex tree * @param addFillerWords if <code>false</code> add filler words * @param languageWeight the languageWeight */ HMMTree(HMMPool pool, Dictionary dictionary, LanguageModel lm, boolean addFillerWords, float languageWeight) { this.hmmPool = pool; this.dictionary = dictionary; this.lm = lm; this.endNodeMap = new HashMap(); this.addFillerWords = addFillerWords; this.languageWeight = languageWeight; Timer.start("Create HMMTree"); compile(); Timer.stop("Create HMMTree"); } /** * Given a base unit and a left context, return the * set of entry points into the lex tree * * @param lc the left context * @param base the center unit * * @return the set of entry points */ public Collection getEntryPoint(Unit lc, Unit base) { EntryPoint ep = entryPointTable.getEntryPoint(base); return ep.getEntryPointsFromLeftContext(lc).getSuccessors(); } /** * Gets the set of hmm nodes associated with the given end node * * @param endNode the end node * * @return an array of associated hmm nodes */ public HMMNode[] getHMMNodes(EndNode endNode) { HMMNode[] results = (HMMNode[]) endNodeMap.get(endNode.getKey()); if (results == null) { // System.out.println("Filling cache for " + endNode.getKey() // + " size " + endNodeMap.size()); Map resultMap = new HashMap(); Unit baseUnit = endNode.getBaseUnit(); Unit lc = endNode.getLeftContext(); for (Iterator i = entryPoints.iterator(); i.hasNext(); ) { Unit rc = (Unit) i.next(); HMM hmm = getHMM(baseUnit, lc, rc, HMMPosition.END); HMMNode hmmNode = (HMMNode) resultMap.get(hmm); if (hmmNode == null) { hmmNode = new HMMNode(hmm, LogMath.getLogOne()); resultMap.put(hmm, hmmNode); } hmmNode.addRC(rc); for (Iterator j = endNode.getSuccessors().iterator(); j.hasNext(); ) { WordNode wordNode = (WordNode) j.next(); hmmNode.addSuccessor(wordNode); } } // cache it results = (HMMNode[]) resultMap.values().toArray( new HMMNode[resultMap.size()]); endNodeMap.put(endNode.getKey(), results); } // System.out.println("GHN: " + endNode + " " + results.length); return results; } /** * Returns the word node associated with the sentence end word * * @return the sentence end word node */ public WordNode getSentenceEndWordNode() { assert sentenceEndWordNode != null; return sentenceEndWordNode; } private Object getKey(EndNode endNode) { Unit base = endNode.getBaseUnit(); Unit lc = endNode.getLeftContext(); return null; } /** * Compiles the vocabulary into an HMM Tree */ private void compile() { collectEntryAndExitUnits(); entryPointTable = new EntryPointTable(entryPoints); addWords(); entryPointTable.createEntryPointMaps(); freeze(); } /** * Dumps the tree * */ void dumpTree() { System.out.println("Dumping Tree ..."); Map dupNode = new HashMap(); dumpTree(0, getInitialNode(), dupNode); System.out.println("... done Dumping Tree"); } /** * Dumps the tree * * @param level the level of the dump * @param node the root of the tree to dump * @param dupNode map of visited nodes */ private void dumpTree(int level, Node node, Map dupNode) { if (dupNode.get(node) == null) { dupNode.put(node, node); System.out.println(Utilities.pad(level * 1) + node); if (! (node instanceof WordNode)) { Collection next = node.getSuccessors(); for (Iterator i= next.iterator(); i.hasNext(); ) { Node nextNode = (Node) i.next(); dumpTree(level + 1, nextNode, dupNode); } } } } /** * Collects all of the entry and exit points for the vocabulary. */ private void collectEntryAndExitUnits() { Collection words = getAllWords(); for (Iterator i = words.iterator(); i.hasNext(); ) { Word word = (Word) i.next(); for (int j = 0; j < word.getPronunciations().length; j++) { Pronunciation p = word.getPronunciations()[j]; Unit first = p.getUnits()[0]; Unit last = p.getUnits()[p.getUnits().length - 1]; entryPoints.add(first); exitPoints.add(last); } } if (debug) { System.out.println("Entry Points: " + entryPoints.size()); System.out.println("Exit Points: " + exitPoints.size()); } } /** * Called after the lex tree is built. Frees all temporary * structures. After this is called, no more words can be added to * the lex tree. */ private void freeze() { entryPointTable.freeze(); dictionary = null; lm = null; exitPoints = null; allWords = null; } /** * Adds the given collection of words to the lex tree * */ private void addWords() { Set words = getAllWords(); for (Iterator i = words.iterator(); i.hasNext(); ) { Word word = (Word) i.next(); addWord(word); } } /** * Adds a single word to the lex tree * * @param word the word to add */ private void addWord(Word word) { float prob = getWordUnigramProbability(word); Pronunciation[] pronunciations = word.getPronunciations(); for (int i = 0; i < pronunciations.length; i++) { addPronunciation(pronunciations[i], prob); } } /** * Adds the given pronunciation to the lex tree * * @param pronunciation the pronunciation * @param probability the unigram probability */ private void addPronunciation(Pronunciation pronunciation, float probability) { Unit baseUnit; Unit lc; Unit rc; Node curNode; WordNode wordNode; Unit[] units = pronunciation.getUnits(); baseUnit = units[0]; EntryPoint ep = entryPointTable.getEntryPoint(baseUnit); ep.addProbability(probability); if (units.length > 1) { curNode = ep.getNode(); lc = baseUnit; for (int i = 1; i < units.length - 1; i++) { baseUnit = units[i]; rc = units[i + 1]; HMM hmm = getHMM(baseUnit, lc, rc, HMMPosition.INTERNAL); curNode = curNode.addSuccessor(hmm, probability); lc = baseUnit; // next lc is this baseUnit } // now add the last unit as an end unit baseUnit = units[units.length - 1]; EndNode endNode = new EndNode(baseUnit, lc, probability); curNode = curNode.addSuccessor(endNode, probability); wordNode = curNode.addSuccessor(pronunciation, probability); if (wordNode.getWord() == dictionary.getSentenceEndWord()) { sentenceEndWordNode = wordNode; } } else { ep.addSingleUnitWord(pronunciation); } } /** * Retrieves an HMM for a unit in context. If there is no direct * match, the nearest match will be used. Note that we are * currently only dealing with, at most, single unit left * and right contexts. * * @param base the base CI unit * @param lc the left context * @param rc the right context * @param pos the position of the base unit within the word * * @return the HMM. (This should never return null) */ private HMM getHMM(Unit base, Unit lc, Unit rc, HMMPosition pos) { int id = hmmPool.buildID(hmmPool.getID(base), hmmPool.getID(lc), hmmPool.getID(rc)); HMM hmm = hmmPool.getHMM(id, pos); if (hmm == null) { System.out.println( "base ID " + hmmPool.getID(base) + "left ID " + hmmPool.getID(lc) + "right ID " + hmmPool.getID(rc)); System.out.println("Unit " + base + " lc " + lc + " rc " + rc + " pos " + pos); System.out.println("ID " + id + " hmm " + hmm); } assert hmm != null; return hmm; } /** * Gets the unigram probability for the given word * * @param word the word * * @return the unigram probability for the word. */ private float getWordUnigramProbability(Word word) { float prob = LogMath.getLogOne(); if (!word.isFiller()) { Word[] wordArray = new Word[1]; wordArray[0] = word; prob = lm.getProbability(WordSequence.getWordSequence(wordArray)); // System.out.println("gwup: " + word + " " + prob); prob *= languageWeight; } return prob; } /** * Returns the entire set of words, including filler words * * @return the set of all words (as Word objects) */ private Set getAllWords() { if (allWords == null) { allWords = new HashSet(); Collection words = lm.getVocabulary(); for (Iterator i = words.iterator(); i.hasNext(); ) { String spelling = (String) i.next(); Word word = dictionary.getWord(spelling); if (word != null) { allWords.add(word); } } if (addFillerWords) { Word[] fillerWords = dictionary.getFillerWords(); for (int i = 0; i < fillerWords.length; i++) { allWords.add(fillerWords[i]); } } else if (addSilenceWord) { allWords.add(dictionary.getSilenceWord()); } } return allWords; } /** * Returns the initial node for this lex tree * * @return the initial lex node */ InitialWordNode getInitialNode() { return initialNode; } /** * The EntryPoint table is used to manage the set of entry points * into the lex tree. */ class EntryPointTable { private Map entryPoints; /** * Create the entry point table give the set of all possible * entry point units * * @param entryPointCollection the set of possible entry * points */ EntryPointTable(Collection entryPointCollection) { entryPoints = new HashMap(); for (Iterator i = entryPointCollection.iterator(); i.hasNext(); ) { Unit unit = (Unit) i.next(); entryPoints.put(unit, new EntryPoint(unit)); } } /** * Given a CI unit, return the EntryPoint object that manages * the entry point for the unit * * @param baseUnit the unit of interest (A ci unit) * * @return the object that manages the entry point for the * unit */ EntryPoint getEntryPoint(Unit baseUnit) { return (EntryPoint) entryPoints.get(baseUnit); } /** * Creates the entry point maps for all entry points. */ void createEntryPointMaps() { for (Iterator i = entryPoints.values().iterator(); i.hasNext(); ) { EntryPoint ep = (EntryPoint) i.next(); ep.createEntryPointMap(); } } /** * Freezes the entry point table */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -