📄 ngramprocesslm.java
字号:
* @param maxNGram Maximum length of n-gram to use in estimates. * @param lambdaFactor Value of interpolation hyperparameter. * @return Log (base 2) conditional estimate of the last character * in the slice given the previous characters. * @throws IndexOutOfBoundsException If the start index and end * index minus one are out of range of the character array or if the * character slice is less than one character long. */ public double log2ConditionalEstimate(char[] cs, int start, int end, int maxNGram, double lambdaFactor) { if (end <= start) { String msg = "Conditional estimates require at least one character."; throw new IllegalArgumentException(msg); } Strings.checkArgsStartEnd(cs,start,end); checkMaxNGram(maxNGram); checkLambdaFactor(lambdaFactor); int maxUsableNGram = Math.min(maxNGram,mMaxNGram); if (start == end) return 0.0; double currentEstimate = mUniformEstimate; int contextEnd = end-1; int longestContextStart = Math.max(start,end-maxUsableNGram); for (int currentContextStart = contextEnd; currentContextStart >= longestContextStart; --currentContextStart) { long contextCount = mTrieCharSeqCounter.extensionCount(cs,currentContextStart,contextEnd); if (contextCount == 0) break; long outcomeCount = mTrieCharSeqCounter.count(cs,currentContextStart,end); double lambda = lambda(cs,currentContextStart,contextEnd,lambdaFactor); currentEstimate = lambda * (((double)outcomeCount) / (double)contextCount) + (1.0 - lambda) * currentEstimate; } return com.aliasi.util.Math.log2(currentEstimate); } /** * Returns the interpolation ratio for the specified character * slice interpreted as a context. The hyperparameter used is * that returned by {@link #getLambdaFactor()}. The definition of * <code>lambda()</code> is provided in the class documentation * above. * * @param cs Underlying character array. * @param start Index of first character in slice. * @param end Index of one past last character in slice. * @throws IndexOutOfBoundsException If the start index and end * index minus one are out of range of the character array. */ double lambda(char[] cs, int start, int end) { return lambda(cs,start,end,getLambdaFactor()); } /** * Returns the interpolation ratio for the specified character * slice interpreted as a context with the specified * hyperparameter. The definition of <code>lambda()</code> is * provided in the class documentation above. * * @param cs Underlying character array. * @param start Index of first character in slice. * @param end Index of one past last character in slice. * @param lambdaFactor Value for interpolation ratio hyperparameter. * @throws IndexOutOfBoundsException If the start index and end * index minus one are out of range of the character array. */ double lambda(char[] cs, int start, int end, double lambdaFactor) { checkLambdaFactor(lambdaFactor); Strings.checkArgsStartEnd(cs,start,end); double count = mTrieCharSeqCounter.extensionCount(cs,start,end); if (count <= 0.0) return 0.0; double numOutcomes = mTrieCharSeqCounter.numCharactersFollowing(cs,start,end); return lambda(count,numOutcomes,lambdaFactor); } /** * Returns the current setting of the interpolation ratio * hyperparameter. See the class documentation above for * information on how the interpolation ratio is used in * estimates. * * @return The current setting of the interpolation ratio * hyperparameter. */ public double getLambdaFactor() { return mLambdaFactor; } /** * Sets the value of the interpolation ratio hyperparameter * to the specified value. See the class documentation above * for information on how the interpolation ratio is used in estimates. * * @param lambdaFactor New value for interpolation ratio * hyperparameter. * @throws IllegalArgumentException If the value is not greater * than or equal to zero. */ public final void setLambdaFactor(double lambdaFactor) { checkLambdaFactor(lambdaFactor); mLambdaFactor = lambdaFactor; } /** * Sets the number of characters for this language model. All * subsequent estimates will be based on this number. See the * class definition above for information on how the number of * character is used to determine the base case uniform * distribution. * * @param numChars New number of characters for this language model. * @throws IllegalArgumentException If the number of characters is * less than <code>0</code> or more than * <code>Character.MAX_VALUE</code>. */ public final void setNumChars(int numChars) { checkNumChars(numChars); mNumChars = numChars; mUniformEstimate = 1.0 / (double)mNumChars; mLog2UniformEstimate = com.aliasi.util.Math.log2(mUniformEstimate); } /** * Returns a string-based representation of this language model. * * @return A string-based representation of this language model. */ public String toString() { StringBuffer sb = new StringBuffer(); toStringBuffer(sb); return sb.toString(); } void toStringBuffer(StringBuffer sb) { sb.append("Max NGram=" + mMaxNGram + " "); sb.append("Num characters=" + mNumChars + "\n"); sb.append("Trie of counts=\n"); mTrieCharSeqCounter.toStringBuffer(sb); } // need this for the process model to get boundaries right void decrementUnigram(char c) { decrementUnigram(c,1); } void decrementUnigram(char c, int count) { mTrieCharSeqCounter.decrementUnigram(c,count); } private double lambda(double count, double numOutcomes, double lambdaFactor) { return count / (count + lambdaFactor * numOutcomes); } private double lambda(Node node) { double count = node.contextCount(Strings.EMPTY_CHAR_ARRAY,0,0); double numOutcomes = node.numOutcomes(Strings.EMPTY_CHAR_ARRAY,0,0); return lambda(count,numOutcomes,mLambdaFactor); } private int lastInternalNodeIndex() { int last = 1; LinkedList queue = new LinkedList(); queue.add(mTrieCharSeqCounter.mRootNode); for (int i = 1; !queue.isEmpty(); ++i) { Node node = (Node) queue.removeFirst(); if (node.numOutcomes(com.aliasi.util.Arrays.EMPTY_CHAR_ARRAY, 0,0) > 0) last = i; node.addDaughters(queue); } return last-1; } private Object writeReplace() { return new Serializer(this); } // unfortunately, this depends on serialization happening with streams static class Serializer implements Externalizable { static final long serialVersionUID = -7101238964823109652L; NGramProcessLM mLM; public Serializer() { } public Serializer(NGramProcessLM lm) { mLM = lm; } public void writeExternal(ObjectOutput out) throws IOException { mLM.writeTo((OutputStream) out); } public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { mLM = NGramProcessLM.readFrom((InputStream) in); } public Object readResolve() { return mLM; } } static class Externalizer extends AbstractExternalizable { static final long serialVersionUID = -3623859317152451545L; final NGramProcessLM mLM; public Externalizer() { this(null); } public Externalizer(NGramProcessLM lm) { mLM = lm; } public Object read(ObjectInput in) throws IOException { return new CompiledNGramProcessLM(in); } public void writeExternal(ObjectOutput dataOut) throws IOException { dataOut.writeInt(mLM.mMaxNGram); dataOut.writeFloat((float) mLM.mLog2UniformEstimate); long numNodes = mLM.mTrieCharSeqCounter.uniqueSequenceCount(); if (numNodes > Integer.MAX_VALUE) { String msg = "Maximum number of compiled nodes is" + " Integer.MAX_VALUE = " + Integer.MAX_VALUE + " Found number of nodes=" + numNodes; throw new IllegalArgumentException(msg); } dataOut.writeInt((int)numNodes); int lastInternalNodeIndex = mLM.lastInternalNodeIndex(); dataOut.writeInt(lastInternalNodeIndex); // write root node (char,logP,log(1-L),firstDtr) dataOut.writeChar('\uFFFF'); dataOut.writeFloat((float) mLM.mLog2UniformEstimate); double oneMinusLambda = 1.0 - mLM.lambda(mLM.mTrieCharSeqCounter.mRootNode); float log2OneMinusLambda = Double.isNaN(oneMinusLambda) ? 0f : (float) com.aliasi.util.Math.log2(oneMinusLambda); dataOut.writeFloat(log2OneMinusLambda); dataOut.writeInt(1); // firstDtr char[] cs = mLM.mTrieCharSeqCounter.observedCharacters(); LinkedList queue = new LinkedList(); for (int i = 0; i < cs.length; ++i) queue.add(new char[] { cs[i] }); for (int index = 1; !queue.isEmpty(); ++index) { char[] nGram = (char[]) queue.removeFirst(); char c = nGram[nGram.length-1]; dataOut.writeChar(c); float logConditionalEstimate = (float) mLM.log2ConditionalEstimate(nGram,0,nGram.length); dataOut.writeFloat(logConditionalEstimate); if (index <= lastInternalNodeIndex) { double oneMinusLambda2 = 1.0 - mLM.lambda(nGram,0,nGram.length); float log2OneMinusLambda2 = (float) com.aliasi.util.Math.log2(oneMinusLambda2); dataOut.writeFloat(log2OneMinusLambda2); int firstChildIndex = index + queue.size() + 1; dataOut.writeInt(firstChildIndex); } char[] cs2 = mLM.mTrieCharSeqCounter .charactersFollowing(nGram,0,nGram.length); for (int i = 0; i < cs2.length; ++i) queue.add(com.aliasi.util.Arrays.concatenate(nGram,cs2[i])); } } } static void checkLambdaFactor(double lambdaFactor) { if (lambdaFactor < 0.0 || Double.isInfinite(lambdaFactor) || Double.isNaN(lambdaFactor)) { String msg = "Lambda factor must be ordinary non-negative double." + " Found lambdaFactor=" + lambdaFactor; throw new IllegalArgumentException(msg); } } static void checkMaxNGram(int maxNGram) { if (maxNGram < 1) { String msg = "Maximum n-gram must be greater than zero." + " Found max n-gram=" + maxNGram; throw new IllegalArgumentException(msg); } } private static void checkNumChars(int numChars) { if (numChars < 0 || numChars > Character.MAX_VALUE) { String msg = "Number of characters must be > 0 and " + " must be less than Character.MAX_VALUE" + " Found numChars=" + numChars; throw new IllegalArgumentException(msg); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -