📄 transducer.java
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. *//** @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a> */package edu.umass.cs.mallet.base.fst;// Analogous to base.types.classify.Classifierimport java.util.Iterator;import java.util.ArrayList;import java.util.logging.*;import java.util.Stack;import edu.umass.cs.mallet.base.pipe.Pipe;//import edu.umass.cs.mallet.base.pipe.SerialPipe;import edu.umass.cs.mallet.base.types.InstanceList;import edu.umass.cs.mallet.base.types.Instance;import edu.umass.cs.mallet.base.types.Sequence;import edu.umass.cs.mallet.base.types.ArraySequence;import edu.umass.cs.mallet.base.types.SequencePair;import edu.umass.cs.mallet.base.types.SequencePairAlignment;import edu.umass.cs.mallet.base.types.Label;import edu.umass.cs.mallet.base.types.LabelAlphabet;import edu.umass.cs.mallet.base.types.LabelVector;import edu.umass.cs.mallet.base.types.DenseVector;import edu.umass.cs.mallet.base.types.Alphabet;import edu.umass.cs.mallet.base.util.MalletLogger;import java.io.*;// Variable name key:// "ip" = "input position"// "op" = "output position"public abstract class Transducer implements Serializable{ private static Logger logger = MalletLogger.getLogger(Transducer.class.getName()); { // xxx Why isn't this resulting in printing the log messages? //logger.setLevel (Level.FINE); //logger.addHandler (new StreamHandler (System.out, new SimpleFormatter ())); //System.out.println ("Setting level to finer"); //System.out.println ("level = " + logger.getLevel()); //logger.warning ("Foooooo"); } public static final double ZERO_COST = 0; public static final double INFINITE_COST = Double.POSITIVE_INFINITY; //private Stack availableTransitionIterators = new Stack (); // Serialization private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 1; private static final int NO_PIPE_VERSION = 0; private void writeObject (ObjectOutputStream out) throws IOException { int i, size; out.writeInt (CURRENT_SERIAL_VERSION); out.writeObject(inputPipe); out.writeObject(outputPipe); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int size, i; int version = in.readInt (); if (version == NO_PIPE_VERSION) { inputPipe = null; outputPipe = null; } else { inputPipe = (Pipe) in.readObject(); outputPipe = (Pipe) in.readObject(); } } public abstract static class State implements Serializable { double initialCost = 0; double finalCost = 0; public abstract String getName(); public abstract int getIndex (); public double getInitialCost () { return initialCost; } public void setInitialCost (double c) { initialCost = c; } public double getFinalCost () { return finalCost; } public void setFinalCost (double c) { finalCost = c; } //public Transducer getTransducer () { return (Transducer)this; } //public abstract TransitionIterator transitionIterator (Object input); // Pass negative positions for a sequence to request "epsilon // transitions" for either input or output. (-position-1) should be // the position in the sequence after which we are trying to insert // the espilon transition. public abstract TransitionIterator transitionIterator (Sequence input, int inputPosition, Sequence output, int outputPosition); /* public abstract TransitionIterator transitionIterator { if (availableTransitionIterators.size() > 0) return ((TransitionIterator)availableTransitionIterators.pop()).initialize (State source, Sequence input, int inputPosition, Sequence output, int outputPosition); else return newTransitionIterator (Sequence input, int inputPosition, Sequence output, int outputPosition); } */ // Pass negative input position for a sequence to request "epsilon // transitions". (-position-1) should be the position in the // sequence after which we are trying to insert the espilon // transition. public TransitionIterator transitionIterator (Sequence input, int inputPosition) { return transitionIterator (input, inputPosition, null, 0); } // For generative transducers: // Return all possible transitions, independent of input public TransitionIterator transitionIterator () { return transitionIterator (null, 0, null, 0); } // For trainable transducers: public void incrementInitialCount (double count) { throw new UnsupportedOperationException (); } public void incrementFinalCount (double count) { throw new UnsupportedOperationException (); } // Serialization private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private void writeObject (ObjectOutputStream out) throws IOException { int i, size; out.writeInt (CURRENT_SERIAL_VERSION); out.writeDouble(initialCost); out.writeDouble(finalCost); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int size, i; int version = in.readInt (); initialCost = in.readDouble(); finalCost = in.readDouble(); } } public abstract static class TransitionIterator implements Iterator, Serializable { //public abstract void initialize (Sequence input, int inputPosition, //Sequence output, int outputPosition); public abstract boolean hasNext (); public int numberNext(){ return -1;} public abstract State nextState (); // returns the destination state public Object next () { return nextState(); } public void remove () { throw new UnsupportedOperationException (); } public abstract Object getInput (); public abstract Object getOutput (); public abstract double getCost (); public abstract State getSourceState (); public abstract State getDestinationState (); // In future these will allow for transition that consume variable amounts of the sequences public int getInputPositionIncrement () { return 1; } public int getOutputPositionIncrement () { return 1; } //public abstract Transducer getTransducer () {return getSourceState().getTransducer();} // For trainable transducers: public void incrementCount (double count) { throw new UnsupportedOperationException (); } // Serialization private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private void writeObject (ObjectOutputStream out) throws IOException { int i, size; out.writeInt (CURRENT_SERIAL_VERSION); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int size, i; int version = in.readInt (); } } /** A pipe that should produce a Sequence in the "data" slot, (and possibly one in the "target" slot also */ protected Pipe inputPipe; /** A pipe that should expect a ViterbiPath in the "target" slot, and should produce something printable in the "source" slot that indicates the results of transduction. */ protected Pipe outputPipe; public Pipe getInputPipe () { return inputPipe; } public Pipe getOutputPipe () { return outputPipe; } /** We aren't really a Pipe subclass, but this method works like Pipes' do. */ public Instance pipe (Instance carrier) { carrier.setTarget(viterbiPath ((Sequence)carrier.getData())); return carrier; } // xxx Enrich this later. // Perhaps to something like: // public Transduction transduce (Instance instance) // public Transduction transduce (Object obj) public Instance transduce (Instance instance) { throw new UnsupportedOperationException ("Not yet implemented"); } public abstract int numStates (); public abstract State getState (int index); // Note that this method is allowed to return states with infinite initialCost. public abstract Iterator initialStateIterator (); // Some transducers are "generative", meaning that you can get a // sequence out of them without giving them an input sequence. In // this case State.transitionIterator() should return all available // transitions, but attempts to obtain the input and cost fields may // throw an exception. // xxx Why could obtaining "cost" be a problem??? public boolean canIterateAllTransitions () { return false; } // If true, this is a "generative transducer". In this case // State.transitionIterator() should return transitions that have // valid input and cost fields. True returned here should imply // that canIterateAllTransitions() is true. public boolean isGenerative () { return false; } public boolean isTrainable () { return false; } // If f is true, and it was already trainable, this has same effect as reset() public void setTrainable (boolean f) { if (f) throw new IllegalStateException ("Cannot be trainable."); } public boolean train (InstanceList instances) { throw new UnsupportedOperationException ("Not trainable."); } public double averageTokenAccuracy (InstanceList ilist) { double accuracy = 0; for (int i = 0; i < ilist.size(); i++) { Instance instance = ilist.getInstance(i); Sequence input = (Sequence) instance.getData(); Sequence output = (Sequence) instance.getTarget(); assert (input.size() == output.size()); double pathAccuracy = viterbiPath(input).tokenAccuracy(output); accuracy += pathAccuracy; logger.info ("Transducer path accuracy = "+pathAccuracy); } return accuracy/ilist.size(); } public double averageTokenAccuracy (InstanceList ilist, String fileName) { double accuracy = 0; PrintWriter out; File f = new File(fileName); try { out = new PrintWriter(new FileWriter(f)); } catch (IOException e) { out = null; } for (int i = 0; i < ilist.size(); i++) { Instance instance = ilist.getInstance(i); Sequence input = (Sequence) instance.getData(); Sequence output = (Sequence) instance.getTarget(); assert (input.size() == output.size()); double pathAccuracy = viterbiPath(input).tokenAccuracy(output, out); accuracy += pathAccuracy; logger.info ("Transducer path accuracy = "+pathAccuracy); } out.close(); return accuracy/ilist.size(); } // Treat the costs as if they are -log(probabilies); we will // normalize them if necessary public SequencePairAlignment generatePath () { if (isGenerative() == false) throw new IllegalStateException ("Transducer is not generative."); ArrayList initialStates = new ArrayList (); Iterator iter = initialStateIterator (); while (iter.hasNext()) { initialStates.add (iter.next()); } // xxx Not yet finished. throw new UnsupportedOperationException (); } public Lattice forwardBackward (Sequence inputSequence) { return forwardBackward (inputSequence, null, false); } public Lattice forwardBackward (Sequence inputSequence, boolean increment) { return forwardBackward (inputSequence, null, increment); } public Lattice forwardBackward (Sequence inputSequence, Sequence outputSequence) { return forwardBackward (inputSequence, outputSequence, false); } public Lattice forwardBackward (Sequence inputSequence, Sequence outputSequence, boolean increment) { return forwardBackward (inputSequence, outputSequence, increment, null); } public Lattice forwardBackward (Sequence inputSequence, Sequence outputSequence, boolean increment, LabelAlphabet outputAlphabet) { // xxx We don't do epsilon transitions for now assert (outputSequence == null || inputSequence.size() == outputSequence.size()); return new Lattice (inputSequence, outputSequence, increment, outputAlphabet); } // culotta: interface for constrained lattice /** Create constrained lattice such that all paths pass through the the labeling of <code> requiredSegment </code> as indicated by <code> constrainedSequence </code> @param inputSequence input sequence @param outputSequence output sequence @param requiredSegment segment of sequence that must be labelled
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -