⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 simplemixturemodel.java

📁 dragontoolkit用于机器学习
💻 JAVA
字号:
package dragon.ir.topicmodel;

import dragon.ir.index.IndexReader;
import dragon.matrix.vector.DoubleVector;
import java.util.*;

/**
 * <p>Simple mixture model</p>
 * <p> </p>
 * <p>Copyright: Copyright (c) 2005</p>
 * <p>Company: IST, Drexel University</p>
 * @author Davis Zhou
 * @version 1.0
 */

public class SimpleMixtureModel extends AbstractTopicModel{
    protected DoubleVector bkgModel;
    protected double bkgCoefficient;

    public SimpleMixtureModel(IndexReader indexReader, double bkgCoefficient){
        this(indexReader,null,bkgCoefficient);
    }

    public SimpleMixtureModel(IndexReader indexReader, DoubleVector bkgModel, double bkgCoefficient) {
        super(indexReader);
        if(bkgModel==null)
            bkgModel=getBkgModel(indexReader);
        else
            bkgModel=bkgModel.copy();
        this.bkgModel =bkgModel;
        this.bkgCoefficient=bkgCoefficient;
        this.bkgModel.multiply(bkgCoefficient);
    }

    public boolean estimateModel(int topicNum){
        return estimateModel(null, topicNum);
    }

    public boolean estimateModel(int[] arrDoc, int topicNum){
        double[][] arrTempProb;
        double[] arrDocWeightSum;
        double bkgProb, themeProb, themeProbSum, docWeightSum, termProbSum, termProb;
        int[] arrIndex, arrFreq;
        int termIndex;
        int i,j,k,m;

        //initialization
        this.themeNum =topicNum;
        termNum=indexReader.getCollection().getTermNum();
        if(arrDoc==null)
            docNum=indexReader.getCollection().getDocNum();
        else
            docNum=arrDoc.length;

        arrThemeTerm=new double[themeNum][termNum];
        arrTempProb=new double[themeNum][termNum];
        arrDocTheme=new double[docNum][themeNum];
        arrDocWeightSum=new double[themeNum];

        //initialize parameters
        initialize(termNum, themeNum, docNum, arrThemeTerm, arrDocTheme);

        //compute coefficients for mixture components
        printStatus("Estimating the coefficients of simple mixture model...");
        for(k=0;k<iterations;k++){
            printStatus((new java.util.Date()).toString()+" Iteration #" + (k + 1));
            for(i=0;i<themeNum;i++)
                for (j = 0; j < termNum; j++)
                    arrTempProb[i][j] = 0;

            for (i = 0; i < docNum; i++) {
                if(arrDoc==null){
                    arrIndex = indexReader.getTermIndexList(i);
                    arrFreq = indexReader.getTermFrequencyList(i);
                }
                else{
                    arrIndex = indexReader.getTermIndexList(arrDoc[i]);
                    arrFreq = indexReader.getTermFrequencyList(arrDoc[i]);
                }
                for(m=0;m<themeNum;m++) arrDocWeightSum[m] = 0;

                for (j = 0; j < arrIndex.length; j++) {
                    termIndex=arrIndex[j];
                    themeProbSum=0;
                    for(m=0;m<themeNum;m++){
                        themeProbSum+=arrThemeTerm[m][termIndex]*arrDocTheme[i][m];
                    }
                    bkgProb=bkgModel.get(termIndex)/(themeProbSum*(1-bkgCoefficient)+bkgModel.get(termIndex));

                    for (m = 0; m <themeNum; m++) {
                        if(themeProbSum!=0)
                           // themeProb is Zd,w=j
                           themeProb=arrThemeTerm[m][termIndex]*arrDocTheme[i][m]/themeProbSum;
                        else
                            themeProb=0;
                        termProb =arrFreq[j]*themeProb;
                        arrDocWeightSum[m]+=termProb;
                        arrTempProb[m][termIndex]+=termProb*(1-bkgProb);
                    }
                }

                //update the doc-specific coefficient for each theme
                docWeightSum=0;
                for (m = 0; m < themeNum; m++)
                    docWeightSum+=arrDocWeightSum[m];
                if(docWeightSum>0){
                    for (m = 0; m < themeNum; m++)
                        arrDocTheme[i][m] = arrDocWeightSum[m] / docWeightSum;
                }
                else{
                    for (m = 0; m < themeNum; m++)
                        arrDocTheme[i][m] = 0;
                }
            }

            //update the generative model for each theme
            for(i=0;i<themeNum;i++){
                termProbSum=0;
                for(j=0;j<termNum;j++)
                    termProbSum+=arrTempProb[i][j];
                for(j=0;j<termNum;j++)
                    arrThemeTerm[i][j]=arrTempProb[i][j]/termProbSum;
            }
        }
        printStatus("");
        return true;
    }

    protected void initialize(int termNum, int themeNum, int docNum, double[][] arrModel, double[][] arrDocMembership){
        Random random;
        double termProb, docProb;
        int i, j;

        termProb=1.0/termNum;
        for(i=0;i<themeNum;i++)
            for(j=0;j<termNum;j++)
                arrModel[i][j]=termProb;

        random=new Random(seed);
        for(i=0;i<docNum;i++){
            docProb=0;
            for(j=0;j<themeNum;j++){
                arrDocMembership[i][j] = random.nextDouble();
                docProb+=arrDocMembership[i][j];
            }
            for(j=0;j<themeNum;j++)
                arrDocMembership[i][j]=arrDocMembership[i][j]/docProb;
        }
    }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -