⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 transducer.java

📁 mallet是自然语言处理、机器学习领域的一个开源项目。
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
				for (int i = 0; i < numStates; i++) {					if (nodes[ip][i] == null || nodes[ip][i].alpha == INFINITE_COST)						// Note that skipping here based on alpha means that beta values won't						// be correct, but since alpha is infinite anyway, it shouldn't matter.						continue;					State s = getState(i);					TransitionIterator iter = s.transitionIterator (input, ip, output, ip);					while (iter.hasNext()) {						State destination = iter.nextState();						if (logger.isLoggable (Level.FINE))							logger.fine ("Backward Lattice[inputPos="+ip													 +"][source="+s.getName()													 +"][dest="+destination.getName()+"]");						int j = destination.getIndex();						LatticeNode destinationNode = nodes[ip+1][j];						if (destinationNode != null) {							double transitionCost = iter.getCost();							assert (!Double.isNaN(transitionCost));							//							assert (transitionCost >= 0);  Not necessarily							double oldBeta = nodes[ip][i].beta;							assert (!Double.isNaN(nodes[ip][i].beta));							nodes[ip][i].beta = sumNegLogProb (nodes[ip][i].beta,																								 destinationNode.beta + transitionCost);							assert (!Double.isNaN(nodes[ip][i].beta))								: "dest.beta="+destinationNode.beta+" trans="+transitionCost+" sum="+(destinationNode.beta+transitionCost)								+ " oldBeta="+oldBeta;							// xis[ip][i][j] = nodes[ip][i].alpha + transitionCost + nodes[ip+1][j].beta - cost;							assert (!Double.isNaN(nodes[ip][i].alpha));							assert (!Double.isNaN(transitionCost));							assert (!Double.isNaN(nodes[ip+1][j].beta));							assert (!Double.isNaN(cost));							if (increment || outputAlphabet != null) {                double xi = nodes[ip][i].alpha + transitionCost + nodes[ip+1][j].beta - cost;								double p = Math.exp(-xi);								assert (p < INFINITE_COST && !Double.isNaN(p)) : "xis["+ip+"]["+i+"]["+j+"]="+-xi;								if (increment)									iter.incrementCount (p);								if (outputAlphabet != null) {									int outputIndex = outputAlphabet.lookupIndex (iter.getOutput(), false);									assert (outputIndex >= 0);									// xxx This assumes that "ip" == "op"!									outputCounts[ip][outputIndex] += p;									//System.out.println ("CRF Lattice outputCounts["+ip+"]["+outputIndex+"]+="+p);								}							}						}					}					gammas[ip][i] = nodes[ip][i].alpha + nodes[ip][i].beta - cost;				}			}			if (increment)				for (int i = 0; i < numStates; i++) {					double p = Math.exp(-gammas[0][i]);					assert (p < INFINITE_COST && !Double.isNaN(p));					getState(i).incrementInitialCount (p);				}			if (outputAlphabet != null) {				labelings = new LabelVector[latticeLength];				for (int ip = latticeLength-2; ip >= 0; ip--) {					assert (Math.abs(1.0-DenseVector.sum (outputCounts[ip])) < 0.000001);;					labelings[ip] = new LabelVector (outputAlphabet, outputCounts[ip]);				}			}		}		public double getCost () {			assert (!Double.isNaN(cost));			return cost; }		// No, this.cost is an "unnormalized cost"		//public double getProbability () { return Math.exp (-cost); }		public double getGammaCost (int inputPosition, State s) {			return gammas[inputPosition][s.getIndex()]; }		public double getGammaProbability (int inputPosition, State s) {			return Math.exp (-gammas[inputPosition][s.getIndex()]); }    public double getXiProbability (int ip, State s1, State s2) {      if (xis == null)        throw new IllegalStateException ("xis were not saved.");      int i = s1.getIndex ();      int j = s2.getIndex ();      return Math.exp (-xis[ip][i][j]);    }    public double getXiCost (int ip, State s1, State s2)    {      if (xis == null)        throw new IllegalStateException ("xis were not saved.");      int i = s1.getIndex ();      int j = s2.getIndex ();      return xis[ip][i][j];    }    public int length () { return latticeLength; }    public double getAlpha (int ip, State s) {      LatticeNode node = getLatticeNode (ip, s.getIndex ());      return node.alpha;    }    public double getBeta (int ip, State s) {       LatticeNode node = getLatticeNode (ip, s.getIndex ());       return node.beta;     }		public LabelVector getLabelingAtPosition (int outputPosition)	{			if (labelings != null)				return labelings[outputPosition];			return null;		}		// Q: We are a non-static inner class so this should be easy; but how?		// A: By the following weird syntax -cas		public Transducer getTransducer ()		{			return Transducer.this;		}		// A container for some information about a particular input position and state		private class LatticeNode		{			int inputPosition;			// outputPosition not really needed until we deal with asymmetric epsilon.			State state;			Object output;			double alpha = INFINITE_COST;			double beta = INFINITE_COST;			LatticeNode (int inputPosition, State state)	{				this.inputPosition = inputPosition;				this.state = state;				assert (this.alpha == INFINITE_COST);	// xxx Remove this check			}		}	}	// end of class Lattice    // ******************************************************************************    // CPAL - NEW "BEAM" Version of Forward Backward    // ******************************************************************************        public BeamLattice forwardBackwardBeam (Sequence inputSequence)    {        return forwardBackwardBeam (inputSequence, null, false);    }    public BeamLattice forwardBackwardBeam (Sequence inputSequence, boolean increment)    {        return forwardBackwardBeam (inputSequence, null, increment);    }    public BeamLattice forwardBackwardBeam (Sequence inputSequence, Sequence outputSequence)    {        return forwardBackwardBeam (inputSequence, outputSequence, false);    }    public BeamLattice forwardBackwardBeam (Sequence inputSequence, Sequence outputSequence, boolean increment)    {        return forwardBackwardBeam (inputSequence, outputSequence, increment, null);    }    public BeamLattice forwardBackwardBeam (Sequence inputSequence, Sequence outputSequence, boolean increment,                                                                    LabelAlphabet outputAlphabet)    {    return forwardBackwardBeam (inputSequence, outputSequence, increment, false, outputAlphabet);  }    public BeamLattice forwardBackwardBeam (Sequence inputSequence, Sequence outputSequence, boolean increment,                                    boolean saveXis, LabelAlphabet outputAlphabet)    {        // xxx We don't do epsilon transitions for now        assert (outputSequence == null                        || inputSequence.size() == outputSequence.size());        return new BeamLattice (inputSequence, outputSequence, increment, saveXis, outputAlphabet);    }    // culotta: interface for constrained lattice    /**         Create constrained lattice such that all paths pass through the         the labeling of <code> requiredSegment </code> as indicated by         <code> constrainedSequence </code>         @param inputSequence input sequence         @param outputSequence output sequence         @param requiredSegment segment of sequence that must be labelled         @param constrainedSequence lattice must have labels of this         sequence from <code> requiredSegment.start </code> to <code>         requiredSegment.end </code> correctly    */    public BeamLattice forwardBackwardBeam (Sequence inputSequence,                                                                    Sequence outputSequence,                                                                    Segment requiredSegment,                                                                    Sequence constrainedSequence) {        if (constrainedSequence.size () != inputSequence.size ())            throw new IllegalArgumentException ("constrainedSequence.size [" + constrainedSequence.size () + "] != inputSequence.size [" + inputSequence.size () + "]");        // constraints tells the lattice which states must emit which        // observations.  positive values say all paths must pass through        // this state index, negative values say all paths must _not_        // pass through this state index.  0 means we don't        // care. initialize to 0. include 1 extra node for start state.        int [] constraints = new int [constrainedSequence.size() + 1];        for (int c = 0; c < constraints.length; c++)            constraints[c] = 0;        for (int i=requiredSegment.getStart (); i <= requiredSegment.getEnd(); i++) {            int si = stateIndexOfString ((String)constrainedSequence.get (i));            if (si == -1)                logger.warning ("Could not find state " + constrainedSequence.get (i) + ". Check that state labels match startTages and inTags, and that all labels are seen in training data.");//				throw new IllegalArgumentException ("Could not find state " + constrainedSequence.get(i) + ". Check that state labels match startTags and InTags.");            constraints[i+1] = si + 1;        }        // set additional negative constraint to ensure state after        // segment is not a continue tag        // xxx if segment length=1, this actually constrains the sequence        // to B-tag (B-tag)', instead of the intended constraint of B-tag        // (I-tag)'        // the fix below is unsafe, but will have to do for now.        // FIXED BELOW/*		String endTag = (String) constrainedSequence.get (requiredSegment.getEnd ());		if (requiredSegment.getEnd()+2 < constraints.length) {			if (requiredSegment.getStart() == requiredSegment.getEnd()) { // segment has length 1				if (endTag.startsWith ("B-")) {					endTag = "I" + endTag.substring (1, endTag.length());				}				else if (!(endTag.startsWith ("I-") || endTag.startsWith ("0")))					throw new IllegalArgumentException ("Constrained Lattice requires that states are tagged in B-I-O format.");			}			int statei = stateIndexOfString (endTag);			if (statei == -1) // no I- tag for this B- tag				statei = stateIndexOfString ((String)constrainedSequence.get (requiredSegment.getStart ()));			constraints[requiredSegment.getEnd() + 2] = - (statei + 1);		}*/        if (requiredSegment.getEnd() + 2 < constraints.length) { // if            String endTag = requiredSegment.getInTag().toString();            int statei = stateIndexOfString (endTag);            if (statei == -1)                throw new IllegalArgumentException ("Could not find state " + endTag + ". Check that state labels match startTags and InTags.");            constraints[requiredSegment.getEnd() + 2] = - (statei + 1);        }        //		printStates ();        logger.fine ("Segment:\n" + requiredSegment.sequenceToString () +                                 "\nconstrainedSequence:\n" + constrainedSequence +                                 "\nConstraints:\n");        for (int i=0; i < constraints.length; i++) {            logger.fine (constraints[i] + "\t");        }        logger.fine ("");        return forwardBackwardBeam (inputSequence, outputSequence, constraints);    }/*    public int stateIndexOfString (String s)    {        for (int i = 0; i < this.numStates(); i++) {            String state = this.getState (i).getName();            if (state.equals (s))                return i;         }         return -1;    }    private void printStates () {        for (int i = 0; i < this.numStates(); i++)            logger.fine (i + ":" + this.getState (i).getName());    }    public void print () {        logger.fine ("Transducer "+this);        printStates();    }*/    public BeamLattice forwardBackwardBeam (Sequence inputSequence,                                                                    Sequence outputSequence, int [] constraints)    {        return new BeamLattice (inputSequence, outputSequence, false, null,                                                constraints);    }    // Remove this method?    // If "increment" is true, call incrementInitialCount, incrementFinalCount and incrementCount    private BeamLattice forwardBackwardBeam (SequencePair inputOutputPair, boolean increment) {        return this.forwardBackwardBeam (inputOutputPair.input(), inputOutputPair.output(), increment);    }        public class BeamLattice // CPAL - like Lattice but using max-product to get the viterbiPath    {		// "ip" == "input position", "op" == "output position", "i" == "state index"		double cost;		Sequence input, output;		LatticeNode[][] nodes;			 // indexed by ip,i		int latticeLength;        int curBeamWidth;               // CPAL - can be adapted if maximizer is confused		// xxx Now that we are incrementing here directly, there isn't		// necessarily a need to save all these arrays...		// -log(probability) of being in state "i" at input position "ip"		double[][] gammas;					 // indexed by ip,i        double[][][] xis;            // indexed by ip,i,j; saved only if saveXis is true;		LabelVector labelings[];			 // indexed by op, created only if "outputAlphabet" is non-null in constructor		private LatticeNode getLatticeNode (int ip, int stateIndex)		{			if (nodes[ip][stateIndex] == null)				nodes[ip][stateIndex] = new LatticeNode (ip, getState (stateIndex));			return nodes[ip][stateIndex];		}		// You may pass null for output, meaning that the lattice		// is not constrained to match the output		protected BeamLattice (Sequence input, Sequence output, boolean increment)		{			this (input, output, increment, false, null);		}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -