⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 twodimensiongibbslda.java

📁 dragontoolkit用于机器学习
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
package dragon.ir.topicmodel;

import dragon.ir.index.*;
import java.util.*;

/**
 * <p>LDA Gibbs sampling two dimensional topical model</p>
 * <p> </p>
 * <p>Copyright: Copyright (c) 2005</p>
 * <p>Company: IST, Drexel University</p>
 * @author Davis Zhou
 * @version 1.0
 */

public class TwoDimensionGibbsLDA extends AbstractTwoDimensionModel{
    private double alpha; //the prior for view distribution in each document;
    private double beta, wBeta; //the prior for word distribution in each view
    private double gamma0, gamma1; //gamma0:view-free, gamma1: view-specific
    private double delta, wDelta; //the prior for word disbribution in each view-free theme
    private double epsilon, wEpsilon; // the prior for word distribution in each view-specific theme
    private double rho; // the prior for theme distribution in each document;
    private int tokenNumO, tokenNumP;

    public TwoDimensionGibbsLDA(IndexReader viewIndexReader, IndexReader topicIndexReader,
                            double alpha, double beta, double gamma0,double gamma1, double delta, double epsilon, double rho) {
        super(viewIndexReader,topicIndexReader);
        this.seed=-1;
        this.iterations=1000;
        this.alpha =alpha;
        this.beta=beta;
        this.gamma0 =gamma0;
        this.delta=delta;
        this.epsilon=epsilon;
        this.rho=rho;
        tokenNumO=(int)viewIndexReader.getCollection().getTermCount();
        tokenNumP=(int)viewIndexReader.getCollection().getTermCount();
        wBeta=beta*(viewTermNum+themeTermNum);
        wDelta=delta*themeTermNum;
        wEpsilon=epsilon*themeTermNum;
    }

    public boolean estimateModel(int viewNum, int topicNum){
        int[] arrTermO, arrDocO, arrTermP, arrDocP;
        int[] arrZO, arrZP, arrY, arrX;
        int[][] arrTVCountO, arrTVCountP, arrDVCount, arrTTCount;
        int[][][] arrTTViewCount, arrTTComCount, arrDTCount;
        int i, j, k;
        double sum;

        this.viewNum =viewNum;
        this.themeNum =topicNum;
        arrTermO=new int[tokenNumO];
        arrDocO=new int[tokenNumO];
        arrTermP=new int[tokenNumP];
        arrDocP=new int[tokenNumP];
        arrZO=new int[tokenNumO];
        arrZP=new int[tokenNumP];
        arrY=new int[tokenNumP];
        arrX=new int[tokenNumP];

        arrTVCountO=new int[viewTermNum][viewNum];
        arrTVCountP=new int[themeTermNum][viewNum];
        arrDVCount=new int[docNum][viewNum];
        arrTTCount=new int[themeTermNum][themeNum];
        arrTTViewCount= new int[themeTermNum][viewNum][themeNum];
        arrTTComCount=new int[themeTermNum][viewNum][themeNum];
        arrDTCount=new int[docNum][viewNum][themeNum];


        arrViewProb=new double[viewNum][viewTermNum];
        arrDocView = new double[docNum][viewNum];
        arrThemeProb = new double[viewNum][themeNum][themeTermNum];
        arrDocTheme = new double[docNum][viewNum][themeNum];
        arrCommonThemeProb = new double[themeNum][themeTermNum];

        //read sequence from indexing files
        readSequence(viewIndexReader,arrTermO,arrDocO);
        readSequence(topicIndexReader,arrTermP,arrDocP);

        //run one makov chain
        run(seed,arrTermO, arrDocO, arrZO, arrTVCountO, arrTermP,arrDocP,arrZP, arrTVCountP, arrDVCount,
            arrY, arrX, arrTTViewCount, arrTTComCount, arrTTCount, arrDTCount);

        //estimate word distrubtion in each view
        for(i=0;i<viewNum;i++){
            sum=beta*viewTermNum;
            for(j=0;j<viewTermNum;j++)
                sum+=arrTVCountO[j][i];
            for(j=0;j<viewTermNum;j++)
                arrViewProb[i][j]=(arrTVCountO[j][i]+beta)/sum;
        }

        //estimate view distribution in each document
        for(i=0;i<docNum;i++){
            sum=viewNum*alpha;
            for(j=0;j<viewNum;j++)
                sum+=arrDVCount[i][j];
            for(j=0;j<viewNum;j++)
                arrDocView[i][j]=(arrDVCount[i][j]+alpha)/sum;
        }

        //estimate word distribution in each common topic
        for(i=0;i<themeNum;i++){
            sum=delta*themeTermNum;
            for(j=0;j<themeTermNum;j++)
                sum+=arrTTCount[j][i];
            for(j=0;j<themeTermNum;j++)
                arrCommonThemeProb[i][j]=(arrTTCount[j][i]+delta)/sum;
        }

        //estimate word distribution in each view-specific topic
        for(k=0;k<viewNum;k++){
            for (i = 0; i < themeNum; i++) {
                sum = epsilon * themeTermNum;
                for (j = 0; j < themeTermNum; j++)
                    sum += arrTTViewCount[j][k][i];
                for (j = 0; j < themeTermNum; j++)
                    arrThemeProb[k][i][j] = (arrTTViewCount[j][k][i] + epsilon) / sum;
            }
        }

        //estimate theme distribution in each document
        for(k=0;k<docNum;k++){
            for (i = 0; i < viewNum; i++) {
                sum = rho * themeNum;
                for (j = 0; j < themeNum; j++)
                    sum += arrDTCount[k][i][j];
                for (j = 0; j < themeNum; j++)
                    arrDocTheme[k][i][j] = (arrDTCount[k][i][j] + rho) / sum;
            }
        }

        return true;
    }

    private void run(int seed, int[] arrTermO, int[] arrDocO, int[] arrZO, int[][] arrTVCountO,
                     int[] arrTermP, int[] arrDocP, int[] arrZP, int[][] arrTVCountP, int[][] arrDVCount,
                     int[] arrY, int[] arrX, int[][][] arrTTViewCount, int[][][] arrTTComCount, int[][] arrTTCount, int[][][] arrDTCount)
    {
        Random random;
        int[] arrViewCount,arrTopicCount, arrOrderO, arrOrderP;
        int[][] arrViewTopicCount;
        int termIndex, docIndex, view, topic,status;
        int i, j, k,iter;

        //initiliazation
        arrViewCount=new int[viewNum];
        arrTopicCount=new int[themeNum];
        arrViewTopicCount=new int[viewNum][themeNum];
        arrOrderO=new int[tokenNumO];
        arrOrderP=new int[tokenNumP];
        random=new Random();
        if(seed>=0)
            random.setSeed(seed);

        printStatus(new java.util.Date().toString()+" Starting random initialization...");
        for(i=0;i<tokenNumO;i++){
            view=random.nextInt(viewNum);
            arrZO[i]=view;
            termIndex=arrTermO[i];
            docIndex=arrDocO[i];
            arrTVCountO[termIndex][view]++;
            arrDVCount[docIndex][view]++;
            arrViewCount[view]++;
        }
        for(i=0;i<tokenNumP;i++){
            view=random.nextInt(viewNum);
            arrZP[i]=view;
            termIndex=arrTermP[i];
            docIndex=arrDocP[i];
            arrTVCountP[termIndex][view]++;
            arrDVCount[docIndex][view]++;
            arrViewCount[view]++;
        }

        for(i=0;i<tokenNumP;i++){
            topic=random.nextInt(themeNum);
            view=arrZP[i];
            arrY[i]=topic;
            termIndex=arrTermP[i];
            docIndex=arrDocP[i];
            arrTTViewCount[termIndex][view][topic]++;
            arrDTCount[docIndex][view][topic]++;
            arrViewTopicCount[view][topic]++;
            arrX[i]=1; //initially all property words are view-spcific
        }

        //Determine random update sequence
        printStatus(new java.util.Date().toString()+" Determining random update sequence...");
        for (i=0; i<tokenNumO; i++) arrOrderO[i]=i;
        for (i = 0; i < (tokenNumO - 1); i++) {
            // pick a random integer between i and n
            k = i +random.nextInt(tokenNumO-i);
            // switch contents on position i and position k
            j = arrOrderO[k];
            arrOrderO[k] = arrOrderO[i];
            arrOrderO[i] = j;
        }

        for (i=0; i<tokenNumP; i++) arrOrderP[i]=i;
        for (i = 0; i < (tokenNumP - 1); i++) {
            // pick a random integer between i and n
            k = i +random.nextInt(tokenNumP-i);
            // switch contents on position i and position k
            j = arrOrderP[k];
            arrOrderP[k] = arrOrderP[i];
            arrOrderP[i] = j;
        }

        for (iter = 0; iter <iterations; iter++) {
            printStatus(new java.util.Date().toString()+" Iteration #"+(iter+1));

            //sampling words in the major dimension
            for (k = 0; k <tokenNumO; k++) {
                i = arrOrderO[k]; // current word token to assess
                termIndex = arrTermO[i];
                docIndex =arrDocO[i];
                view = arrZO[i];

                // substract the current instance from counts
                arrViewCount[view]--;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -