📄 latentdirichletallocation.java
字号:
/ (topicTotalCount[0] -1.0 + numWordsTimesTopicWordPrior); } else { topicDistro[0] = (docTopicCountDoc[0] + docTopicPrior) * (wordTopicCountWord[0] + topicWordPrior) / (topicTotalCount[0] + numWordsTimesTopicWordPrior); for (int topic = 1; topic < currentTopic; ++topic) { topicDistro[topic] = (docTopicCountDoc[topic] + docTopicPrior) * (wordTopicCountWord[topic] + topicWordPrior) / (topicTotalCount[topic] + numWordsTimesTopicWordPrior) + topicDistro[topic-1]; } topicDistro[currentTopic] = (docTopicCountDoc[currentTopic] - 1.0 + docTopicPrior) * (wordTopicCountWord[currentTopic] - 1.0 + topicWordPrior) / (topicTotalCount[currentTopic] -1.0 + numWordsTimesTopicWordPrior) + topicDistro[currentTopic-1]; } for (int topic = currentTopic+1; topic < numTopics; ++topic) { topicDistro[topic] = (docTopicCountDoc[topic] + docTopicPrior) * (wordTopicCountWord[topic] + topicWordPrior) / (topicTotalCount[topic] + numWordsTimesTopicWordPrior) + topicDistro[topic-1]; } int sampledTopic = Statistics.sample(topicDistro,random); // compute probs before updates if (sampledTopic != currentTopic) { currentSampleDoc[tok] = (short) sampledTopic; --docTopicCountDoc[currentTopic]; --wordTopicCountWord[currentTopic]; --topicTotalCount[currentTopic]; ++docTopicCountDoc[sampledTopic]; ++wordTopicCountWord[sampledTopic]; ++topicTotalCount[sampledTopic]; } if (sampledTopic != currentTopic) ++numChangedTopics; double topicProbGivenDoc = docTopicCountDoc[sampledTopic] / (double) docWordsDoc.length; double wordProbGivenTopic = wordTopicCountWord[sampledTopic] / (double) topicTotalCount[sampledTopic]; double tokenLog2Prob = Math.log2(topicProbGivenDoc * wordProbGivenTopic); corpusLog2Prob += tokenLog2Prob; } } // double crossEntropyRate = -corpusLog2Prob / numTokens; if ((epoch >= burninEpochs) && (((epoch - burninEpochs) % sampleLag) == 0)) { GibbsSample sample = new GibbsSample(epoch, currentSample, docWords, docTopicPrior, topicWordPrior, docTopicCount, wordTopicCount, topicTotalCount, numChangedTopics, numWords, numTokens); if (handler != null) handler.handle(sample); if (epoch == numEpochs) return sample; } } throw new IllegalStateException("unreachable in practice because of return if epoch==numEpochs"); } /** * Tokenize an array of text documents represented as character * sequences into a form usable by LDA, using the specified * tokenizer factory and symbol table. The symbol table should be * constructed fresh for this application, but may be used after * this method is called for further token to symbol conversions. * Only tokens whose count is equal to or larger the specified * minimum count are included. Only tokens whose count exceeds * the minimum are added to the symbol table, thus producing a * compact set of symbol assignments to tokens for downstream * processing. * * <p><i>Warning</i>: With some tokenizer factories and or minimum * count thresholds, there may be documents with no tokens in * them. * * @param texts The text corpus. * @param tokenizerFactory A tokenizer factory for tokenizing the texts. * @param symbolTable Symbol table used to convert tokens to identifiers. * @param minCount Minimum count for a token to be included in a * document's representation. * @return The tokenized form of a document suitable for input to LDA. */ public static int[][] tokenizeDocuments(CharSequence[] texts, TokenizerFactory tokenizerFactory, SymbolTable symbolTable, int minCount) { ObjectToCounterMap<String> tokenCounter = new ObjectToCounterMap<String>(); for (CharSequence text : texts) { char[] cs = Strings.toCharArray(text); Tokenizer tokenizer = tokenizerFactory.tokenizer(cs,0,cs.length); for (String token : tokenizer) tokenCounter.increment(token); } tokenCounter.prune(minCount); Set<String> tokenSet = tokenCounter.keySet(); for (String token : tokenSet) symbolTable.getOrAddSymbol(token); int[][] docTokenId = new int[texts.length][]; for (int i = 0; i < docTokenId.length; ++i) { docTokenId[i] = tokenizeDocument(texts[i],tokenizerFactory,symbolTable); } return docTokenId; } /** * Tokenizes the specified text document using the specified tokenizer * factory returning only tokens that exist in the symbol table. This * method is useful within a given LDA model for tokenizing new documents * into lists of words. * * @param text Character sequence to tokenize. * @param tokenizerFactory Tokenizer factory for tokenization. * @param symbolTable Symbol table to use for converting tokens * to symbols. * @return The array of integer symbols for tokens that exist in * the symbol table. */ public static int[] tokenizeDocument(CharSequence text, TokenizerFactory tokenizerFactory, SymbolTable symbolTable) { char[] cs = Strings.toCharArray(text); Tokenizer tokenizer = tokenizerFactory.tokenizer(cs,0,cs.length); List<Integer> idList = new ArrayList<Integer>(); for (String token : tokenizer) { int id = symbolTable.symbolToID(token); if (id >= 0) idList.add(id); } int[] tokenIds = new int[idList.size()]; for (int i = 0; i < tokenIds.length; ++i) tokenIds[i] = idList.get(i); return tokenIds; } static int max(int[][] xs) { int max = 0; for (int i = 0; i < xs.length; ++i) { int[] xsI = xs[i]; for (int j = 0; j < xsI.length; ++j) { if (xsI[j] > max) max = xsI[j]; } } return max; } static double relativeDifference(double x, double y) { return java.lang.Math.abs(x - y) / (java.lang.Math.abs(x) + java.lang.Math.abs(y)); } static void validateInputs(int[][] docWords, short numTopics, double docTopicPrior, double topicWordPrior, int burninEpochs, int sampleLag, int numSamples) { for (int doc = 0; doc < docWords.length; ++doc) { for (int tok = 0; tok < docWords[doc].length; ++tok) { if (docWords[doc][tok] >= 0) continue; String msg = "All tokens must have IDs greater than 0." + " Found docWords[" + doc + "][" + tok + "]=" + docWords[doc][tok]; throw new IllegalArgumentException(msg); } } if (numTopics < 1) { String msg = "Num topics must be positive." + " Found numTopics=" + numTopics; throw new IllegalArgumentException(msg); } if (Double.isInfinite(docTopicPrior) || Double.isNaN(docTopicPrior) || docTopicPrior < 0.0) { String msg = "Document-topic prior must be finite and positive." + " Found docTopicPrior=" + docTopicPrior; throw new IllegalArgumentException(msg); } if (Double.isInfinite(topicWordPrior) || Double.isNaN(topicWordPrior) || topicWordPrior < 0.0) { String msg = "Topic-word prior must be finite and positive." + " Found topicWordPrior=" + topicWordPrior; throw new IllegalArgumentException(msg); } if (burninEpochs < 0) { String msg = "Number of burnin epochs must be non-negative." + " Found burninEpochs=" + burninEpochs; throw new IllegalArgumentException(msg); } if (sampleLag < 1) { String msg = "Sample lag must be positive." + " Found sampleLag=" + sampleLag; throw new IllegalArgumentException(msg); } if (numSamples < 1) { String msg = "Number of samples must be positive." + " Found numSamples=" + numSamples; throw new IllegalArgumentException(msg); } } /** * The <code>LatentDirichletAllocation.GibbsSample</code> class * encapsulates all of the information related to a single Gibbs * sample for latent Dirichlet allocation (LDA). A sample * consists of the assignment of a topic identifier to each * token in the corpus. Other methods in this class are derived * from either the topic samples, the data being estimated, and * the LDA parameters such as priors. * * <p>Instances of * this class are created by the sampling method in the containing * class, {@link LatentDirichletAllocation}. For convenience, the * sample includes all of the data used to construct the sample, * as well as the hyperparameters used for sampling. * * <p>As described in the class documentation for the containing * class {@link LatentDirichletAllocation}, the primary content in * a Gibbs sample for LDA is the assignment of a single topic to * each token in the corpus. Cumulative counts for topics in * documents and words in topics as well as total counts are also * available; they do not entail any additional computation costs * as the sampler maintains them as part of the sample. * * <p>The sample also contains meta information about the state * of the sampling procedure. The epoch at which the sample * was produced is provided, as well as an indication of how many * topic assignments changed between this sample and the previous * sample (note that this is the previous sample in the chain, not * necessarily the previous sample handled by the LDA handler; the * handler only gets the samples separated by the specified lag. * * <p>The sample may be used to generate an LDA model. The * resulting model may then be used for estimation of unseen * documents. Typically, models derived from several samples are * used for Bayesian computations, as described in the class * documentation above. * * @author Bob Carpenter * @version 3.3.0 * @since LingPipe3.3 */ public static class GibbsSample { private final int mEpoch; private final short[][] mTopicSample; private final int[][] mDocWords; private final double mDocTopicPrior; private final double mTopicWordPrior; private final int[][] mDocTopicCount; private final int[][] mWordTopicCount; private final int[] mTopicCount; private final int mNumChangedTopics; private final int mNumWords; private final int mNumTokens; GibbsSample(int epoch, short[][] topicSample, int[][] docWords, double docTopicPrior, double topicWordPrior, int[][] docTopicCount, int[][] wordTopicCount, int[] topicCount, int numChangedTopics, int numWords, int numTokens) { mEpoch = epoch; mTopicSample = topicSample;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -