📄 transducer.java
字号:
public boolean isTrainable () { return false; } // If f is true, and it was already trainable, this has same effect as reset() public void setTrainable (boolean f) { if (f) throw new IllegalStateException ("Cannot be trainable."); } public boolean train (InstanceList instances) { throw new UnsupportedOperationException ("Not trainable."); } public double averageTokenAccuracy (InstanceList ilist) { double accuracy = 0; for (int i = 0; i < ilist.size(); i++) { Instance instance = ilist.getInstance(i); Sequence input = (Sequence) instance.getData(); Sequence output = (Sequence) instance.getTarget(); assert (input.size() == output.size()); double pathAccuracy = viterbiPath(input).tokenAccuracy(output); accuracy += pathAccuracy; logger.info ("Transducer path accuracy = "+pathAccuracy); } return accuracy/ilist.size(); } public double averageTokenAccuracy (InstanceList ilist, String fileName) { double accuracy = 0; PrintWriter out; File f = new File(fileName); try { out = new PrintWriter(new FileWriter(f)); } catch (IOException e) { out = null; } for (int i = 0; i < ilist.size(); i++) { Instance instance = ilist.getInstance(i); Sequence input = (Sequence) instance.getData(); Sequence output = (Sequence) instance.getTarget(); assert (input.size() == output.size()); double pathAccuracy = viterbiPath(input).tokenAccuracy(output, out); accuracy += pathAccuracy; logger.info ("Transducer path accuracy = "+pathAccuracy); } out.close(); return accuracy/ilist.size(); } // Treat the costs as if they are -log(probabilies); we will // normalize them if necessary public SequencePairAlignment generatePath () { if (isGenerative() == false) throw new IllegalStateException ("Transducer is not generative."); ArrayList initialStates = new ArrayList (); Iterator iter = initialStateIterator (); while (iter.hasNext()) { initialStates.add (iter.next()); } // xxx Not yet finished. throw new UnsupportedOperationException (); } public Lattice forwardBackward (Sequence inputSequence) { return forwardBackward (inputSequence, null, false); } public Lattice forwardBackward (Sequence inputSequence, boolean increment) { return forwardBackward (inputSequence, null, increment); } public Lattice forwardBackward (Sequence inputSequence, Sequence outputSequence) { return forwardBackward (inputSequence, outputSequence, false); } public Lattice forwardBackward (Sequence inputSequence, Sequence outputSequence, boolean increment) { return forwardBackward (inputSequence, outputSequence, increment, null); } public Lattice forwardBackward (Sequence inputSequence, Sequence outputSequence, boolean increment, LabelAlphabet outputAlphabet) { return forwardBackward (inputSequence, outputSequence, increment, false, outputAlphabet); } public Lattice forwardBackward (Sequence inputSequence, Sequence outputSequence, boolean increment, boolean saveXis, LabelAlphabet outputAlphabet) { // xxx We don't do epsilon transitions for now assert (outputSequence == null || inputSequence.size() == outputSequence.size()); return new Lattice (inputSequence, outputSequence, increment, saveXis, outputAlphabet); } // culotta: interface for constrained lattice /** Create constrained lattice such that all paths pass through the the labeling of <code> requiredSegment </code> as indicated by <code> constrainedSequence </code> @param inputSequence input sequence @param outputSequence output sequence @param requiredSegment segment of sequence that must be labelled @param constrainedSequence lattice must have labels of this sequence from <code> requiredSegment.start </code> to <code> requiredSegment.end </code> correctly */ public Lattice forwardBackward (Sequence inputSequence, Sequence outputSequence, Segment requiredSegment, Sequence constrainedSequence) { if (constrainedSequence.size () != inputSequence.size ()) throw new IllegalArgumentException ("constrainedSequence.size [" + constrainedSequence.size () + "] != inputSequence.size [" + inputSequence.size () + "]"); // constraints tells the lattice which states must emit which // observations. positive values say all paths must pass through // this state index, negative values say all paths must _not_ // pass through this state index. 0 means we don't // care. initialize to 0. include 1 extra node for start state. int [] constraints = new int [constrainedSequence.size() + 1]; for (int c = 0; c < constraints.length; c++) constraints[c] = 0; for (int i=requiredSegment.getStart (); i <= requiredSegment.getEnd(); i++) { int si = stateIndexOfString ((String)constrainedSequence.get (i)); if (si == -1) logger.warning ("Could not find state " + constrainedSequence.get (i) + ". Check that state labels match startTages and inTags, and that all labels are seen in training data.");// throw new IllegalArgumentException ("Could not find state " + constrainedSequence.get(i) + ". Check that state labels match startTags and InTags."); constraints[i+1] = si + 1; } // set additional negative constraint to ensure state after // segment is not a continue tag // xxx if segment length=1, this actually constrains the sequence // to B-tag (B-tag)', instead of the intended constraint of B-tag // (I-tag)' // the fix below is unsafe, but will have to do for now. // FIXED BELOW/* String endTag = (String) constrainedSequence.get (requiredSegment.getEnd ()); if (requiredSegment.getEnd()+2 < constraints.length) { if (requiredSegment.getStart() == requiredSegment.getEnd()) { // segment has length 1 if (endTag.startsWith ("B-")) { endTag = "I" + endTag.substring (1, endTag.length()); } else if (!(endTag.startsWith ("I-") || endTag.startsWith ("0"))) throw new IllegalArgumentException ("Constrained Lattice requires that states are tagged in B-I-O format."); } int statei = stateIndexOfString (endTag); if (statei == -1) // no I- tag for this B- tag statei = stateIndexOfString ((String)constrainedSequence.get (requiredSegment.getStart ())); constraints[requiredSegment.getEnd() + 2] = - (statei + 1); }*/ if (requiredSegment.getEnd() + 2 < constraints.length) { // if String endTag = requiredSegment.getInTag().toString(); int statei = stateIndexOfString (endTag); if (statei == -1) logger.fine ("Could not find state " + endTag + ". Check that state labels match startTags and InTags."); else constraints[requiredSegment.getEnd() + 2] = - (statei + 1); } // printStates (); logger.fine ("Segment:\n" + requiredSegment.sequenceToString () + "\nconstrainedSequence:\n" + constrainedSequence + "\nConstraints:\n"); for (int i=0; i < constraints.length; i++) { logger.fine (constraints[i] + "\t"); } logger.fine (""); return forwardBackward (inputSequence, outputSequence, constraints); } public int stateIndexOfString (String s) { for (int i = 0; i < this.numStates(); i++) { String state = this.getState (i).getName(); if (state.equals (s)) return i; } return -1; } private void printStates () { for (int i = 0; i < this.numStates(); i++) logger.fine (i + ":" + this.getState (i).getName()); } public void print () { logger.fine ("Transducer "+this); printStates(); } public Lattice forwardBackward (Sequence inputSequence, Sequence outputSequence, int [] constraints) { return new Lattice (inputSequence, outputSequence, false, null, constraints); } // Remove this method? // If "increment" is true, call incrementInitialCount, incrementFinalCount and incrementCount private Lattice forwardBackward (SequencePair inputOutputPair, boolean increment) { return this.forwardBackward (inputOutputPair.input(), inputOutputPair.output(), increment); } // xxx Include methods like this? // ...making random selections proportional to cost //public Transduction transduce (Object[] inputSequence) //{ throw new UnsupportedOperationException (); } //public Transduction transduce (Sequence inputSequence) //{ throw new UnsupportedOperationException (); } public class Lattice // ?? extends SequencePairAlignment, but there isn't just a single output! { // "ip" == "input position", "op" == "output position", "i" == "state index" double cost; Sequence input, output; LatticeNode[][] nodes; // indexed by ip,i int latticeLength; // xxx Now that we are incrementing here directly, there isn't // necessarily a need to save all these arrays... // (Actually, there are useful to have, and they can be turned off by // -log(probability) of being in state "i" at input position "ip" double[][] gammas; // indexed by ip,i double[][][] xis; // indexed by ip,i,j; saved only if saveXis is true; LabelVector labelings[]; // indexed by op, created only if "outputAlphabet" is non-null in constructor private LatticeNode getLatticeNode (int ip, int stateIndex) { if (nodes[ip][stateIndex] == null) nodes[ip][stateIndex] = new LatticeNode (ip, getState (stateIndex)); return nodes[ip][stateIndex]; } // You may pass null for output, meaning that the lattice // is not constrained to match the output protected Lattice (Sequence input, Sequence output, boolean increment) { this (input, output, increment, false, null); } // You may pass null for output, meaning that the lattice // is not constrained to match the output protected Lattice (Sequence input, Sequence output, boolean increment, boolean saveXis) { this (input, output, increment, saveXis, null); } // If outputAlphabet is non-null, this will create a LabelVector // for each position in the output sequence indicating the // probability distribution over possible outputs at that time // index protected Lattice (Sequence input, Sequence output, boolean increment, boolean saveXis, LabelAlphabet outputAlphabet) { if (false && logger.isLoggable (Level.FINE)) { logger.fine ("Starting Lattice"); logger.fine ("Input: "); for (int ip = 0; ip < input.size(); ip++) logger.fine (" " + input.get(ip)); logger.fine ("\nOutput: "); if (output == null) logger.fine ("null"); else for (int op = 0; op < output.size(); op++) logger.fine (" " + output.get(op)); logger.fine ("\n"); } // Initialize some structures this.input = input; this.output = output; // xxx Not very efficient when the lattice is actually sparse, // especially when the number of states is large and the // sequence is long. latticeLength = input.size()+1; int numStates = numStates(); nodes = new LatticeNode[latticeLength][numStates]; // xxx Yipes, this could get big; something sparse might be better? gammas = new double[latticeLength][numStates]; if (saveXis) xis = new double[latticeLength][numStates][numStates]; double outputCounts[][] = null; if (outputAlphabet != null) outputCounts = new double[latticeLength][outputAlphabet.size()]; for (int i = 0; i < numStates; i++) { for (int ip = 0; ip < latticeLength; ip++) gammas[ip][i] = INFINITE_COST; if (saveXis) for (int j = 0; j < numStates; j++) for (int ip = 0; ip < latticeLength; ip++) xis[ip][i][j] = INFINITE_COST; } // Forward pass logger.fine ("Starting Foward pass"); boolean atLeastOneInitialState = false; for (int i = 0; i < numStates; i++) { double initialCost = getState(i).initialCost; //System.out.println ("Forward pass initialCost = "+initialCost); if (initialCost < INFINITE_COST) { getLatticeNode(0, i).alpha = initialCost; //System.out.println ("nodes[0][i].alpha="+nodes[0][i].alpha); atLeastOneInitialState = true; } } if (atLeastOneInitialState == false) logger.warning ("There are no starting states!"); for (int ip = 0; ip < latticeLength-1; ip++) for (int i = 0; i < numStates; i++) { if (nodes[ip][i] == null || nodes[ip][i].alpha == INFINITE_COST) // xxx if we end up doing this a lot, // we could save a list of the non-null ones continue; State s = getState(i);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -