📄 kmeanslearner.java
字号:
/* jahmm package - v0.3.1 *//* * Copyright (c) 2004, Jean-Marc Francois. * * This file is part of Jahmm. * Jahmm is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * Jahmm is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with Jahmm; if not, write to the Free Software * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA */package be.ac.ulg.montefiore.run.jahmm.learn;import java.util.*;import be.ac.ulg.montefiore.run.jahmm.*;/** * An implementation of the K-Means learning algorithm. */public class KMeansLearner { private Clusters clusters; private int nbStates; private Vector obsSeqs; private OpdfFactory opdfFactory; private boolean terminated; /** * Initializes a K-Means algorithm implementation. This algorithm * finds a HMM that models a set of observation sequences. * * @param nbStates The number of states the resulting HMM will be made of. * @param opdfFactory A class that builds the observation probability * distributions associated to the states of the HMM. * @param obsSeqs A vector of observation sequences. Each observation * sequences is a vector of * {@link be.ac.ulg.montefiore.run.jahmm.Observation * observations} compatible with the * {@link be.ac.ulg.montefiore.run.jahmm.KMeansCompatible * k-means}. */ public KMeansLearner(int nbStates, OpdfFactory opdfFactory, Vector obsSeqs) { this.obsSeqs = obsSeqs; this.opdfFactory = opdfFactory; this.nbStates = nbStates; Vector observations = flat(obsSeqs); clusters = new Clusters(nbStates, observations); terminated = false; } /** * Performs one iteration of the K-Means algorithm. * In one iteration, a new HMM is computed using the current clusters, and * the clusters are re-estimated using this HMM. * * @return A new, updated HMM. */ public Hmm iterate() { Hmm hmm = new Hmm(nbStates, opdfFactory); learnPi(hmm); learnAij(hmm); learnOpdf(hmm); terminated = optimizeCluster(hmm); return hmm; } /** * Returns <code>true</code> if the algorithm has reached a fix point, * else returns <code>false</code>. */ public boolean isTerminated() { return terminated; } /** * Does iterations of the K-Means algorithm until a fix point is reached. * * @return The HMM that best matches the set of observation sequences given * (according to the K-Means algorithm). */ public Hmm learn() { Hmm hmm; do hmm = iterate(); while(!isTerminated()); return hmm; } private void learnPi(Hmm hmm) { double[] pi = new double[nbStates]; for (int i = 0; i < nbStates; i++) pi[i] = 0.; for (int i = 0; i < obsSeqs.size(); i++) { Observation obs = (Observation) ((Vector) obsSeqs.elementAt(i)).elementAt(0); pi[clusters.clusterNb(obs)]++; } for (int i = 0; i < nbStates; i++) hmm.setPi(i, pi[i] / obsSeqs.size()); } private void learnAij(Hmm hmm) { for (int i = 0; i < hmm.nbStates(); i++) for (int j = 0; j < hmm.nbStates(); j++) hmm.setAij(i, j, 0.); for (int os = 0; os < obsSeqs.size(); os++) { Vector obsSeq = (Vector) obsSeqs.elementAt(os); if (obsSeq.size() < 2) continue; int first_state; int second_state = clusters.clusterNb((Observation) obsSeq.elementAt(0)); for (int i = 1; i < obsSeq.size(); i++) { first_state = second_state; second_state = clusters.clusterNb((Observation) obsSeq.elementAt(i)); hmm.setAij(first_state, second_state, hmm.getAij(first_state, second_state) + 1.); } } /* Normalize Aij array */ for (int i = 0; i < hmm.nbStates(); i++) { double sum = 0; for (int j = 0; j < hmm.nbStates(); j++) sum += hmm.getAij(i, j); if (sum == 0.) for (int j = 0; j < hmm.nbStates(); j++) hmm.setAij(i, j, 1. / hmm.nbStates()); /* Arbitrarily */ else for (int j = 0; j < hmm.nbStates(); j++) hmm.setAij(i, j, hmm.getAij(i, j) / sum); } } private void learnOpdf(Hmm hmm) { for (int i = 0; i < hmm.nbStates(); i++) { Vector obsVector = clusters.cluster(i); Observation[] obs = (Observation[]) obsVector.toArray(new Observation[obsVector.size()]); if (obs.length > 0) hmm.setOpdf(i, opdfFactory.fit(obs)); else hmm.setOpdf(i, opdfFactory.factor()); } } private boolean optimizeCluster(Hmm hmm) { boolean modif = false; for (int i = 0; i < obsSeqs.size(); i++) { Vector obsSeq = (Vector) obsSeqs.elementAt(i); ViterbiCalculator vc = new ViterbiCalculator(obsSeq, hmm); int states[] = vc.stateSequence(); for (int j = 0; j < states.length; j++) { Observation o = (Observation) obsSeq.elementAt(j); if (clusters.clusterNb(o) != states[j]) { modif = true; clusters.remove(o, states[j]); clusters.put(o, states[j]); } } } return !modif; } static Vector flat(Vector vectors) { Vector v = new Vector(); for (int i = 0; i < vectors.size(); i++) { Vector vector = (Vector) vectors.elementAt(i); for (int j = 0; j < vector.size(); j++) v.add(vector.elementAt(j)); } return v; }}/* * This class holds the matching between observations and clusters. */class Clusters { private Hashtable clustersHash; private Vector[] clusters; class Value { private int clusterNb; Value(int clusterNb) { this.clusterNb = clusterNb; } void setClusterNb(int clusterNb) { this.clusterNb = clusterNb; } int getClusterNb() { return clusterNb; } } public Clusters(int k, Vector observations) { clustersHash = new Hashtable(); clusters = new Vector[k]; KMeansCalculator kmc = new KMeansCalculator(k, observations); for (int i = 0; i < k; i++) { Vector cluster = kmc.cluster(i); clusters[i] = cluster; for (int j = 0; j < cluster.size(); j++) { clustersHash.put(cluster.elementAt(j), new Clusters.Value(i)); } } } public boolean isInCluster(Observation o, int clusterNb) { return clusterNb(o) == clusterNb; } public int clusterNb(Observation o) { return ((Clusters.Value) clustersHash.get(o)).getClusterNb(); } public Vector cluster(int clusterNb) { return clusters[clusterNb]; } public void remove(Observation o, int clusterNb) { ((Clusters.Value) clustersHash.get(o)).setClusterNb(-1); clusters[clusterNb].remove(o); } public void put(Observation o, int clusterNb) { ((Clusters.Value) clustersHash.get(o)).setClusterNb(clusterNb); clusters[clusterNb].add(o); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -