📄 transducer.java
字号:
{ if (false && logger.isLoggable (Level.FINE)) { logger.fine ("Starting Lattice"); logger.fine ("Input: "); for (int ip = 0; ip < input.size(); ip++) logger.fine (" " + input.get(ip)); logger.fine ("\nOutput: "); if (output == null) logger.fine ("null"); else for (int op = 0; op < output.size(); op++) logger.fine (" " + output.get(op)); logger.fine ("\n"); } // Initialize some structures this.input = input; this.output = output; // xxx Not very efficient when the lattice is actually sparse, // especially when the number of states is large and the // sequence is long. latticeLength = input.size()+1; int numStates = numStates(); nodes = new LatticeNode[latticeLength][numStates]; // xxx Yipes, this could get big; something sparse might be better? gammas = new double[latticeLength][numStates]; // xxx Move this to an ivar, so we can save it? But for what? // Commenting this out, because it's a memory hog and not used right now. // Uncomment and conditionalize under a flag if ever needed. -cas // double xis[][][] = new double[latticeLength][numStates][numStates]; double outputCounts[][] = null; if (outputAlphabet != null) outputCounts = new double[latticeLength][outputAlphabet.size()]; for (int i = 0; i < numStates; i++) { for (int ip = 0; ip < latticeLength; ip++) gammas[ip][i] = INFINITE_COST; /* Commenting out xis -cas for (int j = 0; j < numStates; j++) for (int ip = 0; ip < latticeLength; ip++) xis[ip][i][j] = INFINITE_COST; */ } // Forward pass logger.fine ("Starting Constrained Foward pass"); // ensure that at least one state has initial cost less than Infinity // so we can start from there boolean atLeastOneInitialState = false; for (int i = 0; i < numStates; i++) { double initialCost = getState(i).initialCost; //System.out.println ("Forward pass initialCost = "+initialCost); if (initialCost < INFINITE_COST) { getLatticeNode(0, i).alpha = initialCost; //System.out.println ("nodes[0][i].alpha="+nodes[0][i].alpha); atLeastOneInitialState = true; } } if (atLeastOneInitialState == false) logger.warning ("There are no starting states!"); for (int ip = 0; ip < latticeLength-1; ip++) for (int i = 0; i < numStates; i++) { logger.fine ("ip=" + ip+", i=" + i); // check if this node is possible at this <position, // label>. if not, skip it. if (constraints[ip] > 0) { // must be in state indexed by constraints[ip] - 1 if (constraints[ip]-1 != i) { logger.fine ("Current state does not match positive constraint. position="+ip+", constraint="+(constraints[ip]-1)+", currState="+i); continue; } } else if (constraints[ip] < 0) { // must _not_ be in state indexed by constraints[ip] if (constraints[ip]+1 == -i) { logger.fine ("Current state does not match negative constraint. position="+ip+", constraint="+(constraints[ip]+1)+", currState="+i); continue; } } if (nodes[ip][i] == null || nodes[ip][i].alpha == INFINITE_COST) { // xxx if we end up doing this a lot, // we could save a list of the non-null ones if (nodes[ip][i] == null) logger.fine ("nodes[ip][i] is NULL"); else if (nodes[ip][i].alpha == INFINITE_COST) logger.fine ("nodes[ip][i].alpha is Inf"); logger.fine ("INFINITE cost or NULL...skipping"); continue; } State s = getState(i); TransitionIterator iter = s.transitionIterator (input, ip, output, ip); if (logger.isLoggable (Level.FINE)) logger.fine (" Starting Forward transition iteration from state " + s.getName() + " on input " + input.get(ip).toString() + " and output " + (output==null ? "(null)" : output.get(ip).toString())); while (iter.hasNext()) { State destination = iter.nextState(); boolean legalTransition = true; // check constraints to see if node at <ip,i> can transition to destination if (ip+1 < constraints.length && constraints[ip+1] > 0 && ((constraints[ip+1]-1) != destination.getIndex())) { logger.fine ("Destination state does not match positive constraint. Assigning infinite cost. position="+(ip+1)+", constraint="+(constraints[ip+1]-1)+", source ="+i+", destination="+destination.getIndex()); legalTransition = false; } else if (((ip+1) < constraints.length) && constraints[ip+1] < 0 && (-(constraints[ip+1]+1) == destination.getIndex())) { logger.fine ("Destination state does not match negative constraint. Assigning infinite cost. position="+(ip+1)+", constraint="+(constraints[ip+1]+1)+", destination="+destination.getIndex()); legalTransition = false; } if (logger.isLoggable (Level.FINE)) logger.fine ("Forward Lattice[inputPos="+ip +"][source="+s.getName() +"][dest="+destination.getName()+"]"); LatticeNode destinationNode = getLatticeNode (ip+1, destination.getIndex()); destinationNode.output = iter.getOutput(); double transitionCost = iter.getCost(); if (legalTransition) { //if (logger.isLoggable (Level.FINE)) logger.fine ("transitionCost="+transitionCost +" nodes["+ip+"]["+i+"].alpha="+nodes[ip][i].alpha +" destinationNode.alpha="+destinationNode.alpha); destinationNode.alpha = sumNegLogProb (destinationNode.alpha, nodes[ip][i].alpha + transitionCost); //System.out.println ("destinationNode.alpha <- "+destinationNode.alpha); logger.fine ("Set alpha of latticeNode at ip = "+ (ip+1) + " stateIndex = " + destination.getIndex() + ", destinationNode.alpha = " + destinationNode.alpha); } else { // this is an illegal transition according to our // constraints, so set its prob to 0 . NO, alpha's are // unnormalized costs...set to Inf // // destinationNode.alpha = 0.0;// destinationNode.alpha = INFINITE_COST; logger.fine ("Illegal transition from state " + i + " to state " + destination.getIndex() + ". Setting alpha to Inf"); } } } // Calculate total cost of Lattice. This is the normalizer cost = INFINITE_COST; for (int i = 0; i < numStates; i++) if (nodes[latticeLength-1][i] != null) { // Note: actually we could sum at any ip index, // the choice of latticeLength-1 is arbitrary //System.out.println ("Ending alpha, state["+i+"] = "+nodes[latticeLength-1][i].alpha); //System.out.println ("Ending beta, state["+i+"] = "+getState(i).finalCost); if (constraints[latticeLength-1] > 0 && i != constraints[latticeLength-1]-1) continue; if (constraints[latticeLength-1] < 0 && -i == constraints[latticeLength-1]+1) continue; logger.fine ("Summing final lattice cost. state="+i+", alpha="+nodes[latticeLength-1][i].alpha + ", final cost = "+getState(i).finalCost); cost = sumNegLogProb (cost, (nodes[latticeLength-1][i].alpha + getState(i).finalCost)); } // Cost is now an "unnormalized cost" of the entire Lattice //assert (cost >= 0) : "cost = "+cost; // If the sequence has infinite cost, just return. // Usefully this avoids calling any incrementX methods. // It also relies on the fact that the gammas[][] and .alpha and .beta values // are already initialized to values that reflect infinite cost // xxx Although perhaps not all (alphas,betas) exactly correctly reflecting? if (cost == INFINITE_COST) return; // Backward pass for (int i = 0; i < numStates; i++) if (nodes[latticeLength-1][i] != null) { State s = getState(i); nodes[latticeLength-1][i].beta = s.finalCost; gammas[latticeLength-1][i] = nodes[latticeLength-1][i].alpha + nodes[latticeLength-1][i].beta - cost; if (increment) { double p = Math.exp(-gammas[latticeLength-1][i]); assert (p < INFINITE_COST && !Double.isNaN(p)) : "p="+p+" gamma="+gammas[latticeLength-1][i]; s.incrementFinalCount (p); } } for (int ip = latticeLength-2; ip >= 0; ip--) { for (int i = 0; i < numStates; i++) { if (nodes[ip][i] == null || nodes[ip][i].alpha == INFINITE_COST) // Note that skipping here based on alpha means that beta values won't // be correct, but since alpha is infinite anyway, it shouldn't matter. continue; State s = getState(i); TransitionIterator iter = s.transitionIterator (input, ip, output, ip); while (iter.hasNext()) { State destination = iter.nextState(); if (logger.isLoggable (Level.FINE)) logger.fine ("Backward Lattice[inputPos="+ip +"][source="+s.getName() +"][dest="+destination.getName()+"]"); int j = destination.getIndex(); LatticeNode destinationNode = nodes[ip+1][j]; if (destinationNode != null) { double transitionCost = iter.getCost(); assert (!Double.isNaN(transitionCost)); // assert (transitionCost >= 0); Not necessarily double oldBeta = nodes[ip][i].beta; assert (!Double.isNaN(nodes[ip][i].beta)); nodes[ip][i].beta = sumNegLogProb (nodes[ip][i].beta, destinationNode.beta + transitionCost); assert (!Double.isNaN(nodes[ip][i].beta)) : "dest.beta="+destinationNode.beta+" trans="+transitionCost+" sum="+(destinationNode.beta+transitionCost) + " oldBeta="+oldBeta; // xis[ip][i][j] = nodes[ip][i].alpha + transitionCost + nodes[ip+1][j].beta - cost; assert (!Double.isNaN(nodes[ip][i].alpha)); assert (!Double.isNaN(transitionCost)); assert (!Double.isNaN(nodes[ip+1][j].beta)); assert (!Double.isNaN(cost)); if (increment || outputAlphabet != null) { double xi = nodes[ip][i].alpha + transitionCost + nodes[ip+1][j].beta - cost; double p = Math.exp(-xi); assert (p < INFINITE_COST && !Double.isNaN(p)) : "xis["+ip+"]["+i+"]["+j+"]="+-xi; if (increment) iter.incrementCount (p); if (outputAlphabet != null) { int outputIndex = outputAlphabet.lookupIndex (iter.getOutput(), false); assert (outputIndex >= 0); // xxx This assumes that "ip" == "op"! outputCounts[ip][outputIndex] += p; //System.out.println ("CRF Lattice outputCounts["+ip+"]["+outputIndex+"]+="+p); } } } } gammas[ip][i] = nodes[ip][i].alpha + nodes[ip][i].beta - cost; } } if (increment) for (int i = 0; i < numStates; i++) { double p = Math.exp(-gammas[0][i]); assert (p < INFINITE_COST && !Double.isNaN(p)); getState(i).incrementInitialCount (p); } if (outputAlphabet != null) { labelings = new LabelVector[latticeLength]; for (int ip = latticeLength-2; ip >= 0; ip--) { assert (Math.abs(1.0-DenseVector.sum (outputCounts[ip])) < 0.000001);; labelings[ip] = new LabelVector (outputAlphabet, outputCounts[ip]); } } } public double getCost () { assert (!Double.isNaN(cost)); return cost; } // No, this.cost is an "unnormalized cost" //public double getProbability () { return Math.exp (-cost); } public double getGammaCost (int inputPosition, State s) { return gammas[inputPosition][s.getIndex()]; } public double getGammaProbability (int inputPosition, State s) { return Math.exp (-gammas[inputPosition][s.getIndex()]); } public LabelVector getLabelingAtPosition (int outputPosition) { if (labelings != null) return labelings[outputPosition]; return null; } // xxx We are a non-static inner class so this should be easy; but how? // A: By the following weird syntax -cas public Transducer getTransducer () { return Transducer.this; } // A container for some information about a particular input position and state private class LatticeNode { int inputPosition; // outputPosition not really needed until we deal with asymmetric epsilon. State state; Object output; double alpha = INFINITE_COST; double beta = INFINITE_COST; LatticeNode (int inputPosition, State state) { this.inputPosition = inputPosition; this.state = state; assert (this.alpha == INFINITE_COST); // xxx Remove this check } } } // end of class Lattice public ViterbiPath viterbiPath (Object unpipedObject) { Instance carrier = new Instance (unpipedObject, null, null, null, inputPipe); return viterbiPath ((Sequence)carrier.getData()); } public ViterbiPath viterbiPath (Sequence inputSequence) { return viterbiPath (inputSequence, null); } public ViterbiPath viterbiPath (Sequence inputSequence, Sequence outputSequence) { // xxx We don't do epsilon transitions for now assert (outputSequence == null || inputSequence.size() == outputSequence.size()); return new ViterbiPath (inputSequence, outputSequence); } public ViterbiPath_NBest viterbiPath_NBest (Sequence inputSequence, int N) { return viterbiPath_NBest (inputSequence, null, N); } public ViterbiPath_NBest viterbiPath_NBest (Sequence inputSequence, Sequence outputSequence, int N) { // xxx We don't do epsilon transitions for now assert (outputSequence == null || inputSequence.size() == outputSequence.size()); assert(N > 0); return new ViterbiPath_NBest (inputSequence, outputSequence, N); } public class ViterbiPath extends SequencePairAlignment { // double cost inherited from SequencePairAlignment // Sequence input, output inherited from SequencePairAlignment // this.output stores the actual output of Viterbi transitions Sequence providedOutput; ViterbiNode[] nodePath; int latticeLength; protected ViterbiNode getViterbiNode (ViterbiNode[][] nodes, int ip, int stateIndex) { if (nodes[ip][stateIndex] == null) nodes[ip][stateIndex] = new ViterbiNode (ip, getState (stateIndex)); return nodes[ip][stateIndex]; } // You may pass null for output protected ViterbiPath (Sequence inputSequence, Sequence outputSequence) { assert (inputSequence != null); if (logger.isLoggable (Level.FINE)) { logger.fine ("Starting ViterbiPath"); logger.fine ("Input: "); for (int ip = 0; ip < inputSequence.size(); ip++) logger.fine (" " + inputSequence.get(ip));
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -