📄 hmmdecoder.java
字号:
* for a backwards n-best pass using the A<sup>*</sup> algorithm. * Thus it will be slower than just computing the first best * result using {@link #firstBest(String[])}. The iterator stores * the entire Viterbi lattice as well as a priority queue of * partial states ordered by the A<sup>*</sup> condition. * * @param emissions String outputs whose tag sequences are returned. * @return Iterator over scored tag sequences in decreasing order * of likelihood. */ public Iterator<ScoredObject<String[]>> nBest(String[] emissions) { if (emissions.length == 0) { ScoredObject<String[]> result = new ScoredObject(EMPTY_STRING_ARRAY,0.0); return new Iterators.Singleton<ScoredObject<String[]>>(result); } Viterbi viterbiLattice = new Viterbi(emissions); return new NBestIterator(viterbiLattice,Integer.MAX_VALUE); } /** * Returns a best-first iterator of {@link ScoredObject} instances * consisting of arrays of tags and log (base 2) joint likelihoods * of the tags and emissions with respect to the underlying HMM up * to the specified maximum number of results. * * <P><i>Implementation Note:</i> This method is implemented by * doing a Viterbi search to provide exact A<sup>*</sup> bounds * for a backwards n-best pass using the A<sup>*</sup> algorithm. * Thus it will be slower than just computing the first best * result using {@link #firstBest(String[])}. The iterator stores * the entire Viterbi lattice as well as a priority queue of * partial states ordered by the A<sup>*</sup> condition. * * @param emissions String outputs whose tag sequences are returned. * @return Iterator over scored tag sequences in decreasing order * of likelihood. */ public Iterator<ScoredObject<String[]>> nBest(String[] emissions, int maxN) { if (emissions.length == 0) { ScoredObject result = new ScoredObject<String[]>(EMPTY_STRING_ARRAY,0.0); return new Iterators.Singleton<ScoredObject<String[]>>(result); } Viterbi viterbiLattice = new Viterbi(emissions); return new NBestIterator(viterbiLattice,maxN); } /** * Returns a best-first iterator of scored objects consisting of * arrays of tags and log (base 2) conditional likelihoods of the * tags given the specified emissions with respect to the * underlying HMM. Only analyses with non-zero probability * estimates are returned. For this method, the sum of all * iterated estimates should be 1.0, plus or minus rounding * errors. * * <P>Conditional estimates of tags given emissions are derived * from dividing the joint estimates by the marginal likelihood * of the emissions as computed by summing over all joint estimates. * * <P><i>Implementation Note:</i> The total log likelihood is * returned by applying {@link TagWordLattice#log2Total()} to the * result of decoding the input with {@link #lattice(String[])}. * The joint estimates are iterated using the iterator returned by * {@link #nBest(String[])} and then modified by subtracting the * emission marginal log likelihood from the joint emission/tags * log likelihood. This method adds the cost of the full lattice * computation to the joint n-best method. The space for the full * lattice is used transiently when this method is called and * may be garbage-collected even before the first element is returned * by the iterator. * * @param emissions String outputs whose tag sequences are returned. * @return Iterator over scored tag sequences in decreasing order * of likelihood. */ public Iterator<ScoredObject<String[]>> nBestConditional(String[] emissions) { Iterator nBestIterator = nBest(emissions); double jointLog2Prob = lattice(emissions).log2Total(); return new JointIterator(nBestIterator,jointLog2Prob); } void unprunedSources(double[] sources, int[] survivors, double beam) { double best = sources[0]; for (int i = 0; i < sources.length; ++i) if (sources[i] > best) best = sources[i]; int next = 0; for (int i = 0; i < sources.length; ++i) if (sources[i] + beam >= best) survivors[next++] = i; survivors[next] = -1; } private class Viterbi { private final String[] mEmissions; private final double[][] mLattice; private final int[][] mBackPts; Viterbi(String[] emissions) { mEmissions = emissions; HiddenMarkovModel hmm = mHmm; int numStates = hmm.stateSymbolTable().numSymbols(); int numEmits = emissions.length; double[][] lattice = new double[numEmits][numStates]; mLattice = lattice; int[][] backPts = new int[numEmits][numStates]; mBackPts = backPts; if (emissions.length == 0) { return; } double[] emitLog2Probs = emitLog2Probs(emissions[0]); for (int stateId = 0; stateId < numStates; ++stateId) { lattice[0][stateId] = emitLog2Probs[stateId] + hmm.startLog2Prob(stateId); } int[] unprunedSources = new int[numStates+1]; for (int i = 1; i < numEmits; ++i) { double[] lastSlice = lattice[i-1]; unprunedSources(lastSlice,unprunedSources,mLog2Beam); double[] emitLog2Probs2 = emitLog2Probs(emissions[i]); for (int targetId = 0; targetId < numStates; ++targetId) { if (Double.NEGATIVE_INFINITY != emitLog2Probs2[targetId]) { double best = Double.NEGATIVE_INFINITY; int bk = 0; // default tag for (int next = 0; unprunedSources[next] != -1; ++next) { int sourceId = unprunedSources[next]; double est = lastSlice[sourceId] + hmm.transitLog2Prob(sourceId,targetId); if (est > best) { best = est; bk = sourceId; } } lattice[i][targetId] = best + emitLog2Probs2[targetId]; backPts[i][targetId] = bk; } else { lattice[i][targetId] = Double.NEGATIVE_INFINITY; backPts[i][targetId] = 0; // default tag } } } // handles finals even if only one emission double[] lastColumn = lattice[numEmits-1]; for (int i = 0; i < numStates; ++i) lastColumn[i] += hmm.endLog2Prob(i); } String[] bestStates() { HiddenMarkovModel hmm = mHmm; int numStates = hmm.stateSymbolTable().numSymbols(); int numEmits = mEmissions.length; if (numEmits == 0) return new String[0]; int[][] backPts = mBackPts; double[][] lattice = mLattice; int[] bestStateIds = new int[numEmits]; int bestStateId = 0; double[] lastCol = lattice[numEmits-1]; for (int i = 1; i < numStates; ++i) if (lastCol[i] > lastCol[bestStateId]) bestStateId = i; bestStateIds[numEmits-1] = bestStateId; for (int i = numEmits; --i > 0; ) bestStateIds[i-1] = backPts[i][bestStateIds[i]]; String[] bestStates = new String[numEmits]; SymbolTable st = hmm.stateSymbolTable(); for (int i = 0; i < bestStates.length; ++i) bestStates[i] = st.idToSymbol(bestStateIds[i]); return bestStates; } } private class NBestIterator extends Iterators.Buffered<ScoredObject<String[]>> { private final Viterbi mViterbi; private final BoundedPriorityQueue mPQ; NBestIterator(Viterbi vit, int maxSize) { mViterbi = vit; mPQ = new BoundedPriorityQueue(Scored.SCORE_COMPARATOR, maxSize); String[] emissions = vit.mEmissions; int numStates = mHmm.stateSymbolTable().numSymbols(); int numEmits = emissions.length; int lastEmitIndex = numEmits-1; for (int tagId = 0; tagId < numStates; ++tagId) { double contScore = vit.mLattice[lastEmitIndex][tagId]; if (contScore > Double.NEGATIVE_INFINITY) { double score = 0.0; mPQ.add(new State(lastEmitIndex,score,contScore, tagId,null)); } } } public ScoredObject<String[]> bufferNext() { int numTags = mHmm.stateSymbolTable().numSymbols(); int numEmissions = mViterbi.mEmissions.length; int lastEmitIndex = numEmissions-1; while (!mPQ.isEmpty()) { State st = (State) mPQ.pop(); int emitIndex = st.emissionIndex(); if (emitIndex == 0) { mPQ.setMaxSize(mPQ.maxSize()-1); return st.result(numEmissions); } String emission = mViterbi.mEmissions[emitIndex]; int emitTagId = st.mTagId; double score = st.mScore; if (emitIndex == lastEmitIndex) score += mHmm.endLog2Prob(emitTagId); int emitIndexMinus1 = emitIndex-1; // don't compile because only need one tagId double emitLog2Prob = mHmm.emitLog2Prob(emitTagId,emission); for (int tagId = 0; tagId < numTags; ++tagId) { double nextScore = score + mHmm.transitLog2Prob(tagId,emitTagId) + emitLog2Prob; double contScore = mViterbi.mLattice[emitIndexMinus1][tagId]; if (nextScore > Double.NEGATIVE_INFINITY && contScore > Double.NEGATIVE_INFINITY) mPQ.add(new State(emitIndexMinus1, nextScore, contScore, tagId,st)); } } return null; } } private final class State implements Scored { private final double mScore; private final double mContScore; private final int mTagId; private final State mPreviousState; private final int mEmissionIndex; // used outside State(int emissionIndex, double score, double contScore, int tagId, State previousState) { mEmissionIndex = emissionIndex; mScore = score; mContScore = contScore; mTagId = tagId; mPreviousState = previousState; } public int emissionIndex() { return mEmissionIndex; } public double score() { return mScore + mContScore; } ScoredObject<String[]> result(int numTags) { return new ScoredObject<String[]>(tags(numTags),score()); } String[] tags(int numTags) { SymbolTable symTable = mHmm.stateSymbolTable(); String[] tags = new String[numTags]; State state = this; for (int i = 0; i < numTags; ++i) { tags[i] = symTable.idToSymbol(state.mTagId); state = state.mPreviousState; } return tags; } } private static final String[] EMPTY_STRING_ARRAY = new String[0]; private static final class JointIterator extends Iterators.Modifier<ScoredObject<String[]>> { final double mLog2TotalProb; JointIterator(Iterator<ScoredObject<String[]>> nBestIterator, double log2TotalProb) { super(nBestIterator); mLog2TotalProb = log2TotalProb; } public ScoredObject<String[]> modify(ScoredObject<String[]> jointObj) { String[] tags = jointObj.getObject(); double log2JointProb = jointObj.score(); double log2CondProb = log2JointProb - mLog2TotalProb; return new ScoredObject<String[]>(tags,log2CondProb); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -