⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 transducer.java

📁 mallet是自然语言处理、机器学习领域的一个开源项目。
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
    // You may pass null for output, meaning that the lattice    // is not constrained to match the output    protected BeamLattice (Sequence input, Sequence output, boolean increment, boolean saveXis)    {      this (input, output, increment, saveXis, null);    }		// If outputAlphabet is non-null, this will create a LabelVector		// for each position in the output sequence indicating the		// probability distribution over possible outputs at that time		// index		protected BeamLattice (Sequence input, Sequence output, boolean increment, boolean saveXis, LabelAlphabet outputAlphabet)		{			if (false && logger.isLoggable (Level.FINE)) {				logger.fine ("Starting Lattice");				logger.fine ("Input: ");				for (int ip = 0; ip < input.size(); ip++)					logger.fine (" " + input.get(ip));				logger.fine ("\nOutput: ");				if (output == null)					logger.fine ("null");				else					for (int op = 0; op < output.size(); op++)						logger.fine (" " + output.get(op));				logger.fine ("\n");			}			// Initialize some structures			this.input = input;			this.output = output;			// xxx Not very efficient when the lattice is actually sparse,			// especially when the number of states is large and the			// sequence is long.			latticeLength = input.size()+1;			int numStates = numStates();			nodes = new LatticeNode[latticeLength][numStates];			// xxx Yipes, this could get big; something sparse might be better?			gammas = new double[latticeLength][numStates];			if (saveXis) xis = new double[latticeLength][numStates][numStates];			double outputCounts[][] = null;			if (outputAlphabet != null)				outputCounts = new double[latticeLength][outputAlphabet.size()];			for (int i = 0; i < numStates; i++) {				for (int ip = 0; ip < latticeLength; ip++)					gammas[ip][i] = INFINITE_COST;        if (saveXis)          for (int j = 0; j < numStates; j++)					  for (int ip = 0; ip < latticeLength; ip++)						  xis[ip][i][j] = INFINITE_COST;			}			// Forward pass			logger.fine ("Starting Foward pass");			boolean atLeastOneInitialState = false;			for (int i = 0; i < numStates; i++) {				double initialCost = getState(i).initialCost;				//System.out.println ("Forward pass initialCost = "+initialCost);				if (initialCost < INFINITE_COST) {					getLatticeNode(0, i).alpha = initialCost;					//System.out.println ("nodes[0][i].alpha="+nodes[0][i].alpha);					atLeastOneInitialState = true;				}			}			if (atLeastOneInitialState == false)				logger.warning ("There are no starting states!");            // CPAL - a sorted list for our beam experiments            NBestSlist[] slists = new NBestSlist[latticeLength];            // CPAL - used for stats            nstatesExpl = new double[latticeLength];            // CPAL - used to adapt beam if optimizer is getting confused            // tctIter++;            if(curIter == 0) {                curBeamWidth = numStates;            } else if(tctIter > 1 && curIter != 0) {                //curBeamWidth = Math.min((int)Math.round(curAvgNstatesExpl*2),numStates);                //System.out.println ("Doubling Minimum Beam Size to: "+curBeamWidth);                curBeamWidth = beamWidth;            } else {                curBeamWidth = beamWidth;            }            // ************************************************************			for (int ip = 0; ip < latticeLength-1; ip++) {                    // CPAL - add this to construct the beam                             // ***************************************************                             // CPAL - sets up the sorted list                             slists[ip] = new NBestSlist(numStates);                             // CPAL - set the                             slists[ip].setKLMinE(curBeamWidth);                             slists[ip].setKLeps(KLeps);                             slists[ip].setRmin(Rmin);                             for(int i = 0 ; i< numStates ; i++){                                  if (nodes[ip][i] == null || nodes[ip][i].alpha == INFINITE_COST)                                     continue;                                 //State s = getState(i);                                 // CPAL - give the NB viterbi node the (cost, position)                                 NBForBackNode cnode = new NBForBackNode(nodes[ip][i].alpha, i);                                 slists[ip].push(cnode);                             }                             // CPAL - unlike std. n-best beam we now filter the list based                             // on a KL divergence like measure                             // ***************************************************                             // use method which computes the cumulative log sum and                             // finds the point at which the sum is within KLeps                             int KLMaxPos=1;                             int RminPos=1;                             if(KLeps > 0) {                                 KLMaxPos = slists[ip].getKLpos();                                 nstatesExpl[ip]=(double)KLMaxPos;                             } else if(KLeps == 0) {                                 if(Rmin > 0) {                                     RminPos = slists[ip].getTHRpos();                                 } else {                                     slists[ip].setRmin(-Rmin);                                     RminPos = slists[ip].getTHRposSTRAWMAN();                                 }                                 nstatesExpl[ip]=(double)RminPos;                             } else {                                 // Trick, negative values for KLeps mean use the max of KL an Rmin                                 slists[ip].setKLeps(-KLeps);                                 KLMaxPos = slists[ip].getKLpos();                                 //RminPos = slists[ip].getTHRpos();                                 if(Rmin > 0) {                                     RminPos = slists[ip].getTHRpos();                                 } else {                                     slists[ip].setRmin(-Rmin);                                     RminPos = slists[ip].getTHRposSTRAWMAN();                                 }                                 if(KLMaxPos > RminPos) {                                     nstatesExpl[ip]=(double)KLMaxPos;                                 } else {                                     nstatesExpl[ip]=(double)RminPos;                                 }                             }                             //System.out.println(nstatesExpl[ip] + " ");                // CPAL - contemplating setting values to something else                int tmppos;                for (int i = (int) nstatesExpl[ip]+1; i < slists[ip].size(); i++) {                    tmppos = slists[ip].getPosByIndex(i);                    nodes[ip][tmppos].alpha = INFINITE_COST;                    nodes[ip][tmppos] = null;   // Null is faster and seems to work the same                }                // - done contemplation				//for (int i = 0; i < numStates; i++) {                for(int jj=0 ; jj< nstatesExpl[ip]; jj++) {                    int i = slists[ip].getPosByIndex(jj);                    // CPAL - dont need this anymore                    // should be taken care of in the lists					//if (nodes[ip][i] == null || nodes[ip][i].alpha == INFINITE_COST)						// xxx if we end up doing this a lot,						// we could save a list of the non-null ones					//	continue;					State s = getState(i);					TransitionIterator iter = s.transitionIterator (input, ip, output, ip);					if (logger.isLoggable (Level.FINE))						logger.fine (" Starting Foward transition iteration from state "												 + s.getName() + " on input " + input.get(ip).toString()												 + " and output "												 + (output==null ? "(null)" : output.get(ip).toString()));					while (iter.hasNext()) {						State destination = iter.nextState();						if (logger.isLoggable (Level.FINE))							logger.fine ("Forward Lattice[inputPos="+ip													 +"][source="+s.getName()													 +"][dest="+destination.getName()+"]");						LatticeNode destinationNode = getLatticeNode (ip+1, destination.getIndex());						destinationNode.output = iter.getOutput();						double transitionCost = iter.getCost();						if (logger.isLoggable (Level.FINE))							logger.fine ("transitionCost="+transitionCost													 +" nodes["+ip+"]["+i+"].alpha="+nodes[ip][i].alpha													 +" destinationNode.alpha="+destinationNode.alpha);						destinationNode.alpha = sumNegLogProb (destinationNode.alpha,																									 nodes[ip][i].alpha + transitionCost);						//System.out.println ("destinationNode.alpha <- "+destinationNode.alpha);					}				}            }            //System.out.println("Mean Nodes Explored: " + MatrixOps.mean(nstatesExpl));            curAvgNstatesExpl = MatrixOps.mean(nstatesExpl);			// Calculate total cost of Lattice.  This is the normalizer			cost = INFINITE_COST;			for (int i = 0; i < numStates; i++)				if (nodes[latticeLength-1][i] != null) {					// Note: actually we could sum at any ip index,					// the choice of latticeLength-1 is arbitrary					//System.out.println ("Ending alpha, state["+i+"] = "+nodes[latticeLength-1][i].alpha);					//System.out.println ("Ending beta,  state["+i+"] = "+getState(i).finalCost);					cost = sumNegLogProb (cost,																(nodes[latticeLength-1][i].alpha + getState(i).finalCost));				}			// Cost is now an "unnormalized cost" of the entire Lattice			//assert (cost >= 0) : "cost = "+cost;			// If the sequence has infinite cost, just return.			// Usefully this avoids calling any incrementX methods.			// It also relies on the fact that the gammas[][] and .alpha and .beta values			// are already initialized to values that reflect infinite cost			// xxx Although perhaps not all (alphas,betas) exactly correctly reflecting?			if (cost == INFINITE_COST)				return;			// Backward pass			for (int i = 0; i < numStates; i++)				if (nodes[latticeLength-1][i] != null) {					State s = getState(i);					nodes[latticeLength-1][i].beta = s.finalCost;					gammas[latticeLength-1][i] =						nodes[latticeLength-1][i].alpha + nodes[latticeLength-1][i].beta - cost;					if (increment) {						double p = Math.exp(-gammas[latticeLength-1][i]);						assert (p < INFINITE_COST && !Double.isNaN(p))							: "p="+p+" gamma="+gammas[latticeLength-1][i];						s.incrementFinalCount (p);					}				}			for (int ip = latticeLength-2; ip >= 0; ip--) {				for (int i = 0; i < numStates; i++) {					if (nodes[ip][i] == null || nodes[ip][i].alpha == INFINITE_COST)						// Note that skipping here based on alpha means that beta values won't						// be correct, but since alpha is infinite anyway, it shouldn't matter.						continue;					State s = getState(i);					TransitionIterator iter = s.transitionIterator (input, ip, output, ip);					while (iter.hasNext()) {						State destination = iter.nextState();						if (logger.isLoggable (Level.FINE))							logger.fine ("Backward Lattice[inputPos="+ip													 +"][source="+s.getName()													 +"][dest="+destination.getName()+"]");						int j = destination.getIndex();						LatticeNode destinationNode = nodes[ip+1][j];						if (destinationNode != null) {							double transitionCost = iter.getCost();							assert (!Double.isNaN(transitionCost));							//							assert (transitionCost >= 0);  Not necessarily							double oldBeta = nodes[ip][i].beta;							assert (!Double.isNaN(nodes[ip][i].beta));							nodes[ip][i].beta = sumNegLogProb (nodes[ip][i].beta,																								 destinationNode.beta + transitionCost);							assert (!Double.isNaN(nodes[ip][i].beta))								: "dest.beta="+destinationNode.beta+" trans="+transitionCost+" sum="+(destinationNode.beta+transitionCost)								+ " oldBeta="+oldBeta;              double xi = nodes[ip][i].alpha + transitionCost + nodes[ip+1][j].beta - cost;							if (saveXis) xis[ip][i][j] = xi;							assert (!Double.isNaN(nodes[ip][i].alpha));							assert (!Double.isNaN(transitionCost));							assert (!Double.isNaN(nodes[ip+1][j].beta));							assert (!Double.isNaN(cost));							if (increment || outputAlphabet != null) {								double p = Math.exp(-xi);								assert (p < INFINITE_COST && !Double.isNaN(p)) : "xis["+ip+"]["+i+"]["+j+"]="+-xi;								if (increment)									iter.incrementCount (p);								if (outputAlphabet != null) {									int outputIndex = outputAlphabet.lookupIndex (iter.getOutput(), false);									assert (outputIndex >= 0);									// xxx This assumes that "ip" == "op"!									outputCounts[ip][outputIndex] += p;									//System.out.println ("CRF Lattice outputCounts["+ip+"]["+outputIndex+"]+="+p);								}							}						}					}					gammas[ip][i] = nodes[ip][i].alpha + nodes[ip][i].beta - cost;				}                if(true){                // CPAL - check the normalization                double checknorm = INFINITE_COST;			    for (int i = 0; i < numStates; i++)				if (nodes[ip][i] != null) {					// Note: actually we could sum at any ip index,					// the choice of latticeLength-1 is arbitrary					//System.out.println ("Ending alpha, state["+i+"] = "+nodes[latticeLength-1][i].alpha);					//System.out.println ("Ending beta,  state["+i+"] = "+getState(i).finalCost);					checknorm = sumNegLogProb (checknorm, gammas[ip][i]);				}                // System.out.println ("Check Gamma, sum="+checknorm);                // CPAL - done check of normalization                // CPAL - normalize			    for (int i = 0; i < numStates; i++)				if (nodes[ip][i] != null) {					gammas[ip][i] = gammas[ip][i] - checknorm;				}                //System.out.println ("Check Gamma, sum="+checknorm);                // CPAL - normalization                }			}			if (increment)				for (int i = 0; i < numStates; i++) {					double p = Math.exp(-gammas[0][i]);					assert (p < INFINITE_CO

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -