📄 lda.java
字号:
/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.base.topics;import edu.umass.cs.mallet.base.types.*;import edu.umass.cs.mallet.base.util.Random;import java.util.Arrays;import java.io.*;/** * Latent Dirichlet Allocation. * @author Andrew McCallum */// Think about support for incrementally adding more documents...// (I think this means we might want to use FeatureSequence directly).// We will also need to support a growing vocabulary!public class LDA { int numTopics; // Number of topics to be fit double alpha; // Dirichlet(alpha,alpha,...) is the distribution over topics double beta; // Prior on per-topic multinomial distribution over words double tAlpha; double vBeta; InstanceList ilist; // the data field of the instances is expected to hold a FeatureSequence int[][] topics; // indexed by <document index, sequence index> int numTypes; int numTokens; int[][] docTopicCounts; // indexed by <document index, topic index> int[][] typeTopicCounts; // indexed by <feature index, topic index> int[] tokensPerTopic; // indexed by <topic index> public LDA (int numberOfTopics) { this (numberOfTopics, 1.0, 0.01); } public LDA (int numberOfTopics, double alphaSum, double beta) { this.numTopics = numberOfTopics; this.alpha = alphaSum / numTopics; this.beta = beta; } public void estimate (InstanceList documents, int numIterations, int showTopicsInterval, int outputModelInterval, String outputModelFilename, Random r) { ilist = documents; numTypes = ilist.getDataAlphabet().size (); int numDocs = ilist.size(); topics = new int[numDocs][]; docTopicCounts = new int[numDocs][numTopics]; typeTopicCounts = new int[numTypes][numTopics]; tokensPerTopic = new int[numTopics]; tAlpha = alpha * numTopics; vBeta = beta * numTypes; long startTime = System.currentTimeMillis(); // Initialize with random assignments of tokens to topics // and finish allocating this.topics and this.tokens int topic, seqLen; for (int di = 0; di < numDocs; di++) { FeatureSequence fs = (FeatureSequence) ilist.getInstance(di).getData(); seqLen = fs.getLength(); numTokens += seqLen; topics[di] = new int[seqLen]; // Randomly assign tokens to topics for (int si = 0; si < seqLen; si++) { topic = r.nextInt(numTopics); topics[di][si] = topic; docTopicCounts[di][topic]++; typeTopicCounts[fs.getIndexAtPosition(si)][topic]++; tokensPerTopic[topic]++; } } for (int iterations = 0; iterations < numIterations; iterations++) { if (iterations % 10 == 0) System.out.print (iterations); else System.out.print ("."); System.out.flush(); if (showTopicsInterval != 0 && iterations % showTopicsInterval == 0 && iterations > 0) { System.out.println (); printTopWords (5, false); } if (outputModelInterval != 0 && iterations % outputModelInterval == 0 && iterations > 0) { this.write (new File(outputModelFilename+'.'+iterations)); } sampleTopicsForAllDocs (r); } long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0); long minutes = seconds / 60; seconds %= 60; long hours = minutes / 60; minutes %= 60; long days = hours / 24; hours %= 24; System.out.print ("\nTotal time: "); if (days != 0) { System.out.print(days); System.out.print(" days "); } if (hours != 0) { System.out.print(hours); System.out.print(" hours "); } if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); } System.out.print(seconds); System.out.println(" seconds"); // 124.5 seconds // 144.8 seconds after using FeatureSequence instead of tokens[][] array // 121.6 seconds after putting "final" on FeatureSequence.getIndexAtPosition() // 106.3 seconds after avoiding array lookup in inner loop with a temporary variable } /* One iteration of Gibbs sampling, across all documents. */ private void sampleTopicsForAllDocs (Random r) { double[] topicWeights = new double[numTopics]; // Loop over every word in the corpus for (int di = 0; di < topics.length; di++) { sampleTopicsForOneDoc ((FeatureSequence)ilist.getInstance(di).getData(), topics[di], docTopicCounts[di], topicWeights, r); } }/* public double[] assignTopics (int[] testTokens, Random r) { int[] testTopics = new int[testTokens.length]; int[] testTopicCounts = new int[numTopics]; int numTokens = MatrixOps.sum(testTokens); double[] topicWeights = new double[numTopics]; // Randomly assign topics to the words and // incorporate this document in the global counts int topic; for (int si = 0; si < testTokens.length; si++) { topic = r.nextInt (numTopics); testTopics[si] = topic; // analogous to this.topics testTopicCounts[topic]++; // analogous to this.docTopicCounts typeTopicCounts[testTokens[si]][topic]++; tokensPerTopic[topic]++; } // Repeatedly sample topic assignments for the words in this document for (int iterations = 0; iterations < numTokens*2; iterations++) sampleTopicsForOneDoc (testTokens, testTopics, testTopicCounts, topicWeights, r); // Remove this document from the global counts // and also fill topicWeights with an unnormalized distribution over topics for whole doc Arrays.fill (topicWeights, 0.0); for (int si = 0; si < testTokens.length; si++) { topic = testTopics[si]; typeTopicCounts[testTokens[si]][topic]--; tokensPerTopic[topic]--; topicWeights[topic]++; } // Normalize the distribution over topics for whole doc for (int ti = 0; ti < numTopics; ti++) topicWeights[ti] /= testTokens.length; return topicWeights; }*/ private void sampleTopicsForOneDoc (FeatureSequence oneDocTokens, int[] oneDocTopics, // indexed by seq position int[] oneDocTopicCounts, // indexed by topic index double[] topicWeights, Random r) { int[] currentTypeTopicCounts; int type, oldTopic, newTopic; double topicWeightsSum; int docLen = oneDocTokens.getLength(); double tw; // Iterate over the positions (words) in the document for (int si = 0; si < docLen; si++) { type = oneDocTokens.getIndexAtPosition(si); oldTopic = oneDocTopics[si]; // Remove this token from all counts oneDocTopicCounts[oldTopic]--; typeTopicCounts[type][oldTopic]--; tokensPerTopic[oldTopic]--; // Build a distribution over topics for this token Arrays.fill (topicWeights, 0.0); topicWeightsSum = 0; currentTypeTopicCounts = typeTopicCounts[type]; for (int ti = 0; ti < numTopics; ti++) { tw = ((currentTypeTopicCounts[ti] + beta) / (tokensPerTopic[ti] + vBeta)) * ((oneDocTopicCounts[ti] + alpha)); // (/docLen-1+tAlpha); is constant across all topics topicWeightsSum += tw; topicWeights[ti] = tw; } // Sample a topic assignment from this distribution newTopic = r.nextDiscrete (topicWeights, topicWeightsSum); // Put that new topic into the counts oneDocTopics[si] = newTopic; oneDocTopicCounts[newTopic]++; typeTopicCounts[type][newTopic]++; tokensPerTopic[newTopic]++; } } public void printTopWords (int numWords, boolean useNewLines) { class WordProb implements Comparable { int wi; double p; public WordProb (int wi, double p) { this.wi = wi; this.p = p; } public final int compareTo (Object o2) { if (p > ((WordProb)o2).p) return -1; else if (p == ((WordProb)o2).p) return 0; else return 1; } } WordProb[] wp = new WordProb[numTypes]; for (int ti = 0; ti < numTopics; ti++) { for (int wi = 0; wi < numTypes; wi++) wp[wi] = new WordProb (wi, ((double)typeTopicCounts[wi][ti]) / tokensPerTopic[ti]); Arrays.sort (wp); if (useNewLines) { System.out.println ("\nTopic "+ti); for (int i = 0; i < numWords; i++) System.out.println (ilist.getDataAlphabet().lookupObject(wp[i].wi).toString() + " " + wp[i].p); } else { System.out.print ("Topic "+ti+": "); for (int i = 0; i < numWords; i++) System.out.print (ilist.getDataAlphabet().lookupObject(wp[i].wi).toString() + " "); System.out.println(); } } } public void printDocumentTopics (File f) throws IOException { printDocumentTopics (new PrintWriter (new FileWriter (f))); } public void printDocumentTopics (PrintWriter pw) { pw.println ("#doc source topic proportions"); int docLen; for (int di = 0; di < topics.length; di++) { pw.print (di); pw.print (' '); docLen = topics[di].length; for (int ti = 0; ti < numTopics; ti++) pw.print (((float)docTopicCounts[di][ti])/docLen); pw.print (' '); pw.println (ilist.getInstance(di).getSource().toString()); pw.print (' '); } } public void printState (File f) throws IOException { printState (new PrintWriter (new FileWriter(f))); } public void printState (PrintWriter pw) { Alphabet a = ilist.getDataAlphabet(); pw.println ("#doc pos typeindex type topic"); for (int di = 0; di < topics.length; di++) { FeatureSequence fs = (FeatureSequence) ilist.getInstance(di).getData(); for (int si = 0; si < topics[di].length; si++) { int type = fs.getIndexAtPosition(si); pw.print(di); pw.print(' '); pw.print(si); pw.print(' '); pw.print(type); pw.print(' '); pw.print(a.lookupObject(type)); pw.print(' '); pw.print(topics[di][si]); pw.println(); } } } public void write (File f) { try { ObjectOutputStream oos = new ObjectOutputStream (new FileOutputStream(f)); oos.writeObject(this); oos.close(); } catch (IOException e) { System.err.println("Exception writing file " + f + ": " + e); } } // Serialization private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private static final int NULL_INTEGER = -1; private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); out.writeObject (ilist); out.writeInt (numTopics); out.writeDouble (alpha); out.writeDouble (beta); out.writeDouble (tAlpha); out.writeDouble (vBeta); for (int di = 0; di < topics.length; di ++) for (int si = 0; si < topics[di].length; si++) out.writeInt (topics[di][si]); for (int di = 0; di < topics.length; di ++) for (int ti = 0; ti < numTopics; ti++) out.writeInt (docTopicCounts[di][ti]); for (int fi = 0; fi < numTypes; fi++) for (int ti = 0; ti < numTopics; ti++) out.writeInt (typeTopicCounts[fi][ti]); for (int ti = 0; ti < numTopics; ti++) out.writeInt (tokensPerTopic[ti]); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int featuresLength; int version = in.readInt (); ilist = (InstanceList) in.readObject (); numTopics = in.readInt(); alpha = in.readDouble(); beta = in.readDouble(); tAlpha = in.readDouble(); vBeta = in.readDouble(); int numDocs = ilist.size(); topics = new int[numDocs][]; for (int di = 0; di < ilist.size(); di++) { int docLen = ((FeatureSequence)ilist.getInstance(di).getData()).getLength(); topics[di] = new int[docLen]; for (int si = 0; si < docLen; si++) topics[di][si] = in.readInt(); } docTopicCounts = new int[numDocs][numTopics]; for (int di = 0; di < ilist.size(); di++) for (int ti = 0; ti < numTopics; ti++) docTopicCounts[di][ti] = in.readInt(); int numTypes = ilist.getDataAlphabet().size(); typeTopicCounts = new int[numTypes][numTopics]; for (int fi = 0; fi < numTypes; fi++) for (int ti = 0; ti < numTopics; ti++) typeTopicCounts[fi][ti] = in.readInt(); tokensPerTopic = new int[numTopics]; for (int ti = 0; ti < numTopics; ti++) tokensPerTopic[ti] = in.readInt(); } // Recommended to use mallet/bin/vectors2topics instead. public static void main (String[] args) throws IOException { InstanceList ilist = InstanceList.load (new File(args[0])); int numIterations = args.length > 1 ? Integer.parseInt(args[1]) : 1000; int numTopWords = args.length > 2 ? Integer.parseInt(args[2]) : 20; System.out.println ("Data loaded."); LDA lda = new LDA (10); lda.estimate (ilist, numIterations, 50, 0, null, new Random()); // should be 1100 lda.printTopWords (numTopWords, true); lda.printDocumentTopics (new File(args[0]+".lda")); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -