⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 transducer.java

📁 这是一个matlab的java实现。里面有许多内容。请大家慢慢捉摸。
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
		 @param constrainedSequence lattice must have labels of this		 sequence from <code> requiredSegment.start </code> to <code>		 requiredSegment.end </code> correctly	*/	public Lattice forwardBackward (Sequence inputSequence,																	Sequence outputSequence,																	Segment requiredSegment,																	Sequence constrainedSequence) {		if (constrainedSequence.size () != inputSequence.size ())			throw new IllegalArgumentException ("constrainedSequence.size [" + constrainedSequence.size () + "] != inputSequence.size [" + inputSequence.size () + "]"); 		// constraints tells the lattice which states must emit which		// observations.  positive values say all paths must pass through		// this state index, negative values say all paths must _not_		// pass through this state index.  0 means we don't		// care. initialize to 0. include 1 extra node for start state.		int [] constraints = new int [constrainedSequence.size() + 1];		for (int c = 0; c < constraints.length; c++)			constraints[c] = 0;		for (int i=requiredSegment.getStart (); i <= requiredSegment.getEnd(); i++) {			int si = stateIndexOfString ((String)constrainedSequence.get (i));			if (si == -1)				logger.warning ("Could not find state " + constrainedSequence.get (i) + ". Check that state labels match startTages and inTags, and that all labels are seen in training data.");//				throw new IllegalArgumentException ("Could not find state " + constrainedSequence.get(i) + ". Check that state labels match startTags and InTags.");			constraints[i+1] = si + 1;		}		// set additional negative constraint to ensure state after		// segment is not a continue tag				// xxx if segment length=1, this actually constrains the sequence		// to B-tag (B-tag)', instead of the intended constraint of B-tag		// (I-tag)'		// the fix below is unsafe, but will have to do for now.		// FIXED BELOW/*		String endTag = (String) constrainedSequence.get (requiredSegment.getEnd ());		if (requiredSegment.getEnd()+2 < constraints.length) {			if (requiredSegment.getStart() == requiredSegment.getEnd()) { // segment has length 1				if (endTag.startsWith ("B-")) {					endTag = "I" + endTag.substring (1, endTag.length());				}				else if (!(endTag.startsWith ("I-") || endTag.startsWith ("0")))					throw new IllegalArgumentException ("Constrained Lattice requires that states are tagged in B-I-O format.");			}			int statei = stateIndexOfString (endTag);			if (statei == -1) // no I- tag for this B- tag				statei = stateIndexOfString ((String)constrainedSequence.get (requiredSegment.getStart ()));			constraints[requiredSegment.getEnd() + 2] = - (statei + 1);		}*/		if (requiredSegment.getEnd() + 2 < constraints.length) { // if 			String endTag = requiredSegment.getInTag().toString();			int statei = stateIndexOfString (endTag);			if (statei == -1)				throw new IllegalArgumentException ("Could not find state " + endTag + ". Check that state labels match startTags and InTags.");						constraints[requiredSegment.getEnd() + 2] = - (statei + 1);					}						//		printStates ();		logger.fine ("Segment:\n" + requiredSegment.sequenceToString () +								 "\nconstrainedSequence:\n" + constrainedSequence +								 "\nConstraints:\n");		for (int i=0; i < constraints.length; i++) {			logger.fine (constraints[i] + "\t");		}		logger.fine ("");		return forwardBackward (inputSequence, outputSequence, constraints);	}				public int stateIndexOfString (String s)	{		for (int i = 0; i < this.numStates(); i++) {			String state = this.getState (i).getName();			if (state.equals (s))				return i;	 	}	 	return -1;			}		private void printStates () {		for (int i = 0; i < this.numStates(); i++) 			logger.fine (i + ":" + this.getState (i).getName());	}			public void print () {		logger.fine ("Transducer "+this);		printStates();	}	public Lattice forwardBackward (Sequence inputSequence,																	Sequence outputSequence, int [] constraints)	{		return new Lattice (inputSequence, outputSequence, false, null,												constraints);	}		// Remove this method?	// If "increment" is true, call incrementInitialCount, incrementFinalCount and incrementCount	private Lattice forwardBackward (SequencePair inputOutputPair, boolean increment) {		return this.forwardBackward (inputOutputPair.input(), inputOutputPair.output(), increment);	}		// xxx Include methods like this?	// ...making random selections proportional to cost	//public Transduction transduce (Object[] inputSequence)	//{	throw new UnsupportedOperationException (); }	//public Transduction transduce (Sequence inputSequence)	//{	throw new UnsupportedOperationException (); }	public class Lattice // ?? extends SequencePairAlignment, but there isn't just a single output!	{		// "ip" == "input position", "op" == "output position", "i" == "state index"		double cost;		Sequence input, output;		LatticeNode[][] nodes;			 // indexed by ip,i		int latticeLength;		// xxx Now that we are incrementing here directly, there isn't		// necessarily a need to save all these arrays...		// -log(probability) of being in state "i" at input position "ip"		double[][] gammas;					 // indexed by ip,i		LabelVector labelings[];			 // indexed by op, created only if "outputAlphabet" is non-null in constructor		private LatticeNode getLatticeNode (int ip, int stateIndex)		{			if (nodes[ip][stateIndex] == null)				nodes[ip][stateIndex] = new LatticeNode (ip, getState (stateIndex));			return nodes[ip][stateIndex];		}		// You may pass null for output, meaning that the lattice		// is not constrained to match the output		protected Lattice (Sequence input, Sequence output, boolean increment)		{			this (input, output, increment, null);		}		// If outputAlphabet is non-null, this will create a LabelVector		// for each position in the output sequence indicating the		// probability distribution over possible outputs at that time		// index		protected Lattice (Sequence input, Sequence output, boolean increment, LabelAlphabet outputAlphabet)		{			if (false && logger.isLoggable (Level.FINE)) {				logger.fine ("Starting Lattice");				logger.fine ("Input: ");				for (int ip = 0; ip < input.size(); ip++)					logger.fine (" " + input.get(ip));				logger.fine ("\nOutput: ");				if (output == null)					logger.fine ("null");				else					for (int op = 0; op < output.size(); op++)						logger.fine (" " + output.get(op));				logger.fine ("\n");			}			// Initialize some structures			this.input = input;			this.output = output;			// xxx Not very efficient when the lattice is actually sparse,			// especially when the number of states is large and the			// sequence is long.			latticeLength = input.size()+1;			int numStates = numStates();			nodes = new LatticeNode[latticeLength][numStates];			// xxx Yipes, this could get big; something sparse might be better?			gammas = new double[latticeLength][numStates];			// xxx Move this to an ivar, so we can save it?  But for what?      // Commenting this out, because it's a memory hog and not used right now.      //  Uncomment and conditionalize under a flag if ever needed. -cas			// double xis[][][] = new double[latticeLength][numStates][numStates];			double outputCounts[][] = null;			if (outputAlphabet != null)				outputCounts = new double[latticeLength][outputAlphabet.size()];			for (int i = 0; i < numStates; i++) {				for (int ip = 0; ip < latticeLength; ip++)					gammas[ip][i] = INFINITE_COST;        /* Commenting out xis -cas				for (int j = 0; j < numStates; j++)					for (int ip = 0; ip < latticeLength; ip++)						xis[ip][i][j] = INFINITE_COST;        */			}			// Forward pass			logger.fine ("Starting Foward pass");			boolean atLeastOneInitialState = false;			for (int i = 0; i < numStates; i++) {				double initialCost = getState(i).initialCost;				//System.out.println ("Forward pass initialCost = "+initialCost);				if (initialCost < INFINITE_COST) {					getLatticeNode(0, i).alpha = initialCost;					//System.out.println ("nodes[0][i].alpha="+nodes[0][i].alpha);					atLeastOneInitialState = true;				}			}			if (atLeastOneInitialState == false)				logger.warning ("There are no starting states!");			for (int ip = 0; ip < latticeLength-1; ip++)				for (int i = 0; i < numStates; i++) {					if (nodes[ip][i] == null || nodes[ip][i].alpha == INFINITE_COST)						// xxx if we end up doing this a lot,						// we could save a list of the non-null ones						continue;					State s = getState(i);					TransitionIterator iter = s.transitionIterator (input, ip, output, ip);					if (logger.isLoggable (Level.FINE))						logger.fine (" Starting Foward transition iteration from state "												 + s.getName() + " on input " + input.get(ip).toString()												 + " and output "												 + (output==null ? "(null)" : output.get(ip).toString()));					while (iter.hasNext()) {						State destination = iter.nextState();						if (logger.isLoggable (Level.FINE))							logger.fine ("Forward Lattice[inputPos="+ip													 +"][source="+s.getName()													 +"][dest="+destination.getName()+"]");						LatticeNode destinationNode = getLatticeNode (ip+1, destination.getIndex());						destinationNode.output = iter.getOutput();						double transitionCost = iter.getCost();						if (logger.isLoggable (Level.FINE))							logger.fine ("transitionCost="+transitionCost													 +" nodes["+ip+"]["+i+"].alpha="+nodes[ip][i].alpha													 +" destinationNode.alpha="+destinationNode.alpha);						destinationNode.alpha = sumNegLogProb (destinationNode.alpha,																									 nodes[ip][i].alpha + transitionCost);						//System.out.println ("destinationNode.alpha <- "+destinationNode.alpha);					}				}			// Calculate total cost of Lattice.  This is the normalizer			cost = INFINITE_COST;			for (int i = 0; i < numStates; i++)				if (nodes[latticeLength-1][i] != null) {					// Note: actually we could sum at any ip index,					// the choice of latticeLength-1 is arbitrary					//System.out.println ("Ending alpha, state["+i+"] = "+nodes[latticeLength-1][i].alpha);					//System.out.println ("Ending beta,  state["+i+"] = "+getState(i).finalCost);					cost = sumNegLogProb (cost,																(nodes[latticeLength-1][i].alpha + getState(i).finalCost));				}			// Cost is now an "unnormalized cost" of the entire Lattice			//assert (cost >= 0) : "cost = "+cost;			// If the sequence has infinite cost, just return.			// Usefully this avoids calling any incrementX methods.			// It also relies on the fact that the gammas[][] and .alpha and .beta values			// are already initialized to values that reflect infinite cost			// xxx Although perhaps not all (alphas,betas) exactly correctly reflecting?			if (cost == INFINITE_COST)				return;			// Backward pass			for (int i = 0; i < numStates; i++)				if (nodes[latticeLength-1][i] != null) {					State s = getState(i);					nodes[latticeLength-1][i].beta = s.finalCost;					gammas[latticeLength-1][i] =						nodes[latticeLength-1][i].alpha + nodes[latticeLength-1][i].beta - cost;					if (increment) {						double p = Math.exp(-gammas[latticeLength-1][i]);						assert (p < INFINITE_COST && !Double.isNaN(p))							: "p="+p+" gamma="+gammas[latticeLength-1][i];						s.incrementFinalCount (p);					}				}			for (int ip = latticeLength-2; ip >= 0; ip--) {				for (int i = 0; i < numStates; i++) {					if (nodes[ip][i] == null || nodes[ip][i].alpha == INFINITE_COST)						// Note that skipping here based on alpha means that beta values won't						// be correct, but since alpha is infinite anyway, it shouldn't matter.						continue;					State s = getState(i);					TransitionIterator iter = s.transitionIterator (input, ip, output, ip);					while (iter.hasNext()) {						State destination = iter.nextState();						if (logger.isLoggable (Level.FINE))							logger.fine ("Backward Lattice[inputPos="+ip													 +"][source="+s.getName()													 +"][dest="+destination.getName()+"]");						int j = destination.getIndex();						LatticeNode destinationNode = nodes[ip+1][j];						if (destinationNode != null) {							double transitionCost = iter.getCost();							assert (!Double.isNaN(transitionCost));							//							assert (transitionCost >= 0);  Not necessarily							double oldBeta = nodes[ip][i].beta;							assert (!Double.isNaN(nodes[ip][i].beta));							nodes[ip][i].beta = sumNegLogProb (nodes[ip][i].beta,																								 destinationNode.beta + transitionCost);							assert (!Double.isNaN(nodes[ip][i].beta))								: "dest.beta="+destinationNode.beta+" trans="+transitionCost+" sum="+(destinationNode.beta+transitionCost)								+ " oldBeta="+oldBeta;//							xis[ip][i][j] = nodes[ip][i].alpha + transitionCost + nodes[ip+1][j].beta - cost;							assert (!Double.isNaN(nodes[ip][i].alpha));							assert (!Double.isNaN(transitionCost));							assert (!Double.isNaN(nodes[ip+1][j].beta));							assert (!Double.isNaN(cost));							if (increment || outputAlphabet != null) {                double xi = nodes[ip][i].alpha + transitionCost + nodes[ip+1][j].beta - cost;								double p = Math.exp(-xi);								assert (p < INFINITE_COST && !Double.isNaN(p)) : "xis["+ip+"]["+i+"]["+j+"]="+-xi;								if (increment)									iter.incrementCount (p);								if (outputAlphabet != null) {									int outputIndex = outputAlphabet.lookupIndex (iter.getOutput(), false);									assert (outputIndex >= 0);									// xxx This assumes that "ip" == "op"!									outputCounts[ip][outputIndex] += p;									//System.out.println ("CRF Lattice outputCounts["+ip+"]["+outputIndex+"]+="+p);								}							}						}					}					gammas[ip][i] = nodes[ip][i].alpha + nodes[ip][i].beta - cost;				}			}			if (increment)				for (int i = 0; i < numStates; i++) {					double p = Math.exp(-gammas[0][i]);					assert (p < INFINITE_COST && !Double.isNaN(p));					getState(i).incrementInitialCount (p);				}			if (outputAlphabet != null) {				labelings = new LabelVector[latticeLength];				for (int ip = latticeLength-2; ip >= 0; ip--) {					assert (Math.abs(1.0-DenseVector.sum (outputCounts[ip])) < 0.000001);;					labelings[ip] = new LabelVector (outputAlphabet, outputCounts[ip]);				}			}		}		// culotta: constructor for constrained lattice		/** Create a lattice that constrains its transitions such that the		 * <position,label> pairs in "constraints" are adhered		 * to. constraints is an array where each entry is the index of		 * the required label at that position. An entry of 0 means there		 * are no constraints on that <position, label>. Positive values		 * mean the path must pass through that state. Negative values		 * mean the path must _not_ pass through that state. NOTE -		 * constraints.length must be equal to output.size() + 1. A		 * lattice has one extra position for the initial		 * state. Generally, this should be unconstrained, since it does		 * not produce an observation.		*/		protected Lattice (Sequence input, Sequence output, boolean increment, LabelAlphabet outputAlphabet, int [] constraints)

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -