⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 crf_pl.java

📁 mallet是自然语言处理、机器学习领域的一个开源项目。
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
		}    public void incrementFinalCount (double count)    {      if (!((CRF_PL)crf).gatheringTrainingData) {        super.incrementFinalCount (count);      }    }    public void incrementInitialCount (double count)    {      if (!((CRF_PL)crf).gatheringTrainingData) {        super.incrementInitialCount (count);      }    }	}	protected class TransitionIterator extends CRF4.TransitionIterator implements Serializable	{    // You need *two* feature vectors, not one, to compute PL.    FeatureVector fv0;    FeatureVector fv1;    int ip;    boolean isStart = false; // True if at first transition    boolean isLast = false;  // True if at last transition    public TransitionIterator (State source,															 FeatureVectorSequence inputSeq,															 int inputPosition,															 String output, CRF4 memm)		{			super (source, inputSeq, inputPosition, output, memm);      this.fv0 = inputSeq.getFeatureVector (inputPosition);      if (inputPosition > 0)        this.fv1 = inputSeq.getFeatureVector (inputPosition-1);      this.ip = inputPosition;      if (ip == 0) isStart = true;      if (ip == inputSeq.size()) isLast = true;		}		public TransitionIterator (State source,															 FeatureVector fv,															 String output, CRF4 memm)		{			super (source, fv, output, memm);		}    public double getCost ()    {      if (normalizeCosts) {        double cost = super.getCost ();        double logZ = computeLocalLogZ (fv0, fv1, (State) source, (State) getDestinationState ());        return cost - logZ;      } else {        return super.getCost ();      }    }    // I'm going to regret writing a separate functio to do this. -cas    public void incrementCount (double count)    {      if (((CRF_PL) crf).gatheringTrainingData) {        if (count != 0) {          State dest = (State) getDestinationState ();          if (count != 1) {            System.out.println ("Huh?");          }          String stateNameL = leftNameOfState ((State) source);          String stateNameR = rightNameOfState (dest);          int stateL = leftIndexFromStateName (crf, source.getName ());          int stateR = rightIndexFromStateName (crf, dest.getName ());          // Create the source state's trainingSet if it doesn't exist yet.          if (trainingSets[stateL][stateR] == null)          // New InstanceList with a null pipe, because it doesn't do any processing of input.           trainingSets[stateL][stateR] = new ArrayList ();          // xxx We should make sure we don't add duplicates (through a second call to setWeightsDimenstion..!          // xxx Note that when the training data still allows ambiguous outgoing transitions          // this will add the same FV more than once to the source state's trainingSet, each          // with >1.0 weight.  Not incorrect, but inefficient.//        System.out.println ("From: "+source.getName()+" ---> "+getOutput()+" : "+getInput());          String stateNameC = rightNameOfState ((State) source);          if (!isStart) {            PLInstance inst = new PLInstance (stateNameL, stateNameC, stateNameR, fv0, fv1, inum, ip, count);            trainingSets[stateL][stateR].add (inst);          }          if (isStart) {            PLInstance inst = new PLInstance (null, stateNameL, stateNameC, fv0, null, inum, ip, count);            startingInstances.add (inst);          }          if (isLast) {            PLInstance inst = new PLInstance (stateNameC, stateNameR, null, fv1, null, inum, ip, count);            endingInstances.add (inst);          }        }      } else {        super.incrementCount (count);      }    }  }  private double computeLocalLogZ (FeatureVector fv0, FeatureVector fv1, State biState1, State biState2)  {    // Let the transition sources state be L,C and the destinatin be C,R.    //  To do the pseudolikelihood normalization, we need to normalizate over    //  all values of C!  Messy, messy.    // OTOH, this also means that for a single transition, each destiantion will    //  have a different normalizing factor,which could be an advantage.    double[] costs = new double [numOriginalStates];    String stateNameL = leftNameOfState (biState1);    String stateNameR = rightNameOfState (biState2);    for (int i = 0; i < numOriginalStates; i++) {      String stateNameC = getOriginalStateName (i);      State twiddledState1 = (State) getState (stateNameL + LABEL_SEPARATOR + stateNameC);      State twiddledState2 = (State) getState (stateNameC + LABEL_SEPARATOR + stateNameR);      if ((twiddledState1 != null) && (twiddledState2 != null) && isTransition (twiddledState1, twiddledState2)) {        costs [i] = transitionCost (fv0, fv1, twiddledState1, twiddledState2);      } else {        costs [i] = Double.NEGATIVE_INFINITY;      }    }    double logZ = Maths.sumLogProb (costs);    return logZ;  }  private boolean isTransition (State s1, State s2)  {    return (ArrayUtils.indexOf (s1.destinationNames, s2.getName()) >= 0);  }  private double transitionCost (FeatureVector fv0, FeatureVector fv1, State biState1, State biState2)  {    int widx = ArrayUtils.indexOf (biState1.destinationNames, biState2.getName());    int[] weightIndices = biState1.weightsIndices [widx];    int[] prevWeightIndices = biState1.prevWeightsIndices [widx];    double sum = 0;    sum += weightsDotProduct (weightIndices, fv1);    sum += weightsDotProduct (prevWeightIndices, fv0);    return sum;  }  private double weightsDotProduct (int[] weightIndices, FeatureVector fv)  {    double sum = 0;    for (int wi = 0; wi < weightIndices.length; wi++) {      int weightsIndex = weightIndices[wi];      SparseVector w = weights [weightsIndex];      sum += w.dotProduct (fv) + defaultWeights[weightsIndex];    }    return sum;  }  private double logTransitionProb (FeatureVector fv0, FeatureVector fv1, State biState1, State biState2)  {    double cost = transitionCost (fv0, fv1, biState1, biState2);    double logZ = computeLocalLogZ (fv0, fv1, biState1, biState2);    return cost - logZ;  }  // Decoding the 2nd-order state names    int leftIndexFromStateName (CRF4 crf, String name)    {      String leftName = leftNameOfState (name);      int idx = originalStateNames.lookupIndex (leftName, false);      if (idx == -1)        throw new IllegalStateException ("Could not extract left state name from "+name+"  Tried "+leftName);      return idx;    }  private String leftNameOfState (State state) { return leftNameOfState (state.getName ()); }  private String leftNameOfState (String name)  {    int leftIdx = name.indexOf (LABEL_SEPARATOR);    if (leftIdx < 0) {      return name;    } else {      String leftName = name.substring (0, leftIdx);      return leftName;    }  }    int rightIndexFromStateName (CRF4 crf, String name)    {      String rightName = rightNameOfState (name);      int idx = originalStateNames.lookupIndex (rightName, false);      if (idx == -1)        throw new IllegalStateException ("Could not extract left state name from "+name+"  Tried "+rightName);      return idx;    }  private String rightNameOfState (State state) { return rightNameOfState (state.getName ()); }  private String rightNameOfState (String name)  {    int leftIdx = name.indexOf (LABEL_SEPARATOR);    String rightName = name.substring (leftIdx + 1);    return rightName;  }  private int numOriginalStates;  private Alphabet originalStateNames;  private void initOriginalStates ()  {    originalStateNames = new Alphabet ();    for (int snum = 0; snum < numStates (); snum++) {      State state = (State) getState (snum);      if (!higherOrderState(state)) continue;      String stateL = leftNameOfState (state);      originalStateNames.lookupIndex (stateL);      String stateR = rightNameOfState (state);      originalStateNames.lookupIndex (stateR);    }    numOriginalStates = originalStateNames.size();  }  private boolean higherOrderState (State state)  {      return state.getName().indexOf (LABEL_SEPARATOR) > 0;  }  private String getOriginalStateName (int idx)  {    return (String) originalStateNames.lookupObject (idx);  }	public class MaximizableCRF_PL extends MaximizableCRF implements Maximizable.ByGradient	{		protected MaximizableCRF_PL (InstanceList trainingData, CRF_PL memm)		{			super (trainingData, memm);		}		// if constraints=false, return log probability of the training labels		protected double gatherExpectationsOrConstraints (boolean constraints)		{			double totalLogProb = 0;			for (int i = 0; i < numOriginalStates; i++) {				String stateNameL = getOriginalStateName (i);        for (int j = 0; j < numOriginalStates; j++) {          String stateNameR = getOriginalStateName (j);          List training = trainingSets[i][j];				  if (training == null) {					  continue;				  }				for (int inum = 0; inum < training.size(); inum++) {					PLInstance instance = (PLInstance) training.get (inum);					double instWeight = instance.weight;					FeatureVector fv0 = instance.fv0;					FeatureVector fv1 = instance.fv1;          // Compute the value          // This way doesn't allow different labels than states, but I don't know what the best way to implement          //   that is right now.          String stateNameC = instance.stateNameC;          State trueBiSource = concatState (stateNameL, stateNameC);          State trueBiDest = concatState (stateNameC, stateNameR);          double logLabelProb = logTransitionProb (fv0, fv1, trueBiSource, trueBiDest);          totalLogProb += instWeight * logLabelProb;          if (dumpProbabilities) {            System.out.println ("Instance "+instance.inum+" pos "+instance.ip+" (w="+instWeight+") prob = "+Math.exp (logLabelProb));          }          // Now do the gradient          if (constraints) {            incrementConstraints (trueBiSource, trueBiDest, fv0, fv1, instWeight);          } else {            for (int stateC = 0; stateC < numOriginalStates; stateC++) {              String twiddledNameC = getOriginalStateName (stateC);              State biState1 = concatState (stateNameL, twiddledNameC);              State biState2 = concatState (twiddledNameC, stateNameR);              if ((biState1 != null) && (biState2 != null) && isTransition (biState1, biState2)) {                double logProb = logTransitionProb (fv0, fv1, biState1, biState2);                incrementExpectations (biState1, biState2, fv0, fv1, Math.exp (logProb) * instWeight);              }            }          }				}        }      }      //xxx Now the starting and ending instances      /*      for (int i = 0; i < startingInstances.size(); i++) {        PLInstance inst = (PLInstance) startingInstances.get (i);      }      */      // Force initial & final costs to 0 ???      for (int i = 0; i < crf.numStates(); i++) {        State s = (State) crf.getState (i);        s.initialExpectation = s.initialConstraint;        s.finalExpectation = s.finalConstraint;      }			return totalLogProb;		}    private void incrementConstraints (State source, State target, FeatureVector fv0, FeatureVector fv1, double prob)    {      int targetIdx = ArrayUtils.indexOf (source.destinationNames, target.getName());      int[] widxs = source.weightsIndices [targetIdx];      for (int wi = 0; wi < widxs.length; wi++) {        int widx = widxs[wi];        constraints[widx].plusEqualsSparse (fv0, prob);        defaultConstraints[widx] += prob;      }      widxs = source.prevWeightsIndices [targetIdx];      for (int wi = 0; wi < widxs.length; wi++) {        int widx = widxs[wi];        constraints[widx].plusEqualsSparse (fv1, prob);        defaultConstraints[widx] += prob;      }    }    private void incrementExpectations (State source, State target, FeatureVector fv0, FeatureVector fv1, double prob)    {      int targetIdx = ArrayUtils.indexOf (source.destinationNames, target.getName());      int[] widxs = source.weightsIndices [targetIdx];      // Increment expectations for current transition      for (int wi = 0; wi < widxs.length; wi++) {        int widx = widxs[wi];        expectations[widx].plusEqualsSparse (fv0, prob);        defaultExpectations[widx] += prob;      }      // And previous transition      widxs = source.prevWeightsIndices [targetIdx];      for (int wi = 0; wi < widxs.length; wi++) {        int widx = widxs[wi];        expectations[widx].plusEqualsSparse (fv1, prob);        defaultExpectations[widx] += prob;      }    }    private State concatState (String stateNameL, String stateNameR)    {      String biStateName = stateNameL + LABEL_SEPARATOR + stateNameR;      return (State) getState (biStateName);    }    // log probability of the training sequence labels, and fill in expectations[]		protected double getExpectationValue ()		{			return gatherExpectationsOrConstraints (false);		}    protected void gatherConstraints (InstanceList ilist)    {      gatherExpectationsOrConstraints (true);    }	}}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -