📄 latentdirichletallocation.java
字号:
return mTopicWordProbs[topic][word]; } /** * Returns an array representing of probabilities of words in the * specified topic. The probabilities are indexed by word * identifier. * * <p>The returned result is a copy of the underlying data in * the model so that changing it will not change the model. * * @param topic Topic identifier. * @return Array of probabilities of words in the specified topic. */ public double[] wordProbabilities(int topic) { double[] xs = new double[mTopicWordProbs[topic].length]; for (int i = 0; i < xs.length; ++i) xs[i] = mTopicWordProbs[topic][i]; return xs; } /** * Returns the specified number of Gibbs samples of topics for the * specified tokens using the specified number of burnin epochs, * the specified lag between samples, and the specified * randomizer. The array returned is an array of samples, each * sample consisting of a topic assignment to each token in the * specified list of tokens. The tokens must all be in the appropriate * range for this class * * <p>See the class documentation for more information on how the * samples are computed. * * @param tokens The tokens making up the document. * @param numSamples Number of Gibbs samples to return. * @param burnin The number of samples to take and throw away * during the burnin period. * @param sampleLag The interval between samples after burnin. * @param random The random number generator to use for this sampling * process. * @return The selection of topic samples generated by this * sampler. * @throws IndexOutOfBoundsException If there are tokens whose * value is less than zero, or whose value is greater than the * number of tokens in this model. * @throws IllegalArgumentException If the number of samples is * not positive, the sample lag is not positive, or if the burnin * period is negative. if the number of samples, burnin, and lag * are not positive numbers. */ public short[][] sampleTopics(int[] tokens, int numSamples, int burnin, int sampleLag, Random random) { if (burnin < 0) { String msg = "Burnin period must be non-negative." + " Found burnin=" + burnin; throw new IllegalArgumentException(msg); } if (numSamples < 1) { String msg = "Number of samples must be at least 1." + " Found numSamples=" + numSamples; throw new IllegalArgumentException(msg); } if (sampleLag < 1) { String msg = "Sample lag must be at least 1." + " Found sampleLag=" + sampleLag; throw new IllegalArgumentException(msg); } double docTopicPrior = documentTopicPrior(); int numTokens = tokens.length; int numTopics = numTopics(); int[] topicCount = new int[numTopics]; short[][] samples = new short[numSamples][numTokens]; int sample = 0; short[] currentSample = samples[0]; for (int token = 0; token < numTokens; ++token) { int randomTopic = random.nextInt(numTopics); ++topicCount[randomTopic]; currentSample[token] = (short) randomTopic; } double[] topicDistro = new double[numTopics]; int numEpochs = burnin + sampleLag * (numSamples - 1); for (int epoch = 0; epoch < numEpochs; ++epoch) { for (int token = 0; token < numTokens; ++token) { int word = tokens[token]; int currentTopic = currentSample[token]; --topicCount[currentTopic]; if (topicCount[currentTopic] < 0) { throw new IllegalArgumentException("bomb"); } for (int topic = 0; topic < numTopics; ++topic) { topicDistro[topic] = (topicCount[topic] + docTopicPrior) * wordProbability(topic,word) + (topic == 0 ? 0.0 : topicDistro[topic-1]); } int sampledTopic = Statistics.sample(topicDistro,random); ++topicCount[sampledTopic]; currentSample[token] = (short) sampledTopic; } if ((epoch >= burnin) && (((epoch - burnin) % sampleLag) == 0)) { short[] pastSample = currentSample; ++sample; currentSample = samples[sample]; for (int token = 0; token < numTokens; ++token) currentSample[token] = pastSample[token]; } } return samples; } /** * Return the maximum a posteriori (MAP) estimate of the topic * distribution for a document consisting of the specified tokens, * using Gibbs sampling with the specified parameters. The * Gibbs topic samples are simply averaged to produce the MAP * estimate. * * <p>See the method {@link #sampleTopics(int[],int,int,int,Random)} * and the class documentation for more information on the sampling * procedure. * * @param tokens The tokens making up the document. * @param numSamples Number of Gibbs samples to return. * @param burnin The number of samples to take and throw away * during the burnin period. * @param sampleLag The interval between samples after burnin. * @param random The random number generator to use for this sampling * process. * @return The selection of topic samples generated by this * sampler. * @throws IndexOutOfBoundsException If there are tokens whose * value is less than zero, or whose value is greater than the * number of tokens in this model. * @throws IllegalArgumentException If the number of samples is * not positive, the sample lag is not positive, or if the burnin * period is negative. if the number of samples, burnin, and lag * are not positive numbers. */ public double[] mapTopicEstimate(int[] tokens, int numSamples, int burnin, int sampleLag, Random random) { short[][] sampleTopics = sampleTopics(tokens,numSamples,burnin,sampleLag,random); int numTopics = numTopics(); int[] counts = new int[numTopics]; for (short[] topics : sampleTopics) { for (int tok = 0; tok < topics.length; ++tok) ++counts[topics[tok]]; } double totalCount = 0; for (int topic = 0; topic < numTopics; ++topic) totalCount += counts[topic]; double[] result = new double[numTopics]; for (int topic = 0; topic < numTopics; ++topic) result[topic] = counts[topic] / totalCount; return result; } /** * Run Gibbs sampling for the specified multinomial data, number * of topics, priors, search parameters, randomization and * callback sample handler. Gibbs sampling provides samples from the * posterior distribution of topic assignments given the corpus * and prior hyperparameters. A sample is encapsulated as an * instance of class {@link GibbsSample}. This method will return * the final sample and also send intermediate samples to an * optional handler. * * <p>The class documentation above explains Gibbs sampling for * LDA as used in this method. * * <p>The primary input is an array of documents, where each * document is represented as an array of integers representing * the tokens that appear in it. These tokens should be numbered * contiguously from 0 for space efficiency. The topic assignments * in the Gibbs sample are aligned as parallel arrays to the array * of documents. * * <p>The next three parameters are the hyperparameters of the * model, specifically the number of topics, the prior count * assigned to topics in a document, and the prior count assigned * to words in topics. A rule of thumb for the document-topic * prior is to set it to 5 divided by the number of topics (or * less if there are very few topics; 0.1 is typically the maximum * value used). A good general value for the topic-word prior is * 0.01. Both of these priors will be diffuse and tend to lead to * skewed posterior distributions. * * <p>The following three parameters specify how the sampling is * to be done. First, the sampler is "burned in" for a * number of epochs specified by the burnin parameter. After burn * in, samples are taken after fixed numbers of documents to avoid * correlation in the samples; the sampling frequency is specified * by the sample lag. Finally, the number of samples to be taken * is specified. For instance, if the burnin is 1000, the sample * lag is 250, and the number of samples is 5, then samples are * taken after 1000, 1250, 1500, 1750 and 2000 epochs. If a * non-null handler object is specified in the method call, its * <code>handle(GibbsSample)</code> method is called with each the * samples produced as above. * * <p>The final sample in the chain of samples is returned as the * result. Note that this sample will also have been passed to the * specified handler as the last sample for the handler. * * <p>A random number generator must be supplied as an argument. * This may just be a new instance of {@link java.util.Random} or * a custom extension. It is used for all randomization in this * method. * * @param docWords Corpus of documents to be processed. * @param numTopics Number of latent topics to generate. * @param docTopicPrior Prior count of topics in a document. * @param topicWordPrior Prior count of words in a topic. * @param burninEpochs Number of epochs to run before taking a sample. * @param sampleLag Frequency between samples. * @param numSamples Number of samples to take before exiting. * @param random Random number generator. * @param handler Handler to which the samples are sent. * @return The final Gibbs sample. */ public static GibbsSample gibbsSampler(int[][] docWords, short numTopics, double docTopicPrior, double topicWordPrior, int burninEpochs, int sampleLag, int numSamples, Random random, ObjectHandler<GibbsSample> handler) { validateInputs(docWords,numTopics,docTopicPrior,topicWordPrior,burninEpochs,sampleLag,numSamples); int numDocs = docWords.length; int numWords = max(docWords) + 1; int numTokens = 0; for (int doc = 0; doc < numDocs; ++doc) numTokens += docWords[doc].length; // should inputs be permuted? // for (int doc = 0; doc < numDocs; ++doc) // Arrays.permute(docWords[doc]); short[][] currentSample = new short[numDocs][]; for (int doc = 0; doc < numDocs; ++doc) currentSample[doc] = new short[docWords[doc].length]; int[][] docTopicCount = new int[numDocs][numTopics]; int[][] wordTopicCount = new int[numWords][numTopics]; int[] topicTotalCount = new int[numTopics]; for (int doc = 0; doc < numDocs; ++doc) { for (int tok = 0; tok < docWords[doc].length; ++tok) { int word = docWords[doc][tok]; int topic = random.nextInt(numTopics); currentSample[doc][tok] = (short) topic; ++docTopicCount[doc][topic]; ++wordTopicCount[word][topic]; ++topicTotalCount[topic]; } } double numWordsTimesTopicWordPrior = numWords * topicWordPrior; double[] topicDistro = new double[numTopics]; long startTime = System.currentTimeMillis(); int numEpochs = burninEpochs + sampleLag * (numSamples - 1); for (int epoch = 0; epoch <= numEpochs; ++epoch) { double corpusLog2Prob = 0.0; int numChangedTopics = 0; for (int doc = 0; doc < numDocs; ++doc) { int[] docWordsDoc = docWords[doc]; short[] currentSampleDoc = currentSample[doc]; int[] docTopicCountDoc = docTopicCount[doc]; for (int tok = 0; tok < docWordsDoc.length; ++tok) { int word = docWordsDoc[tok]; int[] wordTopicCountWord = wordTopicCount[word]; int currentTopic = currentSampleDoc[tok]; if (currentTopic == 0) { topicDistro[0] = (docTopicCountDoc[0] - 1.0 + docTopicPrior) * (wordTopicCountWord[0] - 1.0 + topicWordPrior)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -