⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 memm.java

📁 mallet是自然语言处理、机器学习领域的一个开源项目。
💻 JAVA
字号:
/* Copyright (C) 2004 Univ. of Massachusetts Amherst, Computer Science Dept.   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This software is provided under the terms of the Common Public License,   version 1.0, as published by http://www.opensource.org.  For further   information, see the file `LICENSE' included with this distribution. *//** 		@author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>    MEMM might have been simply implemented with a MaxEnt classifier object at each state,    but I chose not to do that so that tied features could be used in different parts of the    FSM, just as in CRF.  So, the expectation-gathering is done (in MEMM-style) without    forward-backward, just with local (normalized) distributions over destination states    from source states, but there is a global MaximizebleMEMM, and all the MEMMs parameters    are set together as part of a single optimization. */package edu.umass.cs.mallet.base.fst;import edu.umass.cs.mallet.base.maximize.LimitedMemoryBFGS;import edu.umass.cs.mallet.base.maximize.Maximizable;import edu.umass.cs.mallet.base.maximize.Maximizer;import edu.umass.cs.mallet.base.pipe.Pipe;import edu.umass.cs.mallet.base.types.*;import edu.umass.cs.mallet.base.util.MalletLogger;import edu.umass.cs.mallet.base.classify.MaxEnt;import java.io.Serializable;import java.util.BitSet;import java.util.logging.Logger;import java.text.DecimalFormat;public class MEMM extends CRF4 implements Serializable{	private static Logger logger = MalletLogger.getLogger(MEMM.class.getName());	private boolean gatheringTrainingData = false;  // After training sets have been gathered in the states, record which  //   InstanceList we've gathers, so we don't double-count instances.  private InstanceList trainingGatheredFor;  public MEMM (Pipe inputPipe, Pipe outputPipe)	{		super (inputPipe, outputPipe);	}	public MEMM (Alphabet inputAlphabet, Alphabet outputAlphabet)	{		super (inputAlphabet, outputAlphabet);	}  public MEMM (CRF4 crf)  {    super (crf);  }	protected CRF4.State newState (String name, int index,	                               double initialCost, double finalCost,	                               String[] destinationNames,	                               String[] labelNames,	                               String[][] weightNames,	                               CRF4 crf)	{		return new State (name, index, initialCost, finalCost,		                  destinationNames, labelNames, weightNames, crf);	}	public boolean train (InstanceList training, InstanceList validation, InstanceList testing,												TransducerEvaluator eval, int numIterations)	{		if (numIterations <= 0)			return false;		assert (training.size() > 0);		// Allocate space for the parameters, and place transition FeatureVectors in		// per-source-state InstanceLists.		// Here, gatheringTrainingSets will be true, and these methods will result		// in new InstanceList's being created in each source state, and the FeatureVectors		// of their outgoing transitions to be added to them as the data field in the Instances.    if (trainingGatheredFor != training) {		  gatherTrainingSets (training);    }    if (useSparseWeights) {			setWeightsDimensionAsIn (training);		} else {			setWeightsDimensionDensely ();		}		if (false) {			// Expectation-based placement of training data would go here.			for (int i = 0; i < training.size(); i++) {				Instance instance = training.getInstance(i);				FeatureVectorSequence input = (FeatureVectorSequence) instance.getData();				FeatureSequence output = (FeatureSequence) instance.getTarget();				// Do it for the paths consistent with the labels...				gatheringConstraints = true;				forwardBackward (input, output, true);				// ...and also do it for the paths selected by the current model (so we will get some negative weights)				gatheringConstraints = false;				if (this.someTrainingDone)				// (do this once some training is done)					forwardBackward (input, null, true);			}			gatheringWeightsPresent = false;			SparseVector[] newWeights = new SparseVector[weights.length];			for (int i = 0; i < weights.length; i++) {				int numLocations = weightsPresent[i].cardinality ();				logger.info ("CRF weights["+weightAlphabet.lookupObject(i)+"] num features = "+numLocations);				int[] indices = new int[numLocations];				for (int j = 0; j < numLocations; j++) {					indices[j] = weightsPresent[i].nextSetBit (j == 0 ? 0 : indices[j-1]+1);					//System.out.println ("CRF4 has index "+indices[j]);				}				newWeights[i] = new IndexedSparseVector (indices, new double[numLocations],				                                         numLocations, numLocations, false, false, false);				newWeights[i].plusEqualsSparse (weights[i]);			}			weights = newWeights;		}		MaximizableMEMM maximizable = new MaximizableMEMM (training, this);		// Gather the constraints		maximizable.gatherExpectationsOrConstraints (true);		Maximizer.ByGradient maximizer = new LimitedMemoryBFGS();		int i;		boolean continueTraining = true;		boolean converged = false;		logger.info ("CRF about to train with "+numIterations+" iterations");		for (i = 0; i < numIterations; i++) {			try {				converged = maximizer.maximize (maximizable, 1);				logger.info ("CRF finished one iteration of maximizer, i="+i);			} catch (IllegalArgumentException e) {				e.printStackTrace();				logger.info ("Catching exception; saying converged.");				converged = true;			}			if (eval != null) {				continueTraining = eval.evaluate (this, (converged || i == numIterations-1), i,																					converged, maximizable.getValue(), training, validation, testing);				if (!continueTraining)					break;			}			if (converged) {				logger.info ("CRF training has converged, i="+i);				break;			}		}		logger.info ("About to setTrainable(false)");		// Free the memory of the expectations and constraints		setTrainable (false);		logger.info ("Done setTrainable(false)");		return converged;	}  void gatherTrainingSets (InstanceList training)  {    if (trainingGatheredFor != null) {      // It would be easy enough to support this, just got through all the states and set trainingSet to null.      throw new UnsupportedOperationException ("Training with multiple sets not supported.");    }    trainingGatheredFor = training;    gatheringTrainingData = true;    for (int i = 0; i < training.size(); i++) {        Instance instance = training.getInstance(i);        FeatureVectorSequence input = (FeatureVectorSequence) instance.getData();        FeatureSequence output = (FeatureSequence) instance.getTarget();        // Do it for the paths consistent with the labels...        forwardBackward (input, output, true);     }     gatheringTrainingData = false;  }  public boolean train (InstanceList training, InstanceList validation, InstanceList testing,												TransducerEvaluator eval, int numIterations,												int numIterationsPerProportion,												double[] trainingProportions)	{		throw new UnsupportedOperationException();	}	public boolean trainWithFeatureInduction (InstanceList trainingData,																						InstanceList validationData, InstanceList testingData,																						TransducerEvaluator eval, int numIterations,																						int numIterationsBetweenFeatureInductions,																						int numFeatureInductions,																						int numFeaturesPerFeatureInduction,																						double trueLabelProbThreshold,																						boolean clusteredFeatureInduction,																						double[] trainingProportions,																						String gainName)	{		throw new UnsupportedOperationException();	}	public MaximizableCRF getMaximizableCRF (InstanceList ilist)	{		return new MaximizableMEMM (ilist, this);	}  public void printInstanceLists ()  {    for (int i = 0; i < numStates (); i++) {      State state = (State) getState (i);      InstanceList training = state.trainingSet;      System.out.println ("State "+i+" : "+state.getName());      if (training == null) {        System.out.println ("No data");        continue;      }      for (int j = 0; j < training.size(); j++) {        Instance inst = training.getInstance (j);        System.out.println ("From : "+state.getName()+" To : "+inst.getTarget());        System.out.println ("Instance "+j);        System.out.println (inst.getTarget());        System.out.println (inst.getData());      }    }  }  public static class State extends CRF4.State implements Serializable	{		InstanceList trainingSet;		protected State (String name, int index,										 double initialCost, double finalCost,										 String[] destinationNames,										 String[] labelNames,										 String[][] weightNames,										 CRF4 crf)		{		  super (name, index, initialCost, finalCost, destinationNames, labelNames, weightNames, crf);		}		// Necessary because the CRF4 implementation will return CRF4.TransitionIterator		public Transducer.TransitionIterator transitionIterator (			Sequence inputSequence, int inputPosition,			Sequence outputSequence, int outputPosition)		{			if (inputPosition < 0 || outputPosition < 0)				throw new UnsupportedOperationException ("Epsilon transitions not implemented.");			if (inputSequence == null)				throw new UnsupportedOperationException ("CRFs are not generative models; must have an input sequence.");			return new TransitionIterator (				this, (FeatureVectorSequence)inputSequence, inputPosition,				(outputSequence == null ? null : (String)outputSequence.get(outputPosition)), crf);		}    public void incrementFinalCount (double count)    {      if (!((MEMM)crf).gatheringTrainingData) {        super.incrementFinalCount (count);      }    }    public void incrementInitialCount (double count)    {      if (!((MEMM)crf).gatheringTrainingData) {        super.incrementInitialCount (count);      }    }	}	protected static class TransitionIterator extends CRF4.TransitionIterator implements Serializable	{    private double sum;		public TransitionIterator (State source,															 FeatureVectorSequence inputSeq,															 int inputPosition,															 String output, CRF4 memm)		{			super (source, inputSeq, inputPosition, output, memm);			normalizeCosts ();		}		public TransitionIterator (State source,															 FeatureVector fv,															 String output, CRF4 memm)		{			super (source, fv, output, memm);			normalizeCosts ();		}		private void normalizeCosts ()		{			// Normalize the next-state costs, so they are -(log-probabilities)			// This is the heart of the difference between the locally-normalized MEMM			// and the globally-normalized CRF		  sum = INFINITE_COST;			for (int i = 0; i < costs.length; i++)				sum = sumNegLogProb (sum, costs[i]);			assert (!Double.isNaN (sum));			if (!Double.isInfinite (sum)) {			  for (int i = 0; i < costs.length; i++)				  costs[i] -= sum;      }		}    public void incrementCount (double count)    {      if (((MEMM) crf).gatheringTrainingData) {        if (!crf.someTrainingDone && count != 0) {          // Create the source state's trainingSet if it doesn't exist yet.          if (((MEMM.State) source).trainingSet == null)          // New InstanceList with a null pipe, because it doesn't do any processing of input.            ((MEMM.State) source).trainingSet = new InstanceList (null);          // xxx We should make sure we don't add duplicates (through a second call to setWeightsDimenstion..!          // xxx Note that when the training data still allows ambiguous outgoing transitions          // this will add the same FV more than once to the source state's trainingSet, each          // with >1.0 weight.  Not incorrect, but inefficient.//        System.out.println ("From: "+source.getName()+" ---> "+getOutput()+" : "+getInput());          ((MEMM.State) source).trainingSet.add (this.getInput (), this.getOutput (), null, null, count);        }      } else {        super.incrementCount (count);      }    }    public String describeTransition (double cutoff)    {      DecimalFormat f = new DecimalFormat ("0.###");      return super.describeTransition (cutoff) + "Log Z = "+f.format(sum)+"\n";    }	}	public class MaximizableMEMM extends MaximizableCRF implements Maximizable.ByGradient	{		protected MaximizableMEMM (InstanceList trainingData, MEMM memm)		{			super (trainingData, memm);		}		// if constraints=false, return log probability of the training labels		protected double gatherExpectationsOrConstraints (boolean constraints)		{			// Instance values must either always or never be included in			// the total values; we can't just sometimes skip a value			// because it is infinite, this throws off the total values.			boolean initializingInfiniteValues = false;			if (infiniteValues == null) {				infiniteValues = new BitSet ();				initializingInfiniteValues = true;			}			double labelLogProb = 0;			for (int i = 0; i < crf.numStates(); i++) {				MEMM.State s = (State) crf.getState (i);				if (s.trainingSet == null) {					System.out.println ("Empty training set for state "+s.name);					continue;				}				for (int j = 0; j < s.trainingSet.size(); j++) {					Instance instance = s.trainingSet.getInstance (j);					double instWeight = s.trainingSet.getInstanceWeight (j);					FeatureVector fv = (FeatureVector) instance.getData ();					String labelString = (String) instance.getTarget ();					TransitionIterator iter = new TransitionIterator (s, fv, constraints?labelString:null, crf);					while (iter.hasNext ()) {						State destination = (MEMM.State) iter.nextState();  // Just to advance the iterator						double cost = iter.getCost();						iter.incrementCount (Math.exp(-cost) * instWeight);						if (!constraints && iter.getOutput() == labelString) {							if (!Double.isInfinite (cost))								labelLogProb += -instWeight * cost; // xxx   ?????							else {								logger.warning ("State "+i+" transition "+j+" has infinite cost; skipping.");								if (initializingInfiniteValues)									throw new IllegalStateException ("Infinite-cost transitions not yet supported"); //infiniteValues.set (j);								else if (!infiniteValues.get(j))									throw new IllegalStateException ("Instance i used to have non-infinite value, "									                                 +"but now it has infinite value.");							}						}					}				}      }      // Force initial & final costs to 0 ???      for (int i = 0; i < crf.numStates(); i++) {        MEMM.State s = (State) crf.getState (i);        s.initialExpectation = s.initialConstraint;        s.finalExpectation = s.finalConstraint;      }			return labelLogProb;		}		// log probability of the training sequence labels, and fill in expectations[]		protected double getExpectationValue ()		{			return gatherExpectationsOrConstraints (false);		}	}}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -