📄 hmmchunker.java
字号:
* of bounds of the specified character array. */ public Iterator<Chunk> nBestChunks(char[] cs, int start, int end, int maxNBest) { String[][] toksWhites = getToksWhites(cs,start,end); TagWordLattice lattice = mDecoder.lattice(toksWhites[0]); return new NBestChunkIt(lattice,toksWhites[1],maxNBest); } String[][] getToksWhites(char[] cs, int start, int end) { Strings.checkArgsStartEnd(cs,start,end); Tokenizer tokenizer = mTokenizerFactory.tokenizer(cs,start,end-start); ArrayList tokList = new ArrayList(); ArrayList whiteList = new ArrayList(); tokenizer.tokenize(tokList,whiteList); String[] toks = toStringArray(tokList); String[] whites = toStringArray(whiteList); return new String[][] { toks, whites }; } private static class NBestChunkIt extends Iterators.Buffered<Chunk> { final TagWordLattice mLattice; final String[] mWhites; final int mMaxNBest; final int[] mTokenStartIndexes; final int[] mTokenEndIndexes; String[] mBeginTags; int[] mBeginTagIds; int[] mMidTagIds; int[] mEndTagIds; String[] mWholeTags; int[] mWholeTagIds; final BoundedPriorityQueue<Scored> mQueue; final int mNumToks; final double mTotal; int mCount = 0; NBestChunkIt(TagWordLattice lattice, String[] whites, int maxNBest) { mTotal = lattice.log2Total(); mLattice = lattice; mWhites = whites; String[] toks = lattice.tokens(); mNumToks = toks.length; mTokenStartIndexes = new int[mNumToks]; mTokenEndIndexes = new int[mNumToks]; int pos = 0; for (int i = 0; i < mNumToks; ++i) { pos += whites[i].length(); mTokenStartIndexes[i] = pos; pos += toks[i].length(); mTokenEndIndexes[i] = pos; } mMaxNBest = maxNBest; mQueue = new BoundedPriorityQueue<Scored>(Scored.SCORE_COMPARATOR, maxNBest); initializeTags(); initializeQueue(); } void initializeTags() { SymbolTable tagTable = mLattice.tagSymbolTable(); List beginTagList = new ArrayList(); List beginTagIdList = new ArrayList(); List midTagIdList = new ArrayList(); List endTagIdList = new ArrayList(); List wholeTagList = new ArrayList(); List wholeTagIdList = new ArrayList(); int numTags = tagTable.numSymbols(); for (int i = 0; i < numTags; ++i) { String tag = tagTable.idToSymbol(i); if (tag.startsWith("B_")) { String baseTag = tag.substring(2); beginTagList.add(baseTag); beginTagIdList.add(new Integer(i)); String midTag = "M_" + baseTag; int midTagId = tagTable.symbolToID(midTag); midTagIdList.add(new Integer(midTagId)); String endTag = "E_" + baseTag; int endTagId = tagTable.symbolToID(endTag); endTagIdList.add(new Integer(endTagId)); } else if (tag.startsWith("W_")) { String baseTag = tag.substring(2); wholeTagList.add(baseTag); wholeTagIdList.add(new Integer(i)); } } mBeginTags = toStringArray(beginTagList); mBeginTagIds = toIntArray(beginTagIdList); mMidTagIds = toIntArray(midTagIdList); mEndTagIds = toIntArray(endTagIdList); mWholeTags = toStringArray(wholeTagList); mWholeTagIds = toIntArray(wholeTagIdList); } void initializeQueue() { int len = mWhites.length-1; for (int i = 0; i < len; ++i) { for (int j = 0; j < mBeginTagIds.length; ++j) initializeBeginTag(i,j); for (int j = 0; j < mWholeTagIds.length; ++j) initializeWholeTag(i,j); } } void initializeBeginTag(int tokPos, int j) { int startCharPos = mTokenStartIndexes[tokPos]; String tag = mBeginTags[j]; int beginTagId = mBeginTagIds[j]; int midTagId = mMidTagIds[j]; int endTagId = mEndTagIds[j]; double forward = mLattice.log2Forward(tokPos,beginTagId); double backward = mLattice.log2Backward(tokPos,beginTagId); ChunkItState state = new ChunkItState(startCharPos,tokPos, tag,beginTagId,midTagId,endTagId, forward,backward); mQueue.add(state); } void initializeWholeTag(int tokPos, int j) { int start = mTokenStartIndexes[tokPos]; int end = mTokenEndIndexes[tokPos]; String tag = mWholeTags[j]; double log2Score = mLattice.log2ForwardBackward(tokPos,mWholeTagIds[j]); Chunk chunk = ChunkFactory.createChunk(start,end,tag,log2Score); mQueue.add(chunk); } public Chunk bufferNext() { if (mCount > mMaxNBest) return null; while (!mQueue.isEmpty()) { Object next = mQueue.pop(); if (next instanceof Chunk) { ++mCount; Chunk result = (Chunk) next; return ChunkFactory.createChunk(result.start(), result.end(), result.type(), result.score()-mTotal); } ChunkItState state = (ChunkItState) next; addNextMidState(state); addNextEndState(state); } return null; } void addNextMidState(ChunkItState state) { int nextTokPos = state.mTokPos + 1; if (nextTokPos + 1 >= mNumToks) return; // don't add if can't extend int midTagId = state.mMidTagId; double transition = mLattice.log2Transitions(nextTokPos, state.mCurrentTagId, midTagId); double forward = state.mForward + transition; double backward = mLattice.log2Backward(nextTokPos,midTagId); ChunkItState nextState = new ChunkItState(state.mStartCharPos,nextTokPos, state.mTag, midTagId, state.mMidTagId, state.mEndTagId, forward,backward); mQueue.add(nextState); } void addNextEndState(ChunkItState state) { int nextTokPos = state.mTokPos + 1; if (nextTokPos >= mNumToks) return; int endTagId = state.mEndTagId; double transition = mLattice.log2Transitions(nextTokPos, state.mCurrentTagId, endTagId); double forward = state.mForward + transition; double backward = mLattice.log2Backward(nextTokPos,endTagId); double log2Prob = forward + backward; // - mTotal; Chunk chunk = ChunkFactory.createChunk(state.mStartCharPos, mTokenEndIndexes[nextTokPos], state.mTag, log2Prob); mQueue.add(chunk); } } private static class ChunkItState implements Scored { final int mStartCharPos; final int mTokPos; final String mTag; final double mForward; final double mBack; final double mScore; final int mCurrentTagId; final int mMidTagId; final int mEndTagId; ChunkItState(int startCharPos, int tokPos, String tag, int currentTagId, int midTagId, int endTagId, double forward, double back) { mStartCharPos = startCharPos; mTokPos = tokPos; mTag = tag; mCurrentTagId = currentTagId; mMidTagId = midTagId; mEndTagId = endTagId; mForward = forward; mBack = back; mScore = forward + back; } public double score() { return mScore; } } private static class NBestIt implements Iterator<ScoredObject<Chunking>> { final Iterator mIt; final String[] mWhites; final String[] mToks; NBestIt(Iterator it, String[][] toksWhites) { mIt = it; mToks = toksWhites[0]; mWhites = toksWhites[1]; } public boolean hasNext() { return mIt.hasNext(); } public ScoredObject<Chunking> next() { ScoredObject so = (ScoredObject) mIt.next(); double score = so.score(); String[] tags = (String[]) so.getObject(); decodeNormalize(tags); Chunking chunking = ChunkTagHandlerAdapter.toChunkingBIO(mToks,mWhites,tags); return new ScoredObject<Chunking>(chunking,score); } public void remove() { mIt.remove(); } } private static String[] toStringArray(Collection c) { String[] result = new String[c.size()]; c.toArray(result); return result; } private static int[] toIntArray(Collection c) { int[] result = new int[c.size()]; Iterator it = c.iterator(); for (int i = 0; it.hasNext(); ++i) { Integer nextVal = (Integer) it.next(); result[i] = nextVal.intValue(); } return result; } static String baseTag(String tag) { if (ChunkTagHandlerAdapter.isOutTag(tag)) return tag; return tag.substring(2); } static String[] trainNormalize(String[] tags) { if (tags.length == 0) return tags; String[] normalTags = new String[tags.length]; for (int i = 0; i < normalTags.length; ++i) { String prevTag = (i-1 >= 0) ? tags[i-1] : "W_BOS"; // "W_BOS"; String nextTag = (i+1 < tags.length) ? tags[i+1] : "W_BOS"; // "W_EOS"; normalTags[i] = trainNormalize(prevTag,tags[i],nextTag); } return normalTags; } private static void decodeNormalize(String[] tags) { for (int i = 0; i < tags.length; ++i) tags[i] = decodeNormalize(tags[i]); } static String trainNormalize(String prevTag, String tag, String nextTag) { if (ChunkTagHandlerAdapter.isOutTag(tag)) { if (ChunkTagHandlerAdapter.isOutTag(prevTag)) { if (ChunkTagHandlerAdapter.isOutTag(nextTag)) { return "MM_O"; } else { return "EE_O_" + baseTag(nextTag); } } else if (ChunkTagHandlerAdapter.isOutTag(nextTag)) { return "BB_O_" + baseTag(prevTag); } else { return "WW_O_" + baseTag(nextTag); // WW_O } } if (ChunkTagHandlerAdapter.isBeginTag(tag)) { if (ChunkTagHandlerAdapter.isInTag(nextTag)) return "B_" + baseTag(tag); else return "W_" + baseTag(tag); } if (ChunkTagHandlerAdapter.isInTag(tag)) { if (ChunkTagHandlerAdapter.isInTag(nextTag)) return "M_" + baseTag(tag); else return "E_" + baseTag(tag); } String msg = "Unknown tag triple." + " prevTag=" + prevTag + " tag=" + tag + " nextTag=" + nextTag; throw new IllegalArgumentException(msg); } private static String decodeNormalize(String tag) { if (tag.startsWith("B_") || tag.startsWith("W_")) { String baseTag = tag.substring(2); return ChunkTagHandlerAdapter.toBeginTag(baseTag); } if (tag.startsWith("M_") || tag.startsWith("E_")) { String baseTag = tag.substring(2); return ChunkTagHandlerAdapter.toInTag(baseTag); } return ChunkTagHandlerAdapter.OUT_TAG; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -