📄 transducer.java
字号:
for (int i = 0; i < numStates; i++) { if (nodes[ip][i] == null || nodes[ip][i].alpha == INFINITE_COST) // Note that skipping here based on alpha means that beta values won't // be correct, but since alpha is infinite anyway, it shouldn't matter. continue; State s = getState(i); TransitionIterator iter = s.transitionIterator (input, ip, output, ip); while (iter.hasNext()) { State destination = iter.nextState(); if (logger.isLoggable (Level.FINE)) logger.fine ("Backward Lattice[inputPos="+ip +"][source="+s.getName() +"][dest="+destination.getName()+"]"); int j = destination.getIndex(); LatticeNode destinationNode = nodes[ip+1][j]; if (destinationNode != null) { double transitionCost = iter.getCost(); assert (!Double.isNaN(transitionCost)); // assert (transitionCost >= 0); Not necessarily double oldBeta = nodes[ip][i].beta; assert (!Double.isNaN(nodes[ip][i].beta)); nodes[ip][i].beta = sumNegLogProb (nodes[ip][i].beta, destinationNode.beta + transitionCost); assert (!Double.isNaN(nodes[ip][i].beta)) : "dest.beta="+destinationNode.beta+" trans="+transitionCost+" sum="+(destinationNode.beta+transitionCost) + " oldBeta="+oldBeta; // xis[ip][i][j] = nodes[ip][i].alpha + transitionCost + nodes[ip+1][j].beta - cost; assert (!Double.isNaN(nodes[ip][i].alpha)); assert (!Double.isNaN(transitionCost)); assert (!Double.isNaN(nodes[ip+1][j].beta)); assert (!Double.isNaN(cost)); if (increment || outputAlphabet != null) { double xi = nodes[ip][i].alpha + transitionCost + nodes[ip+1][j].beta - cost; double p = Math.exp(-xi); assert (p < INFINITE_COST && !Double.isNaN(p)) : "xis["+ip+"]["+i+"]["+j+"]="+-xi; if (increment) iter.incrementCount (p); if (outputAlphabet != null) { int outputIndex = outputAlphabet.lookupIndex (iter.getOutput(), false); assert (outputIndex >= 0); // xxx This assumes that "ip" == "op"! outputCounts[ip][outputIndex] += p; //System.out.println ("CRF Lattice outputCounts["+ip+"]["+outputIndex+"]+="+p); } } } } gammas[ip][i] = nodes[ip][i].alpha + nodes[ip][i].beta - cost; } } if (increment) for (int i = 0; i < numStates; i++) { double p = Math.exp(-gammas[0][i]); assert (p < INFINITE_COST && !Double.isNaN(p)); getState(i).incrementInitialCount (p); } if (outputAlphabet != null) { labelings = new LabelVector[latticeLength]; for (int ip = latticeLength-2; ip >= 0; ip--) { assert (Math.abs(1.0-DenseVector.sum (outputCounts[ip])) < 0.000001);; labelings[ip] = new LabelVector (outputAlphabet, outputCounts[ip]); } } } public double getCost () { assert (!Double.isNaN(cost)); return cost; } // No, this.cost is an "unnormalized cost" //public double getProbability () { return Math.exp (-cost); } public double getGammaCost (int inputPosition, State s) { return gammas[inputPosition][s.getIndex()]; } public double getGammaProbability (int inputPosition, State s) { return Math.exp (-gammas[inputPosition][s.getIndex()]); } public double getXiProbability (int ip, State s1, State s2) { if (xis == null) throw new IllegalStateException ("xis were not saved."); int i = s1.getIndex (); int j = s2.getIndex (); return Math.exp (-xis[ip][i][j]); } public double getXiCost (int ip, State s1, State s2) { if (xis == null) throw new IllegalStateException ("xis were not saved."); int i = s1.getIndex (); int j = s2.getIndex (); return xis[ip][i][j]; } public int length () { return latticeLength; } public double getAlpha (int ip, State s) { LatticeNode node = getLatticeNode (ip, s.getIndex ()); return node.alpha; } public double getBeta (int ip, State s) { LatticeNode node = getLatticeNode (ip, s.getIndex ()); return node.beta; } public LabelVector getLabelingAtPosition (int outputPosition) { if (labelings != null) return labelings[outputPosition]; return null; } // Q: We are a non-static inner class so this should be easy; but how? // A: By the following weird syntax -cas public Transducer getTransducer () { return Transducer.this; } // A container for some information about a particular input position and state private class LatticeNode { int inputPosition; // outputPosition not really needed until we deal with asymmetric epsilon. State state; Object output; double alpha = INFINITE_COST; double beta = INFINITE_COST; LatticeNode (int inputPosition, State state) { this.inputPosition = inputPosition; this.state = state; assert (this.alpha == INFINITE_COST); // xxx Remove this check } } } // end of class Lattice // ****************************************************************************** // CPAL - NEW "BEAM" Version of Forward Backward // ****************************************************************************** public BeamLattice forwardBackwardBeam (Sequence inputSequence) { return forwardBackwardBeam (inputSequence, null, false); } public BeamLattice forwardBackwardBeam (Sequence inputSequence, boolean increment) { return forwardBackwardBeam (inputSequence, null, increment); } public BeamLattice forwardBackwardBeam (Sequence inputSequence, Sequence outputSequence) { return forwardBackwardBeam (inputSequence, outputSequence, false); } public BeamLattice forwardBackwardBeam (Sequence inputSequence, Sequence outputSequence, boolean increment) { return forwardBackwardBeam (inputSequence, outputSequence, increment, null); } public BeamLattice forwardBackwardBeam (Sequence inputSequence, Sequence outputSequence, boolean increment, LabelAlphabet outputAlphabet) { return forwardBackwardBeam (inputSequence, outputSequence, increment, false, outputAlphabet); } public BeamLattice forwardBackwardBeam (Sequence inputSequence, Sequence outputSequence, boolean increment, boolean saveXis, LabelAlphabet outputAlphabet) { // xxx We don't do epsilon transitions for now assert (outputSequence == null || inputSequence.size() == outputSequence.size()); return new BeamLattice (inputSequence, outputSequence, increment, saveXis, outputAlphabet); } // culotta: interface for constrained lattice /** Create constrained lattice such that all paths pass through the the labeling of <code> requiredSegment </code> as indicated by <code> constrainedSequence </code> @param inputSequence input sequence @param outputSequence output sequence @param requiredSegment segment of sequence that must be labelled @param constrainedSequence lattice must have labels of this sequence from <code> requiredSegment.start </code> to <code> requiredSegment.end </code> correctly */ public BeamLattice forwardBackwardBeam (Sequence inputSequence, Sequence outputSequence, Segment requiredSegment, Sequence constrainedSequence) { if (constrainedSequence.size () != inputSequence.size ()) throw new IllegalArgumentException ("constrainedSequence.size [" + constrainedSequence.size () + "] != inputSequence.size [" + inputSequence.size () + "]"); // constraints tells the lattice which states must emit which // observations. positive values say all paths must pass through // this state index, negative values say all paths must _not_ // pass through this state index. 0 means we don't // care. initialize to 0. include 1 extra node for start state. int [] constraints = new int [constrainedSequence.size() + 1]; for (int c = 0; c < constraints.length; c++) constraints[c] = 0; for (int i=requiredSegment.getStart (); i <= requiredSegment.getEnd(); i++) { int si = stateIndexOfString ((String)constrainedSequence.get (i)); if (si == -1) logger.warning ("Could not find state " + constrainedSequence.get (i) + ". Check that state labels match startTages and inTags, and that all labels are seen in training data.");// throw new IllegalArgumentException ("Could not find state " + constrainedSequence.get(i) + ". Check that state labels match startTags and InTags."); constraints[i+1] = si + 1; } // set additional negative constraint to ensure state after // segment is not a continue tag // xxx if segment length=1, this actually constrains the sequence // to B-tag (B-tag)', instead of the intended constraint of B-tag // (I-tag)' // the fix below is unsafe, but will have to do for now. // FIXED BELOW/* String endTag = (String) constrainedSequence.get (requiredSegment.getEnd ()); if (requiredSegment.getEnd()+2 < constraints.length) { if (requiredSegment.getStart() == requiredSegment.getEnd()) { // segment has length 1 if (endTag.startsWith ("B-")) { endTag = "I" + endTag.substring (1, endTag.length()); } else if (!(endTag.startsWith ("I-") || endTag.startsWith ("0"))) throw new IllegalArgumentException ("Constrained Lattice requires that states are tagged in B-I-O format."); } int statei = stateIndexOfString (endTag); if (statei == -1) // no I- tag for this B- tag statei = stateIndexOfString ((String)constrainedSequence.get (requiredSegment.getStart ())); constraints[requiredSegment.getEnd() + 2] = - (statei + 1); }*/ if (requiredSegment.getEnd() + 2 < constraints.length) { // if String endTag = requiredSegment.getInTag().toString(); int statei = stateIndexOfString (endTag); if (statei == -1) throw new IllegalArgumentException ("Could not find state " + endTag + ". Check that state labels match startTags and InTags."); constraints[requiredSegment.getEnd() + 2] = - (statei + 1); } // printStates (); logger.fine ("Segment:\n" + requiredSegment.sequenceToString () + "\nconstrainedSequence:\n" + constrainedSequence + "\nConstraints:\n"); for (int i=0; i < constraints.length; i++) { logger.fine (constraints[i] + "\t"); } logger.fine (""); return forwardBackwardBeam (inputSequence, outputSequence, constraints); }/* public int stateIndexOfString (String s) { for (int i = 0; i < this.numStates(); i++) { String state = this.getState (i).getName(); if (state.equals (s)) return i; } return -1; } private void printStates () { for (int i = 0; i < this.numStates(); i++) logger.fine (i + ":" + this.getState (i).getName()); } public void print () { logger.fine ("Transducer "+this); printStates(); }*/ public BeamLattice forwardBackwardBeam (Sequence inputSequence, Sequence outputSequence, int [] constraints) { return new BeamLattice (inputSequence, outputSequence, false, null, constraints); } // Remove this method? // If "increment" is true, call incrementInitialCount, incrementFinalCount and incrementCount private BeamLattice forwardBackwardBeam (SequencePair inputOutputPair, boolean increment) { return this.forwardBackwardBeam (inputOutputPair.input(), inputOutputPair.output(), increment); } public class BeamLattice // CPAL - like Lattice but using max-product to get the viterbiPath { // "ip" == "input position", "op" == "output position", "i" == "state index" double cost; Sequence input, output; LatticeNode[][] nodes; // indexed by ip,i int latticeLength; int curBeamWidth; // CPAL - can be adapted if maximizer is confused // xxx Now that we are incrementing here directly, there isn't // necessarily a need to save all these arrays... // -log(probability) of being in state "i" at input position "ip" double[][] gammas; // indexed by ip,i double[][][] xis; // indexed by ip,i,j; saved only if saveXis is true; LabelVector labelings[]; // indexed by op, created only if "outputAlphabet" is non-null in constructor private LatticeNode getLatticeNode (int ip, int stateIndex) { if (nodes[ip][stateIndex] == null) nodes[ip][stateIndex] = new LatticeNode (ip, getState (stateIndex)); return nodes[ip][stateIndex]; } // You may pass null for output, meaning that the lattice // is not constrained to match the output protected BeamLattice (Sequence input, Sequence output, boolean increment) { this (input, output, increment, false, null); }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -