📄 transducer.java
字号:
// You may pass null for output, meaning that the lattice // is not constrained to match the output protected BeamLattice (Sequence input, Sequence output, boolean increment, boolean saveXis) { this (input, output, increment, saveXis, null); } // If outputAlphabet is non-null, this will create a LabelVector // for each position in the output sequence indicating the // probability distribution over possible outputs at that time // index protected BeamLattice (Sequence input, Sequence output, boolean increment, boolean saveXis, LabelAlphabet outputAlphabet) { if (false && logger.isLoggable (Level.FINE)) { logger.fine ("Starting Lattice"); logger.fine ("Input: "); for (int ip = 0; ip < input.size(); ip++) logger.fine (" " + input.get(ip)); logger.fine ("\nOutput: "); if (output == null) logger.fine ("null"); else for (int op = 0; op < output.size(); op++) logger.fine (" " + output.get(op)); logger.fine ("\n"); } // Initialize some structures this.input = input; this.output = output; // xxx Not very efficient when the lattice is actually sparse, // especially when the number of states is large and the // sequence is long. latticeLength = input.size()+1; int numStates = numStates(); nodes = new LatticeNode[latticeLength][numStates]; // xxx Yipes, this could get big; something sparse might be better? gammas = new double[latticeLength][numStates]; if (saveXis) xis = new double[latticeLength][numStates][numStates]; double outputCounts[][] = null; if (outputAlphabet != null) outputCounts = new double[latticeLength][outputAlphabet.size()]; for (int i = 0; i < numStates; i++) { for (int ip = 0; ip < latticeLength; ip++) gammas[ip][i] = INFINITE_COST; if (saveXis) for (int j = 0; j < numStates; j++) for (int ip = 0; ip < latticeLength; ip++) xis[ip][i][j] = INFINITE_COST; } // Forward pass logger.fine ("Starting Foward pass"); boolean atLeastOneInitialState = false; for (int i = 0; i < numStates; i++) { double initialCost = getState(i).initialCost; //System.out.println ("Forward pass initialCost = "+initialCost); if (initialCost < INFINITE_COST) { getLatticeNode(0, i).alpha = initialCost; //System.out.println ("nodes[0][i].alpha="+nodes[0][i].alpha); atLeastOneInitialState = true; } } if (atLeastOneInitialState == false) logger.warning ("There are no starting states!"); // CPAL - a sorted list for our beam experiments NBestSlist[] slists = new NBestSlist[latticeLength]; // CPAL - used for stats nstatesExpl = new double[latticeLength]; // CPAL - used to adapt beam if optimizer is getting confused // tctIter++; if(curIter == 0) { curBeamWidth = numStates; } else if(tctIter > 1 && curIter != 0) { //curBeamWidth = Math.min((int)Math.round(curAvgNstatesExpl*2),numStates); //System.out.println ("Doubling Minimum Beam Size to: "+curBeamWidth); curBeamWidth = beamWidth; } else { curBeamWidth = beamWidth; } // ************************************************************ for (int ip = 0; ip < latticeLength-1; ip++) { // CPAL - add this to construct the beam // *************************************************** // CPAL - sets up the sorted list slists[ip] = new NBestSlist(numStates); // CPAL - set the slists[ip].setKLMinE(curBeamWidth); slists[ip].setKLeps(KLeps); slists[ip].setRmin(Rmin); for(int i = 0 ; i< numStates ; i++){ if (nodes[ip][i] == null || nodes[ip][i].alpha == INFINITE_COST) continue; //State s = getState(i); // CPAL - give the NB viterbi node the (cost, position) NBForBackNode cnode = new NBForBackNode(nodes[ip][i].alpha, i); slists[ip].push(cnode); } // CPAL - unlike std. n-best beam we now filter the list based // on a KL divergence like measure // *************************************************** // use method which computes the cumulative log sum and // finds the point at which the sum is within KLeps int KLMaxPos=1; int RminPos=1; if(KLeps > 0) { KLMaxPos = slists[ip].getKLpos(); nstatesExpl[ip]=(double)KLMaxPos; } else if(KLeps == 0) { if(Rmin > 0) { RminPos = slists[ip].getTHRpos(); } else { slists[ip].setRmin(-Rmin); RminPos = slists[ip].getTHRposSTRAWMAN(); } nstatesExpl[ip]=(double)RminPos; } else { // Trick, negative values for KLeps mean use the max of KL an Rmin slists[ip].setKLeps(-KLeps); KLMaxPos = slists[ip].getKLpos(); //RminPos = slists[ip].getTHRpos(); if(Rmin > 0) { RminPos = slists[ip].getTHRpos(); } else { slists[ip].setRmin(-Rmin); RminPos = slists[ip].getTHRposSTRAWMAN(); } if(KLMaxPos > RminPos) { nstatesExpl[ip]=(double)KLMaxPos; } else { nstatesExpl[ip]=(double)RminPos; } } //System.out.println(nstatesExpl[ip] + " "); // CPAL - contemplating setting values to something else int tmppos; for (int i = (int) nstatesExpl[ip]+1; i < slists[ip].size(); i++) { tmppos = slists[ip].getPosByIndex(i); nodes[ip][tmppos].alpha = INFINITE_COST; nodes[ip][tmppos] = null; // Null is faster and seems to work the same } // - done contemplation //for (int i = 0; i < numStates; i++) { for(int jj=0 ; jj< nstatesExpl[ip]; jj++) { int i = slists[ip].getPosByIndex(jj); // CPAL - dont need this anymore // should be taken care of in the lists //if (nodes[ip][i] == null || nodes[ip][i].alpha == INFINITE_COST) // xxx if we end up doing this a lot, // we could save a list of the non-null ones // continue; State s = getState(i); TransitionIterator iter = s.transitionIterator (input, ip, output, ip); if (logger.isLoggable (Level.FINE)) logger.fine (" Starting Foward transition iteration from state " + s.getName() + " on input " + input.get(ip).toString() + " and output " + (output==null ? "(null)" : output.get(ip).toString())); while (iter.hasNext()) { State destination = iter.nextState(); if (logger.isLoggable (Level.FINE)) logger.fine ("Forward Lattice[inputPos="+ip +"][source="+s.getName() +"][dest="+destination.getName()+"]"); LatticeNode destinationNode = getLatticeNode (ip+1, destination.getIndex()); destinationNode.output = iter.getOutput(); double transitionCost = iter.getCost(); if (logger.isLoggable (Level.FINE)) logger.fine ("transitionCost="+transitionCost +" nodes["+ip+"]["+i+"].alpha="+nodes[ip][i].alpha +" destinationNode.alpha="+destinationNode.alpha); destinationNode.alpha = sumNegLogProb (destinationNode.alpha, nodes[ip][i].alpha + transitionCost); //System.out.println ("destinationNode.alpha <- "+destinationNode.alpha); } } } //System.out.println("Mean Nodes Explored: " + MatrixOps.mean(nstatesExpl)); curAvgNstatesExpl = MatrixOps.mean(nstatesExpl); // Calculate total cost of Lattice. This is the normalizer cost = INFINITE_COST; for (int i = 0; i < numStates; i++) if (nodes[latticeLength-1][i] != null) { // Note: actually we could sum at any ip index, // the choice of latticeLength-1 is arbitrary //System.out.println ("Ending alpha, state["+i+"] = "+nodes[latticeLength-1][i].alpha); //System.out.println ("Ending beta, state["+i+"] = "+getState(i).finalCost); cost = sumNegLogProb (cost, (nodes[latticeLength-1][i].alpha + getState(i).finalCost)); } // Cost is now an "unnormalized cost" of the entire Lattice //assert (cost >= 0) : "cost = "+cost; // If the sequence has infinite cost, just return. // Usefully this avoids calling any incrementX methods. // It also relies on the fact that the gammas[][] and .alpha and .beta values // are already initialized to values that reflect infinite cost // xxx Although perhaps not all (alphas,betas) exactly correctly reflecting? if (cost == INFINITE_COST) return; // Backward pass for (int i = 0; i < numStates; i++) if (nodes[latticeLength-1][i] != null) { State s = getState(i); nodes[latticeLength-1][i].beta = s.finalCost; gammas[latticeLength-1][i] = nodes[latticeLength-1][i].alpha + nodes[latticeLength-1][i].beta - cost; if (increment) { double p = Math.exp(-gammas[latticeLength-1][i]); assert (p < INFINITE_COST && !Double.isNaN(p)) : "p="+p+" gamma="+gammas[latticeLength-1][i]; s.incrementFinalCount (p); } } for (int ip = latticeLength-2; ip >= 0; ip--) { for (int i = 0; i < numStates; i++) { if (nodes[ip][i] == null || nodes[ip][i].alpha == INFINITE_COST) // Note that skipping here based on alpha means that beta values won't // be correct, but since alpha is infinite anyway, it shouldn't matter. continue; State s = getState(i); TransitionIterator iter = s.transitionIterator (input, ip, output, ip); while (iter.hasNext()) { State destination = iter.nextState(); if (logger.isLoggable (Level.FINE)) logger.fine ("Backward Lattice[inputPos="+ip +"][source="+s.getName() +"][dest="+destination.getName()+"]"); int j = destination.getIndex(); LatticeNode destinationNode = nodes[ip+1][j]; if (destinationNode != null) { double transitionCost = iter.getCost(); assert (!Double.isNaN(transitionCost)); // assert (transitionCost >= 0); Not necessarily double oldBeta = nodes[ip][i].beta; assert (!Double.isNaN(nodes[ip][i].beta)); nodes[ip][i].beta = sumNegLogProb (nodes[ip][i].beta, destinationNode.beta + transitionCost); assert (!Double.isNaN(nodes[ip][i].beta)) : "dest.beta="+destinationNode.beta+" trans="+transitionCost+" sum="+(destinationNode.beta+transitionCost) + " oldBeta="+oldBeta; double xi = nodes[ip][i].alpha + transitionCost + nodes[ip+1][j].beta - cost; if (saveXis) xis[ip][i][j] = xi; assert (!Double.isNaN(nodes[ip][i].alpha)); assert (!Double.isNaN(transitionCost)); assert (!Double.isNaN(nodes[ip+1][j].beta)); assert (!Double.isNaN(cost)); if (increment || outputAlphabet != null) { double p = Math.exp(-xi); assert (p < INFINITE_COST && !Double.isNaN(p)) : "xis["+ip+"]["+i+"]["+j+"]="+-xi; if (increment) iter.incrementCount (p); if (outputAlphabet != null) { int outputIndex = outputAlphabet.lookupIndex (iter.getOutput(), false); assert (outputIndex >= 0); // xxx This assumes that "ip" == "op"! outputCounts[ip][outputIndex] += p; //System.out.println ("CRF Lattice outputCounts["+ip+"]["+outputIndex+"]+="+p); } } } } gammas[ip][i] = nodes[ip][i].alpha + nodes[ip][i].beta - cost; } if(true){ // CPAL - check the normalization double checknorm = INFINITE_COST; for (int i = 0; i < numStates; i++) if (nodes[ip][i] != null) { // Note: actually we could sum at any ip index, // the choice of latticeLength-1 is arbitrary //System.out.println ("Ending alpha, state["+i+"] = "+nodes[latticeLength-1][i].alpha); //System.out.println ("Ending beta, state["+i+"] = "+getState(i).finalCost); checknorm = sumNegLogProb (checknorm, gammas[ip][i]); } // System.out.println ("Check Gamma, sum="+checknorm); // CPAL - done check of normalization // CPAL - normalize for (int i = 0; i < numStates; i++) if (nodes[ip][i] != null) { gammas[ip][i] = gammas[ip][i] - checknorm; } //System.out.println ("Check Gamma, sum="+checknorm); // CPAL - normalization } } if (increment) for (int i = 0; i < numStates; i++) { double p = Math.exp(-gammas[0][i]); assert (p < INFINITE_CO
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -