📄 crf_pl.java
字号:
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.base.fst;import edu.umass.cs.mallet.base.maximize.LimitedMemoryBFGS;import edu.umass.cs.mallet.base.maximize.Maximizable;import edu.umass.cs.mallet.base.maximize.Maximizer;import edu.umass.cs.mallet.base.types.*;import edu.umass.cs.mallet.base.util.ArrayUtils;import edu.umass.cs.mallet.base.util.MalletLogger;import edu.umass.cs.mallet.base.util.Maths;import gnu.trove.TIntArrayList;import java.io.Serializable;import java.util.ArrayList;import java.util.List;import java.util.logging.Logger;/** * A CRF trained by pseudolikelihood. zt test time, the standard, globally-normalized * model as used, as in the clssical work on pseudolikelihood. */public class CRF_PL extends CRF4 implements Serializable{ private static Logger logger = MalletLogger.getLogger(CRF_PL.class.getName()); private static final String LABEL_SEPARATOR = "^"; public boolean dumpProbabilities = false; private boolean gatheringTrainingData = false; // After training sets have been gathered in the states, record which // InstanceList we've gathers, so we don't double-count instances. private InstanceList trainingGatheredFor; // I hate that trainingSets can't be a member of MaximizableCRF... breaks thread-safety. // It would be better if you could pass a hook to forwardBackward than this TransitionIterator business /* Indexed by [state@-1][state@+1] */ private static class PLInstance { // Names of the true states involved String stateNameL; String stateNameC; String stateNameR; FeatureVector fv0; // Feature Vector for L-->C transition FeatureVector fv1; // Feature Vector for C-->R transition double weight; // Rest is for debugging int ip; // Position from original sequence int inum; public PLInstance (String stateNameL, String stateNameC, String stateNameR, FeatureVector fv0, FeatureVector fv1, int inum, int ip, double weight) { this.stateNameL = stateNameL; this.stateNameC = stateNameC; this.stateNameR = stateNameR; this.fv0 = fv0; this.fv1 = fv1; this.inum = inum; this.ip = ip; this.weight = weight; } } private List[][] trainingSets; private List startingInstances; private List endingInstances; // If true, use normalizatio at test time as in Toutanova et al. private boolean normalizeCosts = false; public CRF_PL (CRF4 crf) { super (crf.getInputAlphabet (), crf.getOutputAlphabet ()); this.inputPipe = crf.inputPipe; this.outputPipe = crf.outputPipe; // To do local normalization, we use a second-order model with scores p(y_t | y_{t-1},y_{t+1},x_t) // on the transition (y_t-1&y_t ==> y_t+1) // This trick helps when gathering training sets, too. makeSecondOrderStatesFrom (crf); } private void makeSecondOrderStatesFrom (CRF4 initialCrf) { weightAlphabet = new Alphabet (); for (int widx = 0; widx < initialCrf.weightAlphabet.size(); widx++) { getWeightsIndex (initialCrf.getWeightsName (widx)); } for (int snum = 0; snum < initialCrf.numStates (); snum++) { CRF4.State s = (CRF4.State) initialCrf.getState (snum); for (int didx = 0; didx < s.destinationNames.length; didx++) { CRF4.State dest = initialCrf.getState (s.destinationNames[didx]); String newStateName = s.getName () + LABEL_SEPARATOR + dest.getName (); // create new destination names String[] newDests = new String[dest.destinationNames.length]; for (int didx2 = 0; didx2 < dest.destinationNames.length; didx2++) { newDests[didx2] = dest.getName () + LABEL_SEPARATOR + dest.destinationNames[didx2]; } // and new weight names. On String[][] weightNames = new String[dest.weightsIndices.length][]; int[][] prevWeightIndices = new int [dest.weightsIndices.length][]; for (int j = 0; j < weightNames.length; j++) { TIntArrayList weightIdxList = new TIntArrayList (); weightIdxList.add (dest.weightsIndices [j]); int[] widxs = weightIdxList.toNativeArray (); weightNames[j] = (String[]) initialCrf.weightAlphabet.lookupObjects (widxs, new String[widxs.length]); // Computing p(C|L,R) requires having the weights L-->C as well as C-->R prevWeightIndices [j] = (int[]) s.weightsIndices [didx].clone(); } // better initial & final state handling? addState (newStateName, INFINITE_COST, dest.finalCost, newDests, dest.labels, weightNames); State theState = (State) getState (newStateName); theState.prevWeightsIndices = prevWeightIndices; } } // Add start states for (int snum = 0; snum < initialCrf.numStates (); snum++) { CRF4.State s = (CRF4.State) initialCrf.getState (snum); if (Double.isInfinite (s.getInitialCost ())) continue; String[] destNames = new String [s.destinationNames.length]; String[][] weightNames = new String [s.weightsIndices.length][]; for (int didx = 0; didx < s.destinationNames.length; didx++) { destNames [didx] = s.getName() + LABEL_SEPARATOR + s.destinationNames[didx]; int[] widxs = s.weightsIndices[didx]; weightNames [didx] = (String[]) weightAlphabet.lookupObjects (widxs, new String [widxs.length]); } addState (s.getName(), s.getInitialCost (), s.getFinalCost (), destNames, s.labels, weightNames); } } protected CRF4.State newState (String name, int index, double initialCost, double finalCost, String[] destinationNames, String[] labelNames, String[][] weightNames, CRF4 crf) { return new State (name, index, initialCost, finalCost, destinationNames, labelNames, weightNames, crf); } public boolean train (InstanceList training, InstanceList validation, InstanceList testing, TransducerEvaluator eval, int numIterations) { if (numIterations <= 0) return false; assert (training.size() > 0); initializeTrainingFor (training); if (false) { // Expectation-based placement of training data would go here. for (int i = 0; i < training.size(); i++) { Instance instance = training.getInstance(i); FeatureVectorSequence input = (FeatureVectorSequence) instance.getData(); FeatureSequence output = (FeatureSequence) instance.getTarget(); // Do it for the paths consistent with the labels... gatheringConstraints = true; forwardBackward (input, output, true); // ...and also do it for the paths selected by the current model (so we will get some negative weights) gatheringConstraints = false; if (this.someTrainingDone) // (do this once some training is done) forwardBackward (input, null, true); } gatheringWeightsPresent = false; SparseVector[] newWeights = new SparseVector[weights.length]; for (int i = 0; i < weights.length; i++) { int numLocations = weightsPresent[i].cardinality (); logger.info ("CRF weights["+weightAlphabet.lookupObject(i)+"] num features = "+numLocations); int[] indices = new int[numLocations]; for (int j = 0; j < numLocations; j++) { indices[j] = weightsPresent[i].nextSetBit (j == 0 ? 0 : indices[j-1]+1); //System.out.println ("CRF4 has index "+indices[j]); } newWeights[i] = new IndexedSparseVector (indices, new double[numLocations], numLocations, numLocations, false, false, false); newWeights[i].plusEqualsSparse (weights[i]); } weights = newWeights; } MaximizableCRF_PL maximizable = new MaximizableCRF_PL (training, this); // Gather the constraints maximizable.gatherExpectationsOrConstraints (true); Maximizer.ByGradient maximizer = new LimitedMemoryBFGS();// Maximizer.ByGradient maximizer = new GradientAscent ();// ((GradientAscent)maximizer).setLineMaximizer (new BoldDriverLineSearch ()); int i; boolean continueTraining = true; boolean converged = false; boolean retry = true; logger.info ("CRF about to train with "+numIterations+" iterations"); for (i = 0; i < numIterations; i++) { try { converged = maximizer.maximize (maximizable, 1); logger.info ("CRF finished one iteration of maximizer, i="+i); retry = true; } catch (IllegalArgumentException e) { e.printStackTrace(); if (retry && maximizer instanceof LimitedMemoryBFGS) { logger.info ("Catching expception "+e+"... retrying..."); ((LimitedMemoryBFGS)maximizer).reset (); retry = false; } else { logger.info ("Catching exception; saying converged."); converged = true; } } if (eval != null) { continueTraining = eval.evaluate (this, (converged || i == numIterations-1), i, converged, maximizable.getValue(), training, validation, testing); if (!continueTraining) break; } if (converged) { logger.info ("CRF training has converged, i="+i); break; } } logger.info ("About to setTrainable(false)"); // Free the memory of the expectations and constraints setTrainable (false); logger.info ("Done setTrainable(false)"); return converged; } public void initializeTrainingFor (InstanceList training) { initOriginalStates (); // Allocate space for the parameters, and place transition FeatureVectors in // per-source-state InstanceLists. // Here, gatheringTrainingSets will be true, and these methods will result // in new InstanceList's being created in each source state, and the FeatureVectors // of their outgoing transitions to be added to them as the data field in the Instances. if (trainingGatheredFor != training) { gatherTrainingSets (training); } if (useSparseWeights) { setWeightsDimensionAsIn (training); } else { setWeightsDimensionDensely (); } } // hack for debugging private int inum; public void gatherTrainingSets (InstanceList training) { if (trainingGatheredFor != null) { // It would be easy enough to support this, just got through all the states and set trainingSet to null. logger.warning ("Training sets already gathered. Clearing...."); } trainingSets = new ArrayList [numOriginalStates][numOriginalStates]; startingInstances = new ArrayList (); endingInstances = new ArrayList (); trainingGatheredFor = training; gatheringTrainingData = true; for (int i = 0; i < training.size(); i++) { Instance instance = training.getInstance(i); inum = i; FeatureVectorSequence input = (FeatureVectorSequence) instance.getData(); FeatureSequence output = (FeatureSequence) instance.getTarget(); // Do it for the paths consistent with the labels... forwardBackward (input, output, true); } gatheringTrainingData = false; int total = 0; for (int i = 0; i < trainingSets.length; i++) { for (int j = 0; j < trainingSets[i].length; j++) { if (trainingSets[i][j] != null) { total += trainingSets[i][j].size(); } } } logger.info ("Total local training instances = "+total); } public boolean train (InstanceList training, InstanceList validation, InstanceList testing, TransducerEvaluator eval, int numIterations, int numIterationsPerProportion, double[] trainingProportions) { throw new UnsupportedOperationException(); } public boolean trainWithFeatureInduction (InstanceList trainingData, InstanceList validationData, InstanceList testingData, TransducerEvaluator eval, int numIterations, int numIterationsBetweenFeatureInductions, int numFeatureInductions, int numFeaturesPerFeatureInduction, double trueLabelProbThreshold, boolean clusteredFeatureInduction, double[] trainingProportions, String gainName) { throw new UnsupportedOperationException(); } public MaximizableCRF getMaximizableCRF (InstanceList ilist) { return new MaximizableCRF_PL (ilist, this); } public void printInstanceLists () { for (int i = 0; i < numOriginalStates; i++) { String s1 = getOriginalStateName (i); for (int j = 0; j < numOriginalStates; j++) { String s2 = getOriginalStateName (j); List training = trainingSets[i][j]; System.out.println ("States ("+i+","+j+") : ("+s1+","+s2); if (training == null) { System.out.println ("No data"); continue; } for (int inum = 0; inum < training.size(); inum++) { PLInstance inst = (PLInstance) training.get (inum); System.out.println ("State C : "+inst.stateNameC); System.out.println ("Instance "+inum+" weight: "+inst.weight); System.out.println ("FV0 is\n"+inst.fv0); System.out.println ("FV1 is\n"+inst.fv1); } } } } public static class State extends CRF4.State implements Serializable { // The weights used by the previous feature vector int[][] prevWeightsIndices; protected State (String name, int index, double initialCost, double finalCost, String[] destinationNames, String[] labelNames, String[][] weightNames, CRF4 crf) { super (name, index, initialCost, finalCost, destinationNames, labelNames, weightNames, crf); } public void setPrevWeightsIndices (int[][] prevWeightsIndices) { this.prevWeightsIndices = prevWeightsIndices; } // Necessary because the CRF4 implementation will return CRF4.TransitionIterator public Transducer.TransitionIterator transitionIterator ( Sequence inputSequence, int inputPosition, Sequence outputSequence, int outputPosition) { if (inputPosition < 0 || outputPosition < 0) throw new UnsupportedOperationException ("Epsilon transitions not implemented."); if (inputSequence == null) throw new UnsupportedOperationException ("CRFs are not generative models; must have an input sequence."); return ((CRF_PL)crf).new TransitionIterator ( this, (FeatureVectorSequence)inputSequence, inputPosition, (outputSequence == null ? null : (String)outputSequence.get(outputPosition)), crf);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -