📄 ngramprocesslm.java
字号:
* Construct an n-gram process language model with the specified * number of characters, interpolation parameter * and character sequence counter. The maximum n-gram is determined * by the sequence counter. * * <p>The counter argument allows serialized counters to be * read back in and used to create an n-gram process LM. * * @param numChars Maximum number of characters in training and * test data. * @param lambdaFactor Interpolation parameter (see class doc). * @param counter Character sequence counter to use. * @throws IllegalArgumentException If the number of characters is * not between 1 and the maximum number of characters, of if the * lambda factor is not greater than or equal to 0. */ public NGramProcessLM(int numChars, double lambdaFactor, TrieCharSeqCounter counter) { mMaxNGram = counter.mMaxLength; setLambdaFactor(lambdaFactor); // checks range setNumChars(numChars); mTrieCharSeqCounter = counter; } /** * Writes this language model to the specified output stream. * * <p>A language model is written using a {@link BitOutput} * wrapped around the specified output stream. This bit output is * used to delta encode the maximum n-gram, number of characters, * lambda factor times 1,000,000, and then the underlying sequence * counter using {@link * TrieCharSeqCounter#writeCounter(CharSeqCounter,TrieWriter,int)}. * The bit output is flushed, but the output stream is not closed. * * <p>A language model can be read and written using the following * code, given a file <code>f</code>: * * <blockquote><pre> * NGramProcessLM lm = ...; * File f = ...; * OutputStream out = new FileOutputStream(f); * BufferedOutputStream bufOut = new BufferedOutputStream(out); * lm.writeTo(bufOut); * bufOut.close(); * * ... * InputStream in = new FileInputStream(f); * BufferedInputStream bufIn = new BufferedInputStream(in); * NGramProcessLM lm2 = NGramProcessLM.readFrom(bufIn); * bufIn.close();</pre></blockquote> * * @param out Output stream to which to write language model. * @throws IOException If there is an underlying I/O error. */ public void writeTo(OutputStream out) throws IOException { BitOutput bitOut = new BitOutput(out); writeTo(bitOut); bitOut.flush(); } void writeTo(BitOutput bitOut) throws IOException { bitOut.writeDelta(mMaxNGram); bitOut.writeDelta(mNumChars); bitOut.writeDelta((int) (mLambdaFactor * 1000000)); BitTrieWriter trieWriter = new BitTrieWriter(bitOut); TrieCharSeqCounter.writeCounter(mTrieCharSeqCounter,trieWriter, mMaxNGram); } /** * Reads a language model from the specified input stream. * * <p>See {@link #writeTo(OutputStream)} for information on the * binary I/O format. * * @param in Input stream from which to read a language model. * @return The language model read from the stream. * @throws IOException If there is an underlying I/O error. */ public static NGramProcessLM readFrom(InputStream in) throws IOException { BitInput bitIn = new BitInput(in); return readFrom(bitIn); } static NGramProcessLM readFrom(BitInput bitIn) throws IOException { int maxNGram = (int) bitIn.readDelta(); int numChars = (int) bitIn.readDelta(); double lambdaFactor = bitIn.readDelta() / 1000000.0; BitTrieReader trieReader = new BitTrieReader(bitIn); TrieCharSeqCounter counter = TrieCharSeqCounter.readCounter(trieReader,maxNGram); return new NGramProcessLM(numChars,lambdaFactor,counter); } public double log2Prob(CharSequence cSeq) { return log2Estimate(cSeq); } public double prob(CharSequence cSeq) { return Math.pow(2.0,log2Estimate(cSeq)); } public final double log2Estimate(CharSequence cSeq) { char[] cs = Strings.toCharArray(cSeq); return log2Estimate(cs,0,cs.length); } public final double log2Estimate(char[] cs, int start, int end) { Strings.checkArgsStartEnd(cs,start,end); double sum = 0.0; for (int i = start+1; i <= end; ++i) sum += log2ConditionalEstimate(cs,start,i); return sum; } public void train(CharSequence cSeq) { train(cSeq,1); } public void train(CharSequence cSeq, int incr) { char[] cs = Strings.toCharArray(cSeq); train(cs,0,cs.length,incr); } public void train(char[] cs, int start, int end) { train(cs,start,end,1); } public void train(char[] cs, int start, int end, int incr) { Strings.checkArgsStartEnd(cs,start,end); mTrieCharSeqCounter.incrementSubstrings(cs,start,end,incr); } /** * Convenience implementation of the {@link TextHandler} * interface, which defers to {@link #train(char[],int,int)}. * Note that this method uses start and length encoding * of a slice, whereas the training method uses start end * end encodings. * @param cs Underlying character array. * @param start Index of first character in slice. * @param length Number of characters in slice. * @throws IndexOutOfBoundsException If the indices do not fall within * the specified character array. */ public void handle(char[] cs, int start, int length) { train(cs,start,length-start); } /** * Trains the specified conditional outcome(s) of the specified * character slice given the background slice. * <P>This method just shorthand for incrementing the counts of * all substrings of <code>cs</code> from position * <code>start</code> to <code>end-1</code> inclusive, then * decrementing all of the counts of substrings from position * <code>start</code> to <code>condEnd-1</code>. For instance, if * <code>cs</code> is * <code>"abcde".toCharArray()</code>, then calling * <code>trainConditional(cs,0,5,2)</code> will increment the * counts of <code>cde</code> given <code>ab</code>, but will not * increment the counts of <code>ab</code> directly. This increases * the following probabilities: * * <blockquote><code> * P('e'|"abcd") * P('e'|"bcd") * P('e'|"cd") * P('e'|"d") * P('e'|"") * <br> * P('d'|"abc") * P('d'|"bc") * P('d'|"c") * P('d'|"") * <br> * P('c'|"ab") * P('c'|"b") * P('c'|"") * </code></blockquote> * * but does not increase the following probabilities: * * <blockquote><code> * P('b'|"a") * P('b'|"") * <br> * P('a'|"") * </blockquote> * * @param cs Array of characters. * @param start Start position for slice. * @param end One past end position for slice. * @param condEnd One past the end of the conditional portion of * the slice. */ public void trainConditional(char[] cs, int start, int end, int condEnd) { Strings.checkArgsStartEnd(cs,start,end); Strings.checkArgsStartEnd(cs,start,condEnd); if (condEnd > end) { String msg = "Conditional end must be < end." + " Found condEnd=" + condEnd + " end=" + end; throw new IllegalArgumentException(msg); } if (condEnd == end) return; mTrieCharSeqCounter.incrementSubstrings(cs,start,end); mTrieCharSeqCounter.decrementSubstrings(cs,start,condEnd); } public char[] observedCharacters() { return mTrieCharSeqCounter.observedCharacters(); } /** * Writes a compiled version of this process language model to the * specified object output. * * <P>The object written will be an instance of {@link * CompiledNGramProcessLM}. It may be read in by casting the * result of {@link ObjectInput#readObject()}. * * <P>Compilation is time consuming, because it must traverse the * entire trie structure, and for each node, estimate its log * probability and if it is internal, its log interpolation value. * Given that time taken is proportional to the size of the trie, * pruning first may greatly speed up this operation and reduce * the size of the compiled object that is written. * * @param objOut Object output to which a compiled version of this * langauge model will be written. * @throws IOException If there is an I/O exception writing the * compiled object. */ public void compileTo(ObjectOutput objOut) throws IOException { objOut.writeObject(new Externalizer(this)); } public double log2ConditionalEstimate(CharSequence cSeq) { return log2ConditionalEstimate(cSeq,mMaxNGram,mLambdaFactor); } public double log2ConditionalEstimate(char[] cs, int start, int end) { return log2ConditionalEstimate(cs,start,end,mMaxNGram,mLambdaFactor); } /** * Returns the substring counter for this language model. * Modifying the counts in the returned counter, such as by * pruning, will change the estimates in this language model. * * @return Substring counter for this language model. */ public TrieCharSeqCounter substringCounter() { return mTrieCharSeqCounter; } /** * Returns the maximum n-gram length for this model. * * @return The maximum n-gram length for this model. */ public int maxNGram() { return mMaxNGram; } /** * Returns the log (base 2) conditional estimate of the last * character in the specified character sequence given the * previous characters based only on counts of n-grams up to the * specified maximum n-gram. If the maximum n-gram argument is * greater than or equal to the one supplied at construction time, * the results wil lbe the same as the ordinary conditional * estimate. * * @param cSeq Character sequence to estimate. * @param maxNGram Maximum length of n-gram count to use for * estimate. * @param lambdaFactor Value of interpolation hyperparameter for * this estimate. * @return Log (base 2) conditional estimate. * @throws IllegalArgumentException If the character sequence is not at * least one character long. */ public double log2ConditionalEstimate(CharSequence cSeq, int maxNGram, double lambdaFactor) { char[] cs = Strings.toCharArray(cSeq); return log2ConditionalEstimate(cs,0,cs.length,maxNGram,lambdaFactor); } /** * Returns the log (base 2) conditional estimate for a specified * character slice with a specified maximum n-gram and specified * hyperparameter. * @param cs Underlying character array. * @param start Index of first character in slice. * @param end Index of one past last character in slice.
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -