📄 hmm.java
字号:
State s = (HMM.State) in.readObject(); initialStates.add(s); } name2state = (HashMap) in.readObject(); size = in.readInt(); if (size == NULL_INTEGER) { emissionEstimator = null; } else { emissionEstimator = new Multinomial.Estimator[size]; for(i=0; i< size; i++) { emissionEstimator[i] = (Multinomial.Estimator) in.readObject(); } } size = in.readInt(); if (size == NULL_INTEGER) { transitionEstimator = null; } else { transitionEstimator = new Multinomial.Estimator[size]; for(i=0; i< size; i++) { transitionEstimator[i] = (Multinomial.Estimator) in.readObject(); } } trainable = in.readBoolean(); } public static class State extends Transducer.State implements Serializable { // Parameters indexed by destination state, feature index String name; int index; String[] destinationNames; State[] destinations; String[] labels; HMM hmm; // No arg constructor so serialization works protected State() { super (); } protected State (String name, int index, double initialCost, double finalCost, String[] destinationNames, String[] labelNames, HMM hmm) { super (); assert (destinationNames.length == labelNames.length); this.name = name; this.index = index; this.initialCost = initialCost; this.finalCost = finalCost; this.destinationNames = new String[destinationNames.length]; this.destinations = new State[labelNames.length]; this.labels = new String[labelNames.length]; this.hmm = hmm; for (int i = 0; i < labelNames.length; i++) { // Make sure this label appears in our output Alphabet hmm.outputAlphabet.lookupIndex (labelNames[i]); this.destinationNames[i] = destinationNames[i]; this.labels[i] = labelNames[i]; } } public void print () { System.out.println ("State #"+index+" \""+name+"\""); System.out.println ("initialCost="+initialCost+", finalCost="+finalCost); System.out.println ("#destinations="+destinations.length); for (int i = 0; i < destinations.length; i++) System.out.println ("-> "+destinationNames[i]); } public State getDestinationState (int index) { State ret; if ((ret = destinations[index]) == null) { ret = destinations[index] = (State) hmm.name2state.get (destinationNames[index]); assert (ret != null) : index; } return ret; } public Transducer.TransitionIterator transitionIterator ( Sequence inputSequence, int inputPosition, Sequence outputSequence, int outputPosition) { if (inputPosition < 0 || outputPosition < 0) throw new UnsupportedOperationException ("Epsilon transitions not implemented."); if (inputSequence == null) throw new UnsupportedOperationException ("HMMs are generative models; but this is not yet implemented."); if (!(inputSequence instanceof FeatureSequence)) throw new UnsupportedOperationException ("HMMs currently expect Instances to have FeatureSequence data"); return new TransitionIterator ( this, (FeatureSequence)inputSequence, inputPosition, (outputSequence == null ? null : (String)outputSequence.get(outputPosition)), hmm); } public String getName () { return name; } public int getIndex () { return index; } public void incrementInitialCount (double count) { } public void incrementFinalCount (double count) { } // Serialization // For class State private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private static final int NULL_INTEGER = -1; private void writeObject (ObjectOutputStream out) throws IOException { int i, size; out.writeInt (CURRENT_SERIAL_VERSION); out.writeObject(name); out.writeInt(index); size = (destinationNames == null) ? NULL_INTEGER : destinationNames.length; out.writeInt(size); if (size != NULL_INTEGER) { for(i=0; i<size; i++){ out.writeObject(destinationNames[i]); } } size = (destinations == null) ? NULL_INTEGER : destinations.length; out.writeInt(size); if (size != NULL_INTEGER) { for(i=0; i<size;i++) { out.writeObject(destinations[i]); } } size = (labels == null) ? NULL_INTEGER : labels.length; out.writeInt(size); if (size != NULL_INTEGER) { for (i=0; i<size; i++) out.writeObject(labels[i]); } out.writeObject(hmm); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int size, i; int version = in.readInt (); name = (String) in.readObject(); index = in.readInt(); size = in.readInt(); if (size != NULL_INTEGER) { destinationNames = new String[size]; for (i=0; i<size; i++) { destinationNames[i] = (String) in.readObject(); } } else { destinationNames = null; } size = in.readInt(); if (size != NULL_INTEGER) { destinations = new State[size]; for (i=0; i<size; i++) { destinations[i] = (State) in.readObject(); } } else { destinations = null; } size = in.readInt(); if (size != NULL_INTEGER) { labels = new String[size]; for (i=0; i<size; i++) labels[i] = (String) in.readObject(); //inputAlphabet = (Alphabet) in.readObject(); //outputAlphabet = (Alphabet) in.readObject(); } else { labels = null; } hmm = (HMM) in.readObject(); } } protected static class TransitionIterator extends Transducer.TransitionIterator implements Serializable { State source; int index, nextIndex, inputPos; double[] costs; // -logProb // Eventually change this because we will have a more space-efficient // FeatureVectorSequence that cannot break out each FeatureVector FeatureSequence input; HMM hmm; public TransitionIterator (State source, FeatureSequence inputSeq, int inputPosition, String output, HMM hmm) { this.source = source; this.hmm = hmm; this.input = inputSeq; this.inputPos = inputPosition; this.costs = new double[source.destinations.length]; for (int transIndex = 0; transIndex < source.destinations.length; transIndex++) { if (output == null || output.equals(source.labels[transIndex])) { costs[transIndex] = 0; // xxx should this be emission of the _next_ observation?// double logEmissionProb = hmm.emissionMultinomial[source.getIndex()].logProbability (inputSeq.get (inputPosition)); double logEmissionProb = hmm.emissionMultinomial[transIndex].logProbability (inputSeq.get (inputPosition)); double logTransitionProb = hmm.transitionMultinomial[source.getIndex()].logProbability (source.destinationNames[transIndex]); // cost = -logProbability costs[transIndex] -= (logEmissionProb + logTransitionProb); assert (!Double.isNaN(costs[transIndex])); } else costs[transIndex] = INFINITE_COST; } nextIndex = 0; while (nextIndex < source.destinations.length && costs[nextIndex] == INFINITE_COST) nextIndex++; } public boolean hasNext () { return nextIndex < source.destinations.length; } public Transducer.State nextState () { assert (nextIndex < source.destinations.length); index = nextIndex; nextIndex++; while (nextIndex < source.destinations.length && costs[nextIndex] == INFINITE_COST) nextIndex++; return source.getDestinationState (index); } public Object getInput () { return input; } public Object getOutput () { return source.labels[index]; } public double getCost () { return costs[index]; } public Transducer.State getSourceState () { return source; } public Transducer.State getDestinationState () { return source.getDestinationState (index); } public void incrementCount (double count) { // xxx ?? want way to increment observation count and transition count separately// if (inputPos == 0) {// System.err.println ("Initial increment for " + source.destinationNames[index]);// hmm.initialEstimator.increment (source.destinationNames[index], 1.0);// }// else { //System.err.println ("Incrementing count for emission " + input.get (inputPos) + " from state " + source.getName() + " -> " + source.destinationNames[index]);// hmm.emissionEstimator[source.getIndex()].increment (hmm.inputAlphabet.lookupIndex (input.get (inputPos), false), 1.0); hmm.emissionEstimator[index].increment (hmm.inputAlphabet.lookupIndex (input.get (inputPos), false), 1.0); hmm.transitionEstimator[source.getIndex()].increment (source.destinationNames[index], 1.0);// } } // Serialization // TransitionIterator private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private static final int NULL_INTEGER = -1; private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); out.writeObject (source); out.writeInt (index); out.writeInt (nextIndex); out.writeInt (inputPos); if (costs != null) { out.writeInt (costs.length); for (int i = 0; i < costs.length; i++) { out.writeDouble (costs[i]); } } else { out.writeInt(NULL_INTEGER); } out.writeObject (input); out.writeObject(hmm); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt (); source = (State) in.readObject(); index = in.readInt (); nextIndex = in.readInt (); inputPos = in.readInt (); int size = in.readInt(); if (size == NULL_INTEGER) { costs = null; } else { costs = new double[size]; for (int i =0; i <size; i++) { costs[i] = in.readDouble(); } } input = (FeatureSequence) in.readObject(); hmm = (HMM) in.readObject(); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -