📄 hmm.java
字号:
String sep = ""; StringBuffer buf = new StringBuffer(); for (int i = 0; i < labels.length; i++) { buf.append(sep).append(labels[i]); sep = LABEL_SEPARATOR; } return buf.toString(); } private String nextKGram(String[] history, int k, String next) { String sep = ""; StringBuffer buf = new StringBuffer(); int start = history.length + 1 - k; for (int i = start; i < history.length; i++) { buf.append(sep).append(history[i]); sep = LABEL_SEPARATOR; } buf.append(sep).append(next); return buf.toString(); } private boolean allowedTransition(String prev, String curr, Pattern no, Pattern yes) { String pair = concatLabels(new String[]{prev, curr}); if (no != null && no.matcher(pair).matches()) return false; if (yes != null && !yes.matcher(pair).matches()) return false; return true; } private boolean allowedHistory(String[] history, Pattern no, Pattern yes) { for (int i = 1; i < history.length; i++) if (!allowedTransition(history[i-1], history[i], no, yes)) return false; return true; } /** * Assumes that the HMM's output alphabet contains * <code>String</code>s. Creates an order-<em>n</em> HMM with input * predicates and output labels given by <code>trainingSet</code> * and order, connectivity, and weights given by the remaining * arguments. * * @param trainingSet the training instances * @param orders an array of increasing non-negative numbers giving * the orders of the features for this HMM. The largest number * <em>n</em> is the Markov order of the HMM. States are * <em>n</em>-tuples of output labels. Each of the other numbers * <em>k</em> in <code>orders</code> represents a weight set shared * by all destination states whose last (most recent) <em>k</em> * labels agree. If <code>orders</code> is <code>null</code>, an * order-0 HMM is built. * @param defaults If non-null, it must be the same length as * <code>orders</code>, with <code>true</code> positions indicating * that the weight set for the corresponding order contains only the * weight for a default feature; otherwise, the weight set has * weights for all features built from input predicates. * @param start The label that represents the context of the start of * a sequence. It may be also used for sequence labels. * @param forbidden If non-null, specifies what pairs of successive * labels are not allowed, both for constructing <em>n</em>order * states or for transitions. A label pair (<em>u</em>,<em>v</em>) * is not allowed if <em>u</em> + "," + <em>v</em> matches * <code>forbidden</code>. * @param allowed If non-null, specifies what pairs of successive * labels are allowed, both for constructing <em>n</em>order * states or for transitions. A label pair (<em>u</em>,<em>v</em>) * is allowed only if <em>u</em> + "," + <em>v</em> matches * <code>allowed</code>. * @param fullyConnected Whether to include all allowed transitions, * even those not occurring in <code>trainingSet</code>, * @returns The name of the start state. * */ public String addOrderNStates(InstanceList trainingSet, int[] orders, boolean[] defaults, String start, Pattern forbidden, Pattern allowed, boolean fullyConnected) { boolean[][] connections = null; if (!fullyConnected) connections = labelConnectionsIn (trainingSet); int order = -1; if (defaults != null && defaults.length != orders.length) throw new IllegalArgumentException("Defaults must be null or match orders"); if (orders == null) order = 0; else { for (int i = 0; i < orders.length; i++) if (orders[i] <= order) throw new IllegalArgumentException("Orders must be non-negative and in ascending order"); else order = orders[i]; if (order < 0) order = 0; } if (order > 0) { int[] historyIndexes = new int[order]; String[] history = new String[order]; String label0 = (String)outputAlphabet.lookupObject(0); for (int i = 0; i < order; i++) history[i] = label0; int numLabels = outputAlphabet.size(); while (historyIndexes[0] < numLabels) { logger.info("Preparing " + concatLabels(history)); if (allowedHistory(history, forbidden, allowed)) { String stateName = concatLabels(history); int nt = 0; String[] destNames = new String[numLabels]; String[] labelNames = new String[numLabels]; for (int nextIndex = 0; nextIndex < numLabels; nextIndex++) { String next = (String)outputAlphabet.lookupObject(nextIndex); if (allowedTransition(history[order-1], next, forbidden, allowed) && (fullyConnected || connections[historyIndexes[order-1]][nextIndex])) { destNames[nt] = nextKGram(history, order, next); labelNames[nt] = next; nt++; } } if (nt < numLabels) { String[] newDestNames = new String[nt]; String[] newLabelNames = new String[nt]; for (int t = 0; t < nt; t++) { newDestNames[t] = destNames[t]; newLabelNames[t] = labelNames[t]; } destNames = newDestNames; labelNames = newLabelNames; } addState (stateName, 0.0, 0.0, destNames, labelNames); } for (int o = order-1; o >= 0; o--) if (++historyIndexes[o] < numLabels) { history[o] = (String)outputAlphabet.lookupObject(historyIndexes[o]); break; } else if (o > 0) { historyIndexes[o] = 0; history[o] = label0; } } for (int i = 0; i < order; i++) history[i] = start; return concatLabels(history); } else { String[] stateNames = new String[outputAlphabet.size()]; for (int s = 0; s < outputAlphabet.size(); s++) stateNames[s] = (String)outputAlphabet.lookupObject(s); for (int s = 0; s < outputAlphabet.size(); s++) addState(stateNames[s], 0.0, 0.0, stateNames, stateNames); return start; } } public State getState (String name) { return (State) name2state.get(name); } public int numStates () { return states.size(); } public Transducer.State getState (int index) { return (Transducer.State) states.get(index); } public Iterator initialStateIterator () { return initialStates.iterator (); } public boolean isTrainable () { return trainable; } public void reset () { throw new UnsupportedOperationException ("Not used in HMMs"); } public void estimate () { if (!trainable) throw new IllegalStateException ("This transducer not currently trainable."); // xxx Put stuff in here. EM training. throw new UnsupportedOperationException ("Not yet implemented. Never?"); } public boolean train (InstanceList ilist) { return train (ilist, (InstanceList)null, (InstanceList)null); } public boolean train (InstanceList ilist, InstanceList validation, InstanceList testing) { return train (ilist, validation, testing, (TransducerEvaluator)null); } public boolean train (InstanceList ilist, InstanceList validation, InstanceList testing, TransducerEvaluator eval) { assert (ilist.size() > 0); if (emissionEstimator == null) { emissionEstimator = new Multinomial.LaplaceEstimator[numStates()]; transitionEstimator = new Multinomial.LaplaceEstimator[numStates()]; emissionMultinomial = new Multinomial[numStates()]; transitionMultinomial = new Multinomial[numStates()]; Alphabet transitionAlphabet = new Alphabet (); for (int i=0; i < numStates(); i++) transitionAlphabet.lookupIndex (((State)states.get(i)).getName(), true); for (int i=0; i < numStates(); i++) { emissionEstimator[i] = new Multinomial.LaplaceEstimator(inputAlphabet); transitionEstimator[i] = new Multinomial.LaplaceEstimator(transitionAlphabet); emissionMultinomial[i] = new Multinomial (getUniformArray (inputAlphabet.size()), inputAlphabet); transitionMultinomial[i] = new Multinomial (getUniformArray (transitionAlphabet.size()), transitionAlphabet); } initialEstimator = new Multinomial.LaplaceEstimator (transitionAlphabet); } for (int i=0; i < ilist.size(); i++) { Instance instance = ilist.getInstance(i); FeatureSequence input = (FeatureSequence) instance.getData(); FeatureSequence output = (FeatureSequence) instance.getTarget(); forwardBackward (input, output, true); } initialMultinomial = initialEstimator.estimate(); for (int i=0; i < numStates(); i++) { emissionMultinomial[i] = emissionEstimator[i].estimate(); transitionMultinomial[i] = transitionEstimator[i].estimate(); getState (i).setInitialCost (-initialMultinomial.logProbability (getState(i).getName())); } return true; } public void write (File f) { try { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f)); oos.writeObject(this); oos.close(); } catch (IOException e) { System.err.println("Exception writing file " + f + ": " + e); } } private double[] getUniformArray (int size) { double[] ret = new double[size]; for (int i=0; i < size; i++) ret[i] = 1.0 / (double)size; return ret; } // Serialization // For HMM class private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 1; static final int NULL_INTEGER = -1; /* Need to check for null pointers. */ private void writeObject (ObjectOutputStream out) throws IOException { int i, size; out.writeInt (CURRENT_SERIAL_VERSION); out.writeObject(inputPipe); out.writeObject(outputPipe); out.writeObject (inputAlphabet); out.writeObject (outputAlphabet); size = states.size(); out.writeInt(size); for (i = 0; i<size; i++) out.writeObject(states.get(i)); size = initialStates.size(); out.writeInt(size); for (i = 0; i <size; i++) out.writeObject(initialStates.get(i)); out.writeObject(name2state); if (emissionEstimator != null) { size = emissionEstimator.length; for (i=0; i<size; i++) out.writeObject(emissionEstimator[i]); } else out.writeInt(NULL_INTEGER); if (transitionEstimator != null) { size = transitionEstimator.length; for (i=0; i<size; i++) out.writeObject(transitionEstimator[i]); } else out.writeInt(NULL_INTEGER); out.writeBoolean(trainable); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int size, i; int version = in.readInt (); inputPipe = (Pipe) in.readObject(); outputPipe = (Pipe) in.readObject(); inputAlphabet = (Alphabet) in.readObject(); outputAlphabet = (Alphabet) in.readObject(); size = in.readInt(); states = new ArrayList(); for (i=0; i<size; i++) { State s = (HMM.State) in.readObject(); states.add(s); } size = in.readInt(); initialStates = new ArrayList(); for (i=0; i<size; i++) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -