⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 baumwelchlearner.java

📁 java实现的隐马尔科夫模型
💻 JAVA
字号:
/* jahmm package - v0.3.1 *//* *  Copyright (c) 2004, Jean-Marc Francois. * *  This file is part of Jahmm. *  Jahmm is free software; you can redistribute it and/or modify *  it under the terms of the GNU General Public License as published by *  the Free Software Foundation; either version 2 of the License, or *  (at your option) any later version. * *  Jahmm is distributed in the hope that it will be useful, *  but WITHOUT ANY WARRANTY; without even the implied warranty of *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *  GNU General Public License for more details. * *  You should have received a copy of the GNU General Public License *  along with Jahmm; if not, write to the Free Software *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA */package be.ac.ulg.montefiore.run.jahmm.learn;import java.util.*;import be.ac.ulg.montefiore.run.jahmm.*;/** * An implementation of the Baum-Welch learning algorithm. */public class BaumWelchLearner {        static private final int NB_ITERATIONS = 9;    int nbStates;    Vector obsSeqs;    OpdfFactory opdfFactory;            /**     * Initializes a Baum-Welch algorithm implementation.  This algorithm     * finds a HMM that models a set of observation sequences.     *     * @param nbStates  The number of states the resulting HMM will be made of.     * @param opdfFactory A class that builds the observation probability     *                    distributions associated to the states of the HMM.     * @param obsSeqs A vector of observation sequences.  Each observation     *                sequences is a vector of     *                {@link be.ac.ulg.montefiore.run.jahmm.Observation     *                observations}.     */    public BaumWelchLearner(int nbStates, OpdfFactory opdfFactory,			 Vector obsSeqs) {	if (nbStates <= 0 || opdfFactory == null || obsSeqs.size() == 0)	    throw new IllegalArgumentException();		this.obsSeqs = obsSeqs;	this.opdfFactory = opdfFactory;	this.nbStates = nbStates;    }        /**     * Performs one iteration of the Baum-Welch algorithm.     * In one iteration, a new HMM is computed using a previously estimated HMM.     *     * @param hmm A previously estimated HMM.     * @return A new, updated HMM.     */    public Hmm iterate(Hmm hmm) {	Hmm nhmm = new Hmm(hmm.nbStates(), opdfFactory);	/* gamma and xi arrays are those defined by Rabiner and Juang */	/* allGamma[n] = gamma array associated to observation sequence n */	double allGamma[][][] = new double[obsSeqs.size()][][];		/* a[i][j] = aijNum[i][j] / aijDen[i] */	double aijNum[][] = new double[hmm.nbStates()][hmm.nbStates()];	double aijDen[] = new double[hmm.nbStates()];	Arrays.fill(aijDen, 0.);	for (int i = 0; i < hmm.nbStates(); i++)	    Arrays.fill(aijNum[i], 0.);		for (int o = 0; o < obsSeqs.size(); o++) {	    Vector obsSeq = (Vector) obsSeqs.elementAt(o);	    	    ForwardBackwardCalculator fbc = 		generateForwardBackwardCalculator(obsSeq, hmm);	    	    double xi[][][] = estimation_xi(obsSeq, fbc, hmm);	    double gamma[][] = allGamma[o] = estimation_gamma(xi, fbc);	    	    for (int i = 0; i < hmm.nbStates(); i++)		for (int t = 0; t < obsSeq.size() - 1; t++) {		    aijDen[i] += gamma[t][i];		    		    for (int j = 0; j < hmm.nbStates(); j++)			aijNum[i][j] += xi[t][i][j];		}	}		for (int i = 0; i < hmm.nbStates(); i++) 	    for (int j = 0; j < hmm.nbStates(); j++)		nhmm.setAij(i, j, aijNum[i][j] / aijDen[i]);			/* pi computation */	for (int i = 0; i < hmm.nbStates(); i++)	    nhmm.setPi(i, 0.);	for (int o = 0; o < obsSeqs.size(); o++)	    for (int i = 0; i < hmm.nbStates(); i++)		nhmm.setPi(i,			   nhmm.getPi(i) + allGamma[o][0][i] / obsSeqs.size());			/* pdfs computation */	for (int i = 0; i < hmm.nbStates(); i++) {	    Observation[] observations = (Observation[])		KMeansLearner.flat(obsSeqs).toArray(new Observation[1]);	    double[] weights = new double[observations.length];	    double sum = 0.;	    int j = 0;	    	    for (int o = 0; o < obsSeqs.size(); o++) {		Vector obsSeq = (Vector) obsSeqs.elementAt(o);				for (int t = 0; t < obsSeq.size(); t++, j++)		    sum += weights[j] = allGamma[o][t][i];	    }	    	    for (j--; j >= 0; j--)		weights[j] /= sum;	    	    nhmm.setOpdf(i, opdfFactory.fit(observations, weights));	}	    	return nhmm;    }        ForwardBackwardCalculator generateForwardBackwardCalculator(Vector obsSeq,								Hmm hmm) {	return new ForwardBackwardCalculator(obsSeq, hmm, 					     ForwardBackwardCalculator.					     COMPUTE_ALPHA |					     ForwardBackwardCalculator.					     COMPUTE_BETA);    }            /**     * Does a fixed number of iterations of the Baum-Welch algorithm.     *      * @param initialHmm An initial estimation of the expected HMM.  This     *                   estimate is critical as the Baum-Welch algorithm     *                   only find local minima of its likelihood function.     * @return The HMM that best matches the set of observation sequences given     *         (according to the Baum-Welch algorithm).     */    public Hmm learn(Hmm initialHmm) {	Hmm hmm = initialHmm;	for (int i = 0; i < NB_ITERATIONS; i++)	    hmm = iterate(hmm);	return hmm;    }        double[][][] estimation_xi(Vector obsSeq, ForwardBackwardCalculator fbc, 				Hmm hmm) {	if (obsSeq.size() <= 1)	    throw new IllegalArgumentException("Observation sequence too " + 					       "short");		double xi[][][] = 	    new double[obsSeq.size() - 1][hmm.nbStates()][hmm.nbStates()];	double probability = fbc.probability();		for (int t = 0; t < obsSeq.size() - 1; t++)	    for (int i = 0; i < hmm.nbStates(); i++)		for (int j = 0; j < hmm.nbStates(); j++)		    xi[t][i][j] = fbc.alphaElement(t, i) *			hmm.getAij(i, j) * 			hmm.getOpdf(j).probability((Observation)						   obsSeq.elementAt(t + 1)) *			fbc.betaElement(t + 1, j) / probability;		return xi;    }            /* gamma[][] could be computed directly using the alpha and beta       arrays, but this (slower) method is prefered because it doesn't       change if the xi array has been scaled (and should be changed with       the scaled alpha and beta arrays).    */    double[][] estimation_gamma(double[][][] xi, 				ForwardBackwardCalculator fbc) {	double[][] gamma = new double[xi.length + 1][xi[0].length];	for (int t = 0; t < xi.length + 1; t++)	    Arrays.fill(gamma[t], 0.);		for (int t = 0; t < xi.length; t++)	    for (int i = 0; i < xi[0].length; i++)		for (int j = 0; j < xi[0].length; j++)		    gamma[t][i] += xi[t][i][j];		for (int j = 0; j < xi[0].length; j++)	    for (int i = 0; i < xi[0].length; i++)		gamma[xi.length][j] += xi[xi.length - 1][i][j];		return gamma;    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -