📄 intnode.java
字号:
package com.aliasi.lm;import com.aliasi.symbol.SymbolTable;import java.util.Iterator;import java.util.List;import java.util.Map;import java.util.TreeMap;class IntNode { int mCount; long mExtCount; DtrMap mDtrs; IntNode() { mCount = 0; mExtCount = 0L; mDtrs = DtrMap0.EMPTY_DTR_MAP; } IntNode(int[] toks, int start, int end) { mCount = 1; if (start == end) { mDtrs = DtrMap0.EMPTY_DTR_MAP; mExtCount = 0L; return; } mExtCount = 1L; int tok = toks[start]; IntNode dtr = new IntNode(toks,start+1,end); mDtrs = new DtrMap1(tok,dtr); } IntNode(int[] toks, int start, int end, int count) { mCount = count; if (start == end) { mDtrs = DtrMap0.EMPTY_DTR_MAP; mExtCount = 0L; return; } mExtCount = count; int tok = toks[start]; IntNode dtr = new IntNode(toks,start+1,end,count); mDtrs = new DtrMap1(tok,dtr); } IntNode(int[] toks, int start, int end, int count, boolean incrementPath) { if (incrementPath) throw new IllegalArgumentException("require true"); if (start == end) { mCount = count; mDtrs = DtrMap0.EMPTY_DTR_MAP; mExtCount = 0L; return; } mCount = 0; mExtCount = (start + 1 == end) ? count : 0L; int tok = toks[start]; IntNode dtr = new IntNode(toks,start+1,end,count,incrementPath); mDtrs = new DtrMap1(tok,dtr); } public void prune(int minCount) { mDtrs = mDtrs.prune(minCount); mExtCount = mDtrs.extensionCount(); } public void rescale(double countMultiplier) { mCount = (int)(countMultiplier * mCount); mDtrs = mDtrs.rescale(countMultiplier); mExtCount = mDtrs.extensionCount(); } public static String idToSymbol(int id, SymbolTable st) { if (id == -2) return "EOS"; if (id == -1) return "UNK"; return st.idToSymbol(id); } int trieSize() { return 1 + mDtrs.dtrsTrieSize(); } void decrement(int symbol) { IntNode dtr = mDtrs.getDtr(symbol); if (dtr == null) { String msg = "symbol doesn't exist=" + symbol; throw new IllegalArgumentException(msg); } if (mCount <= 0) { String msg = "Cannot decrement below zero."; throw new IllegalArgumentException(msg); } if (mExtCount < 1) { String msg = "Cannot decrement extensions below zero."; throw new IllegalArgumentException(msg); } --mCount; --mExtCount; dtr.decrement(); } private void decrement() { if (mCount == 0) { String msg = "Cannot decrement below 0."; throw new IllegalArgumentException(msg); } --mCount; } void decrement(int symbol, int count) { IntNode dtr = mDtrs.getDtr(symbol); if (dtr == null) { String msg = "symbol doesn't exist=" + symbol; throw new IllegalArgumentException(msg); } if (mCount - count < 0) { String msg = "Cannot decrement below zero." + " Count=" + mCount + " decrement=" + count; throw new IllegalArgumentException(msg); } if (mExtCount - count < 0) { String msg = "Cannot decrement extension count below zero." + " Ext count=" + mExtCount + " decrement=" + count; throw new IllegalArgumentException(msg); } mCount -= count; mExtCount -= count; dtr.decrementCount(count); } private void decrementCount(int count) { if (mCount - count < 0) { String msg = "Cannot decrement below 0." + " Count=" + mCount + " decrement=" + count; throw new IllegalArgumentException(msg); } mCount -= count; } int count() { return mCount; } void addDaughters(List queue) { mDtrs.addDtrs(queue); } long extensionCount() { return mExtCount; } int numExtensions() { return mDtrs.numExtensions(); } int[] integersFollowing() { return mDtrs.integersFollowing(); } int[] integersFollowing(int[] is, int start, int end) { IntNode dtr = getDtr(is,start,end); if (dtr == null) return EMPTY_INT_ARRAY; return dtr.integersFollowing(); } int[] observedIntegers() { return integersFollowing(EMPTY_INT_ARRAY,0,0); } void incrementSequence(int[] tokIndices, int start, int end, int count) { if (start == end) { mCount += count; return; } if (start + 1 == end) mExtCount += count; DtrMap newDtrs = mDtrs.incrementSequence(tokIndices,start,end,count); if (!newDtrs.equals(mDtrs)) mDtrs = newDtrs; } void increment(int[] tokIndices, int start, int end) { ++mCount; if (start == end) return; ++mExtCount; DtrMap newDtrs = mDtrs.incrementDtrs(tokIndices,start,end); if (!newDtrs.equals(mDtrs)) mDtrs = newDtrs; } void increment(int[] tokIndices, int start, int end, int count) { mCount += count; if (start == end) return; mExtCount += count; DtrMap newDtrs = mDtrs.incrementDtrs(tokIndices,start,end,count); if (!newDtrs.equals(mDtrs)) mDtrs = newDtrs; } IntNode getDtr(int[] toks, int start, int end) { if (start == end) return this; IntNode dtr = mDtrs.getDtr(toks[start]); if (dtr == null) return null; return dtr.getDtr(toks,start+1,end); } public String toString(SymbolTable st) { StringBuffer sb = new StringBuffer(); toString(sb,0,st); return sb.toString(); } public void toString(StringBuffer sb, int depth, SymbolTable st) { sb.append(count()); AbstractNode.indent(sb,depth); mDtrs.toString(sb,depth,st); } static final int[] EMPTY_INT_ARRAY = new int[0];}interface DtrMap { // just passed along public int numExtensions(); public long extensionCount(); public void addDtrs(List queue); public int[] integersFollowing(); // distinct public IntNode getDtr(int tok); public int dtrsTrieSize(); public void toString(StringBuffer sb, int depth, SymbolTable st); // these are the killers for unfolding public DtrMap rescale(double countMultiplier); public DtrMap prune(int minCount); public DtrMap incrementDtrs(int[] tokIndices, int start, int end); public DtrMap incrementDtrs(int[] tokIndices, int start, int end, int count); public DtrMap incrementSequence(int[] tokIndices, int start, int end, int count);}class DtrMap0 implements DtrMap { public IntNode getDtr(int tok) { return null; } public int numExtensions() { return 0; } public DtrMap incrementDtrs(int[] tokIndices, int start, int end) { if (start == end) return this; IntNode dtr = new IntNode(tokIndices,start+1,end); return new DtrMap1(tokIndices[start],dtr); } public DtrMap incrementDtrs(int[] tokIndices, int start, int end, int count) { if (start == end) return this; IntNode dtr = new IntNode(tokIndices,start+1,end,count); return new DtrMap1(tokIndices[start],dtr); } public DtrMap incrementSequence(int[] tokIndices, int start, int end, int count) { if (start == end) return this; IntNode dtr = new IntNode(tokIndices,start+1,end,count,false); return new DtrMap1(tokIndices[start],dtr); } public void toString(StringBuffer sb, int depth, SymbolTable st) { /* nothing to add */ } public long extensionCount() { return 0l; } public int[] integersFollowing() { return IntNode.EMPTY_INT_ARRAY; } public DtrMap prune(int minCount) { return this; // nothing to prune } public DtrMap rescale(double countMultiplier) { return this; // nought to scale } public void addDtrs(List queue) { /* nothing to add */ } public int dtrsTrieSize() { return 0; } static final DtrMap EMPTY_DTR_MAP = new DtrMap0();}class DtrMap1 implements DtrMap { final int mTok; IntNode mDtr = new IntNode(); public DtrMap1(int tok, IntNode dtr) { mTok = tok; mDtr = dtr; } public DtrMap prune(int minCount) { if (mDtr.count() < minCount) return DtrMap0.EMPTY_DTR_MAP; mDtr.prune(minCount); return this; } public DtrMap rescale(double countMultiplier) { mDtr.rescale(countMultiplier); if (mDtr.count() == 0) return DtrMap0.EMPTY_DTR_MAP; return this; } public int numExtensions() { return 1; } public void toString(StringBuffer sb, int depth, SymbolTable st) { if (st != null) sb.append(IntNode.idToSymbol(mTok,st)); else sb.append(mTok); sb.append(": "); mDtr.toString(sb,depth+1,st); } public void addDtrs(List queue) { queue.add(mDtr); } public int dtrsTrieSize() { return mDtr.trieSize(); } public IntNode getDtr(int tok) { return tok == mTok ? mDtr : null; } public DtrMap incrementDtrs(int[] tokIndices, int start, int end) { if (start == end) return this; if (tokIndices[start] == mTok) { mDtr.increment(tokIndices,start+1,end); return this;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -