📄 hmm.java
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. *//** @author Aron Culotta <a href="mailto:culotta@cs.umass.edu">culotta@cs.umass.edu</a> @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a> */package edu.umass.cs.mallet.base.fst;import edu.umass.cs.mallet.base.types.*;import edu.umass.cs.mallet.base.pipe.Pipe;import edu.umass.cs.mallet.base.maximize.*;import edu.umass.cs.mallet.base.maximize.tests.*;import edu.umass.cs.mallet.base.util.Maths;import edu.umass.cs.mallet.base.util.MalletLogger;import java.util.ArrayList;import java.util.HashMap;import java.util.Iterator;import java.util.Arrays;import java.util.BitSet;import java.util.Random;import java.util.regex.*;import java.util.logging.*;import java.io.*;import java.lang.reflect.Constructor;/** Hidden Markov Model */public class HMM extends Transducer implements Serializable{ private static Logger logger = MalletLogger.getLogger(HMM.class.getName()); static final String LABEL_SEPARATOR = ","; Alphabet inputAlphabet; Alphabet outputAlphabet; ArrayList states = new ArrayList (); ArrayList initialStates = new ArrayList (); HashMap name2state = new HashMap (); Multinomial.Estimator [] transitionEstimator; Multinomial.Estimator [] emissionEstimator; Multinomial.Estimator initialEstimator; Multinomial [] transitionMultinomial; Multinomial [] emissionMultinomial; Multinomial initialMultinomial; boolean trainable = false; public HMM (Pipe inputPipe, Pipe outputPipe) { this.inputPipe = inputPipe; this.outputPipe = outputPipe; this.inputAlphabet = inputPipe.getDataAlphabet(); this.outputAlphabet = inputPipe.getTargetAlphabet(); } public HMM (Alphabet inputAlphabet, Alphabet outputAlphabet) { inputAlphabet.stopGrowth(); logger.info ("HMM input dictionary size = "+inputAlphabet.size()); this.inputAlphabet = inputAlphabet; this.outputAlphabet = outputAlphabet; } public Alphabet getInputAlphabet () { return inputAlphabet; } public Alphabet getOutputAlphabet () { return outputAlphabet; } public void print () { StringBuffer sb = new StringBuffer(); for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); sb.append ("STATE NAME=\""); sb.append (s.name); sb.append ("\" ("); sb.append (s.destinations.length); sb.append (" outgoing transitions)\n"); sb.append (" "); sb.append ("initialCost = "); sb.append (s.initialCost); sb.append ('\n'); sb.append (" "); sb.append ("finalCost = "); sb.append (s.finalCost); sb.append ('\n'); sb.append ("Emission distribution:\n" + emissionMultinomial[i] + "\n\n"); sb.append ("Transition distribution:\n" + transitionMultinomial[i].toString()); } System.out.println (sb.toString()); } public void addState (String name, double initialCost, double finalCost, String[] destinationNames, String[] labelNames) { assert (labelNames.length == destinationNames.length); setTrainable (false); if (name2state.get(name) != null) throw new IllegalArgumentException ("State with name `"+name+"' already exists."); State s = new State (name, states.size(), initialCost, finalCost, destinationNames, labelNames, this); s.print (); states.add (s); if (initialCost < INFINITE_COST) initialStates.add (s); name2state.put (name, s); } // Add a state with parameters equal zero, and labels on out-going arcs // the same name as their destination state names. public void addState (String name, String[] destinationNames) { this.addState (name, 0, 0, destinationNames, destinationNames); } // Add a group of states that are fully connected with each other, // with parameters equal zero, and labels on their out-going arcs // the same name as their destination state names. public void addFullyConnectedStates (String[] stateNames) { for (int i = 0; i < stateNames.length; i++) addState (stateNames[i], stateNames); } public void addFullyConnectedStatesForLabels () { String[] labels = new String[outputAlphabet.size()]; // This is assuming the the entries in the outputAlphabet are Strings! for (int i = 0; i < outputAlphabet.size(); i++) { labels[i] = (String) outputAlphabet.lookupObject(i); } addFullyConnectedStates (labels); } private boolean[][] labelConnectionsIn (InstanceList trainingSet) { int numLabels = outputAlphabet.size(); boolean[][] connections = new boolean[numLabels][numLabels]; for (int i = 0; i < trainingSet.size(); i++) { Instance instance = trainingSet.getInstance(i); FeatureSequence output = (FeatureSequence) instance.getTarget(); for (int j = 1; j < output.size(); j++) { int sourceIndex = outputAlphabet.lookupIndex (output.get(j-1)); int destIndex = outputAlphabet.lookupIndex (output.get(j)); assert (sourceIndex >= 0 && destIndex >= 0); connections[sourceIndex][destIndex] = true; } } return connections; } /** Add states to create a first-order Markov model on labels, adding only those transitions the occur in the given trainingSet. */ public void addStatesForLabelsConnectedAsIn (InstanceList trainingSet) { int numLabels = outputAlphabet.size(); boolean[][] connections = labelConnectionsIn (trainingSet); for (int i = 0; i < numLabels; i++) { int numDestinations = 0; for (int j = 0; j < numLabels; j++) if (connections[i][j]) numDestinations++; String[] destinationNames = new String[numDestinations]; int destinationIndex = 0; for (int j = 0; j < numLabels; j++) if (connections[i][j]) destinationNames[destinationIndex++] = (String)outputAlphabet.lookupObject(j); addState ((String)outputAlphabet.lookupObject(i), destinationNames); } } /** Add as many states as there are labels, but don't create separate weights for each source-destination pair of states. Instead have all the incoming transitions to a state share the same weights. */ public void addStatesForHalfLabelsConnectedAsIn (InstanceList trainingSet) { int numLabels = outputAlphabet.size(); boolean[][] connections = labelConnectionsIn (trainingSet); for (int i = 0; i < numLabels; i++) { int numDestinations = 0; for (int j = 0; j < numLabels; j++) if (connections[i][j]) numDestinations++; String[] destinationNames = new String[numDestinations]; int destinationIndex = 0; for (int j = 0; j < numLabels; j++) if (connections[i][j]) destinationNames[destinationIndex++] = (String)outputAlphabet.lookupObject(j); addState ((String)outputAlphabet.lookupObject(i), 0.0, 0.0, destinationNames, destinationNames); } } /** Add as many states as there are labels, but don't create separate observational-test-weights for each source-destination pair of states---instead have all the incoming transitions to a state share the same observational-feature-test weights. However, do create separate default feature for each transition, (which acts as an HMM-style transition probability). */ public void addStatesForThreeQuarterLabelsConnectedAsIn (InstanceList trainingSet) { int numLabels = outputAlphabet.size(); boolean[][] connections = labelConnectionsIn (trainingSet); for (int i = 0; i < numLabels; i++) { int numDestinations = 0; for (int j = 0; j < numLabels; j++) if (connections[i][j]) numDestinations++; String[] destinationNames = new String[numDestinations]; int destinationIndex = 0; for (int j = 0; j < numLabels; j++) if (connections[i][j]) { String labelName = (String)outputAlphabet.lookupObject(j); destinationNames[destinationIndex] = labelName; // The "transition" weights will include only the default feature String wn = (String)outputAlphabet.lookupObject(i) + "->" + (String)outputAlphabet.lookupObject(j); destinationIndex++; } addState ((String)outputAlphabet.lookupObject(i), 0.0, 0.0, destinationNames, destinationNames); } } public void addFullyConnectedStatesForThreeQuarterLabels (InstanceList trainingSet) { int numLabels = outputAlphabet.size(); for (int i = 0; i < numLabels; i++) { String[] destinationNames = new String[numLabels]; for (int j = 0; j < numLabels; j++) { String labelName = (String)outputAlphabet.lookupObject(j); destinationNames[j] = labelName; } addState ((String)outputAlphabet.lookupObject(i), 0.0, 0.0, destinationNames, destinationNames); } } public void addFullyConnectedStatesForBiLabels () { String[] labels = new String[outputAlphabet.size()]; // This is assuming the the entries in the outputAlphabet are Strings! for (int i = 0; i < outputAlphabet.size(); i++) { labels[i] = outputAlphabet.lookupObject(i).toString(); } for (int i = 0; i < labels.length; i++) { for (int j = 0; j < labels.length; j++) { String[] destinationNames = new String[labels.length]; for (int k = 0; k < labels.length; k++) destinationNames[k] = labels[j]+LABEL_SEPARATOR+labels[k]; addState (labels[i]+LABEL_SEPARATOR+labels[j], 0.0, 0.0, destinationNames, labels); } } } /** Add states to create a second-order Markov model on labels, adding only those transitions the occur in the given trainingSet. */ public void addStatesForBiLabelsConnectedAsIn (InstanceList trainingSet) { int numLabels = outputAlphabet.size(); boolean[][] connections = labelConnectionsIn (trainingSet); for (int i = 0; i < numLabels; i++) { for (int j = 0; j < numLabels; j++) { if (!connections[i][j]) continue; int numDestinations = 0; for (int k = 0; k < numLabels; k++) if (connections[j][k]) numDestinations++; String[] destinationNames = new String[numDestinations]; String[] labels = new String[numDestinations]; int destinationIndex = 0; for (int k = 0; k < numLabels; k++) if (connections[j][k]) { destinationNames[destinationIndex] = (String)outputAlphabet.lookupObject(j)+LABEL_SEPARATOR+(String)outputAlphabet.lookupObject(k); labels[destinationIndex] = (String)outputAlphabet.lookupObject(k); destinationIndex++; } addState ((String)outputAlphabet.lookupObject(i)+LABEL_SEPARATOR+ (String)outputAlphabet.lookupObject(j), 0.0, 0.0, destinationNames, labels); } } } public void addFullyConnectedStatesForTriLabels () { String[] labels = new String[outputAlphabet.size()]; // This is assuming the the entries in the outputAlphabet are Strings! for (int i = 0; i < outputAlphabet.size(); i++) { logger.info ("HMM: outputAlphabet.lookup class = "+ outputAlphabet.lookupObject(i).getClass().getName()); labels[i] = outputAlphabet.lookupObject(i).toString(); } for (int i = 0; i < labels.length; i++) { for (int j = 0; j < labels.length; j++) { for (int k = 0; k < labels.length; k++) { String[] destinationNames = new String[labels.length]; for (int l = 0; l < labels.length; l++) destinationNames[l] = labels[j]+LABEL_SEPARATOR+labels[k]+LABEL_SEPARATOR+labels[l]; addState (labels[i]+LABEL_SEPARATOR+labels[j]+LABEL_SEPARATOR+labels[k], 0.0, 0.0, destinationNames, labels); } } } } public void addSelfTransitioningStateForAllLabels (String name) { String[] labels = new String[outputAlphabet.size()]; String[] destinationNames = new String[outputAlphabet.size()]; for (int i = 0; i < outputAlphabet.size(); i++) { labels[i] = outputAlphabet.lookupObject(i).toString(); destinationNames[i] = name; } addState (name, 0.0, 0.0, destinationNames, labels); } private String concatLabels(String[] labels) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -