⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 kmeanslearner.java

📁 java实现的隐马尔科夫模型
💻 JAVA
字号:
/* jahmm package - v0.3.1 *//* *  Copyright (c) 2004, Jean-Marc Francois. * *  This file is part of Jahmm. *  Jahmm is free software; you can redistribute it and/or modify *  it under the terms of the GNU General Public License as published by *  the Free Software Foundation; either version 2 of the License, or *  (at your option) any later version. * *  Jahmm is distributed in the hope that it will be useful, *  but WITHOUT ANY WARRANTY; without even the implied warranty of *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *  GNU General Public License for more details. * *  You should have received a copy of the GNU General Public License *  along with Jahmm; if not, write to the Free Software *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA */package be.ac.ulg.montefiore.run.jahmm.learn;import java.util.*;import be.ac.ulg.montefiore.run.jahmm.*;/** * An implementation of the K-Means learning algorithm. */public class KMeansLearner {    private Clusters clusters;    private int nbStates;    private Vector obsSeqs;    private OpdfFactory opdfFactory;    private boolean terminated;        /**     * Initializes a K-Means algorithm implementation.  This algorithm     * finds a HMM that models a set of observation sequences.     *     * @param nbStates  The number of states the resulting HMM will be made of.     * @param opdfFactory A class that builds the observation probability     *                    distributions associated to the states of the HMM.     * @param obsSeqs A vector of observation sequences.  Each observation     *                sequences is a vector of     *                {@link be.ac.ulg.montefiore.run.jahmm.Observation     *                observations} compatible with the      *                {@link be.ac.ulg.montefiore.run.jahmm.KMeansCompatible     *                k-means}.     */    public KMeansLearner(int nbStates, OpdfFactory opdfFactory,			 Vector obsSeqs) {	this.obsSeqs = obsSeqs;	this.opdfFactory = opdfFactory;	this.nbStates = nbStates;		Vector observations = flat(obsSeqs);	clusters = new Clusters(nbStates, observations);	terminated = false;    }    /**     * Performs one iteration of the K-Means algorithm.     * In one iteration, a new HMM is computed using the current clusters, and     * the clusters are re-estimated using this HMM.     *     * @return A new, updated HMM.     */    public Hmm iterate() {	Hmm hmm = new Hmm(nbStates, opdfFactory);	learnPi(hmm);	learnAij(hmm);	learnOpdf(hmm);	terminated = optimizeCluster(hmm);	return hmm;    }    /**     * Returns <code>true</code> if the algorithm has reached a fix point,     * else returns <code>false</code>.     */    public boolean isTerminated() {	return terminated;    }    /**     * Does iterations of the K-Means algorithm until a fix point is reached.     *      * @return The HMM that best matches the set of observation sequences given     *         (according to the K-Means algorithm).     */    public Hmm learn() {	Hmm hmm;	do 	    hmm = iterate();	while(!isTerminated());	return hmm;    }    private void learnPi(Hmm hmm) {	double[] pi = new double[nbStates];	for (int i = 0; i < nbStates; i++)	    pi[i] = 0.;		for (int i = 0; i < obsSeqs.size(); i++) {	    Observation obs = (Observation)		((Vector) obsSeqs.elementAt(i)).elementAt(0);	    	    pi[clusters.clusterNb(obs)]++;	}		for (int i = 0; i < nbStates; i++)	    hmm.setPi(i, pi[i] / obsSeqs.size());    }        private void learnAij(Hmm hmm) {	for (int i = 0; i < hmm.nbStates(); i++)	    for (int j = 0; j < hmm.nbStates(); j++)		hmm.setAij(i, j, 0.);	for (int os = 0; os < obsSeqs.size(); os++) {	    Vector obsSeq = (Vector) obsSeqs.elementAt(os);	    	    if (obsSeq.size() < 2)		continue;	    	    int first_state;	    int second_state = 		clusters.clusterNb((Observation) obsSeq.elementAt(0));	    for (int i = 1; i < obsSeq.size(); i++) {		first_state = second_state;		second_state =		    clusters.clusterNb((Observation) obsSeq.elementAt(i));				hmm.setAij(first_state, second_state,			   hmm.getAij(first_state, second_state) + 1.);	    }	}	/* Normalize Aij array */	for (int i = 0; i < hmm.nbStates(); i++) {	    double sum = 0;	    	    for (int j = 0; j < hmm.nbStates(); j++)		sum += hmm.getAij(i, j);	    	    if (sum == 0.)		for (int j = 0; j < hmm.nbStates(); j++) 		    hmm.setAij(i, j, 1. / hmm.nbStates());     /* Arbitrarily */	    else		for (int j = 0; j < hmm.nbStates(); j++)		    hmm.setAij(i, j, hmm.getAij(i, j) / sum);	}    }    private void learnOpdf(Hmm hmm) {	for (int i = 0; i < hmm.nbStates(); i++) {	    Vector obsVector = clusters.cluster(i);	    Observation[] obs = (Observation[]) 		obsVector.toArray(new Observation[obsVector.size()]);	    	    if (obs.length > 0)		hmm.setOpdf(i, opdfFactory.fit(obs));	    else		hmm.setOpdf(i, opdfFactory.factor());	}    }        private boolean optimizeCluster(Hmm hmm) {	boolean modif = false;	for (int i = 0; i < obsSeqs.size(); i++) {	    Vector obsSeq = (Vector) obsSeqs.elementAt(i);	    	    ViterbiCalculator vc = new ViterbiCalculator(obsSeq, hmm);	    int states[] = vc.stateSequence();	    	    for (int j = 0; j < states.length; j++) {		Observation o = (Observation) obsSeq.elementAt(j);				if (clusters.clusterNb(o) != states[j]) {		    modif = true;		    clusters.remove(o, states[j]);		    clusters.put(o, states[j]);		}	    }	}	return !modif;    }        static Vector flat(Vector vectors) {	Vector v = new Vector();	for (int i = 0; i < vectors.size(); i++) {	    Vector vector = (Vector) vectors.elementAt(i);	    	    for (int j = 0; j < vector.size(); j++)		v.add(vector.elementAt(j));	}	return v;    }}/* * This class holds the matching between observations and clusters. */class Clusters {    private Hashtable clustersHash;    private Vector[] clusters;            class Value {	private int clusterNb;	Value(int clusterNb) {	    this.clusterNb = clusterNb;	}	void setClusterNb(int clusterNb) {	    this.clusterNb = clusterNb;	}	int getClusterNb() {	    return clusterNb;	}    }        public Clusters(int k, Vector observations) {	clustersHash = new Hashtable();	clusters = new Vector[k];	KMeansCalculator kmc = new KMeansCalculator(k, observations);	for (int i = 0; i < k; i++) {	    Vector cluster = kmc.cluster(i);	    clusters[i] = cluster;	    	    for (int j = 0; j < cluster.size(); j++) {		clustersHash.put(cluster.elementAt(j), new Clusters.Value(i));	    }	}    }        public boolean isInCluster(Observation o, int clusterNb) {	return clusterNb(o) == clusterNb;    }        public int clusterNb(Observation o) {	return ((Clusters.Value) clustersHash.get(o)).getClusterNb();    }    public Vector cluster(int clusterNb) {	return clusters[clusterNb];    }    public void remove(Observation o, int clusterNb) {	((Clusters.Value) clustersHash.get(o)).setClusterNb(-1);	clusters[clusterNb].remove(o);    }        public void put(Observation o, int clusterNb) {	((Clusters.Value) clustersHash.get(o)).setClusterNb(clusterNb);	clusters[clusterNb].add(o);    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -