⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 latentdirichletallocation.java

📁 一个自然语言处理的Java开源工具包。LingPipe目前已有很丰富的功能
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
        return mTopicWordProbs[topic][word];    }    /**     * Returns an array representing of probabilities of words in the     * specified topic.  The probabilities are indexed by word     * identifier.     *     * <p>The returned result is a copy of the underlying data in     * the model so that changing it will not change the model.     *     * @param topic Topic identifier.     * @return Array of probabilities of words in the specified topic.     */    public double[] wordProbabilities(int topic) {        double[] xs = new double[mTopicWordProbs[topic].length];        for (int i = 0; i < xs.length; ++i)            xs[i] = mTopicWordProbs[topic][i];        return xs;    }    /**     * Returns the specified number of Gibbs samples of topics for the     * specified tokens using the specified number of burnin epochs,     * the specified lag between samples, and the specified     * randomizer.  The array returned is an array of samples, each     * sample consisting of a topic assignment to each token in the     * specified list of tokens.  The tokens must all be in the appropriate     * range for this class     *     * <p>See the class documentation for more information on how the     * samples are computed.     *     * @param tokens The tokens making up the document.     * @param numSamples Number of Gibbs samples to return.     * @param burnin The number of samples to take and throw away     * during the burnin period.     * @param sampleLag The interval between samples after burnin.     * @param random The random number generator to use for this sampling     * process.     * @return The selection of topic samples generated by this     * sampler.     * @throws IndexOutOfBoundsException If there are tokens whose     * value is less than zero, or whose value is greater than the     * number of tokens in this model.     * @throws IllegalArgumentException If the number of samples is     * not positive, the sample lag is not positive, or if the burnin     * period is negative.  if the number of samples, burnin, and lag     * are not positive numbers.     */    public short[][] sampleTopics(int[] tokens,                                  int numSamples,                                  int burnin,                                  int sampleLag,                                  Random random) {        if (burnin < 0) {            String msg = "Burnin period must be non-negative."                + " Found burnin=" + burnin;            throw new IllegalArgumentException(msg);        }        if (numSamples < 1) {            String msg = "Number of samples must be at least 1."                + " Found numSamples=" + numSamples;            throw new IllegalArgumentException(msg);        }        if (sampleLag < 1) {            String msg = "Sample lag must be at least 1."                + " Found sampleLag=" + sampleLag;            throw new IllegalArgumentException(msg);        }        double docTopicPrior = documentTopicPrior();        int numTokens = tokens.length;        int numTopics = numTopics();        int[] topicCount = new int[numTopics];        short[][] samples = new short[numSamples][numTokens];        int sample = 0;        short[] currentSample = samples[0];        for (int token = 0; token < numTokens; ++token) {            int randomTopic = random.nextInt(numTopics);            ++topicCount[randomTopic];            currentSample[token] = (short) randomTopic;        }        double[] topicDistro = new double[numTopics];        int numEpochs = burnin + sampleLag * (numSamples - 1);        for (int epoch = 0; epoch < numEpochs; ++epoch) {            for (int token = 0; token < numTokens; ++token) {                int word = tokens[token];                int currentTopic = currentSample[token];                --topicCount[currentTopic];                if (topicCount[currentTopic] < 0) {                    throw new IllegalArgumentException("bomb");                }                for (int topic = 0; topic < numTopics; ++topic) {                    topicDistro[topic]                        = (topicCount[topic] + docTopicPrior)                        * wordProbability(topic,word)                        + (topic == 0 ? 0.0 : topicDistro[topic-1]);                }                int sampledTopic = Statistics.sample(topicDistro,random);                ++topicCount[sampledTopic];                currentSample[token] = (short) sampledTopic;            }            if ((epoch >= burnin) && (((epoch - burnin) % sampleLag) == 0)) {                short[] pastSample = currentSample;                ++sample;                currentSample = samples[sample];                for (int token = 0; token < numTokens; ++token)                    currentSample[token] = pastSample[token];            }        }        return samples;    }    /**     * Return the maximum a posteriori (MAP) estimate of the topic     * distribution for a document consisting of the specified tokens,     * using Gibbs sampling with the specified parameters.  The     * Gibbs topic samples are simply averaged to produce the MAP     * estimate.     *     * <p>See the method {@link #sampleTopics(int[],int,int,int,Random)}     * and the class documentation for more information on the sampling     * procedure.     *     * @param tokens The tokens making up the document.     * @param numSamples Number of Gibbs samples to return.     * @param burnin The number of samples to take and throw away     * during the burnin period.     * @param sampleLag The interval between samples after burnin.     * @param random The random number generator to use for this sampling     * process.     * @return The selection of topic samples generated by this     * sampler.     * @throws IndexOutOfBoundsException If there are tokens whose     * value is less than zero, or whose value is greater than the     * number of tokens in this model.     * @throws IllegalArgumentException If the number of samples is     * not positive, the sample lag is not positive, or if the burnin     * period is negative.  if the number of samples, burnin, and lag     * are not positive numbers.     */    public double[] mapTopicEstimate(int[] tokens,                                     int numSamples,                                     int burnin,                                     int sampleLag,                                     Random random) {        short[][] sampleTopics = sampleTopics(tokens,numSamples,burnin,sampleLag,random);        int numTopics = numTopics();        int[] counts = new int[numTopics];        for (short[] topics : sampleTopics) {            for (int tok = 0; tok < topics.length; ++tok)                ++counts[topics[tok]];        }        double totalCount = 0;        for (int topic = 0; topic < numTopics; ++topic)            totalCount += counts[topic];        double[] result = new double[numTopics];        for (int topic = 0; topic < numTopics; ++topic)            result[topic] = counts[topic] / totalCount;        return result;    }    /**     * Run Gibbs sampling for the specified multinomial data, number     * of topics, priors, search parameters, randomization and     * callback sample handler.  Gibbs sampling provides samples from the     * posterior distribution of topic assignments given the corpus     * and prior hyperparameters.  A sample is encapsulated as an     * instance of class {@link GibbsSample}.  This method will return     * the final sample and also send intermediate samples to an     * optional handler.     *     * <p>The class documentation above explains Gibbs sampling for     * LDA as used in this method.     *     * <p>The primary input is an array of documents, where each     * document is represented as an array of integers representing     * the tokens that appear in it.  These tokens should be numbered     * contiguously from 0 for space efficiency.  The topic assignments     * in the Gibbs sample are aligned as parallel arrays to the array     * of documents.     *     * <p>The next three parameters are the hyperparameters of the     * model, specifically the number of topics, the prior count     * assigned to topics in a document, and the prior count assigned     * to words in topics.  A rule of thumb for the document-topic     * prior is to set it to 5 divided by the number of topics (or     * less if there are very few topics; 0.1 is typically the maximum     * value used).  A good general value for the topic-word prior is     * 0.01.  Both of these priors will be diffuse and tend to lead to     * skewed posterior distributions.     *     * <p>The following three parameters specify how the sampling is     * to be done.  First, the sampler is &quot;burned in&quot; for a     * number of epochs specified by the burnin parameter.  After burn     * in, samples are taken after fixed numbers of documents to avoid     * correlation in the samples; the sampling frequency is specified     * by the sample lag.  Finally, the number of samples to be taken     * is specified.  For instance, if the burnin is 1000, the sample     * lag is 250, and the number of samples is 5, then samples are     * taken after 1000, 1250, 1500, 1750 and 2000 epochs.  If a     * non-null handler object is specified in the method call, its     * <code>handle(GibbsSample)</code> method is called with each the     * samples produced as above.     *     * <p>The final sample in the chain of samples is returned as the     * result.  Note that this sample will also have been passed to the     * specified handler as the last sample for the handler.     *     * <p>A random number generator must be supplied as an argument.     * This may just be a new instance of {@link java.util.Random} or     * a custom extension. It is used for all randomization in this     * method.     *     * @param docWords Corpus of documents to be processed.     * @param numTopics Number of latent topics to generate.     * @param docTopicPrior Prior count of topics in a document.     * @param topicWordPrior Prior count of words in a topic.     * @param burninEpochs  Number of epochs to run before taking a sample.     * @param sampleLag Frequency between samples.     * @param numSamples Number of samples to take before exiting.     * @param random Random number generator.     * @param handler Handler to which the samples are sent.     * @return The final Gibbs sample.     */    public static GibbsSample        gibbsSampler(int[][] docWords,                     short numTopics,                     double docTopicPrior,                     double topicWordPrior,                     int burninEpochs,                     int sampleLag,                     int numSamples,                     Random random,                     ObjectHandler<GibbsSample> handler) {        validateInputs(docWords,numTopics,docTopicPrior,topicWordPrior,burninEpochs,sampleLag,numSamples);        int numDocs = docWords.length;        int numWords = max(docWords) + 1;        int numTokens = 0;        for (int doc = 0; doc < numDocs; ++doc)            numTokens += docWords[doc].length;        // should inputs be permuted?        // for (int doc = 0; doc < numDocs; ++doc)        // Arrays.permute(docWords[doc]);        short[][] currentSample = new short[numDocs][];        for (int doc = 0; doc < numDocs; ++doc)            currentSample[doc] = new short[docWords[doc].length];        int[][] docTopicCount = new int[numDocs][numTopics];        int[][] wordTopicCount = new int[numWords][numTopics];        int[] topicTotalCount = new int[numTopics];        for (int doc = 0; doc < numDocs; ++doc) {            for (int tok = 0; tok < docWords[doc].length; ++tok) {                int word = docWords[doc][tok];                int topic = random.nextInt(numTopics);                currentSample[doc][tok] = (short) topic;                ++docTopicCount[doc][topic];                ++wordTopicCount[word][topic];                ++topicTotalCount[topic];            }        }        double numWordsTimesTopicWordPrior = numWords * topicWordPrior;        double[] topicDistro = new double[numTopics];        long startTime = System.currentTimeMillis();        int numEpochs = burninEpochs + sampleLag * (numSamples - 1);        for (int epoch = 0; epoch <= numEpochs; ++epoch) {            double corpusLog2Prob = 0.0;            int numChangedTopics = 0;            for (int doc = 0; doc < numDocs; ++doc) {                int[] docWordsDoc = docWords[doc];                short[] currentSampleDoc = currentSample[doc];                int[] docTopicCountDoc = docTopicCount[doc];                for (int tok = 0; tok < docWordsDoc.length; ++tok) {                    int word = docWordsDoc[tok];                    int[] wordTopicCountWord = wordTopicCount[word];                    int currentTopic = currentSampleDoc[tok];                    if (currentTopic == 0) {                        topicDistro[0]                            = (docTopicCountDoc[0] - 1.0 + docTopicPrior)                            * (wordTopicCountWord[0] - 1.0 + topicWordPrior)

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -