📄 crf_pl.java
字号:
} public void incrementFinalCount (double count) { if (!((CRF_PL)crf).gatheringTrainingData) { super.incrementFinalCount (count); } } public void incrementInitialCount (double count) { if (!((CRF_PL)crf).gatheringTrainingData) { super.incrementInitialCount (count); } } } protected class TransitionIterator extends CRF4.TransitionIterator implements Serializable { // You need *two* feature vectors, not one, to compute PL. FeatureVector fv0; FeatureVector fv1; int ip; boolean isStart = false; // True if at first transition boolean isLast = false; // True if at last transition public TransitionIterator (State source, FeatureVectorSequence inputSeq, int inputPosition, String output, CRF4 memm) { super (source, inputSeq, inputPosition, output, memm); this.fv0 = inputSeq.getFeatureVector (inputPosition); if (inputPosition > 0) this.fv1 = inputSeq.getFeatureVector (inputPosition-1); this.ip = inputPosition; if (ip == 0) isStart = true; if (ip == inputSeq.size()) isLast = true; } public TransitionIterator (State source, FeatureVector fv, String output, CRF4 memm) { super (source, fv, output, memm); } public double getCost () { if (normalizeCosts) { double cost = super.getCost (); double logZ = computeLocalLogZ (fv0, fv1, (State) source, (State) getDestinationState ()); return cost - logZ; } else { return super.getCost (); } } // I'm going to regret writing a separate functio to do this. -cas public void incrementCount (double count) { if (((CRF_PL) crf).gatheringTrainingData) { if (count != 0) { State dest = (State) getDestinationState (); if (count != 1) { System.out.println ("Huh?"); } String stateNameL = leftNameOfState ((State) source); String stateNameR = rightNameOfState (dest); int stateL = leftIndexFromStateName (crf, source.getName ()); int stateR = rightIndexFromStateName (crf, dest.getName ()); // Create the source state's trainingSet if it doesn't exist yet. if (trainingSets[stateL][stateR] == null) // New InstanceList with a null pipe, because it doesn't do any processing of input. trainingSets[stateL][stateR] = new ArrayList (); // xxx We should make sure we don't add duplicates (through a second call to setWeightsDimenstion..! // xxx Note that when the training data still allows ambiguous outgoing transitions // this will add the same FV more than once to the source state's trainingSet, each // with >1.0 weight. Not incorrect, but inefficient.// System.out.println ("From: "+source.getName()+" ---> "+getOutput()+" : "+getInput()); String stateNameC = rightNameOfState ((State) source); if (!isStart) { PLInstance inst = new PLInstance (stateNameL, stateNameC, stateNameR, fv0, fv1, inum, ip, count); trainingSets[stateL][stateR].add (inst); } if (isStart) { PLInstance inst = new PLInstance (null, stateNameL, stateNameC, fv0, null, inum, ip, count); startingInstances.add (inst); } if (isLast) { PLInstance inst = new PLInstance (stateNameC, stateNameR, null, fv1, null, inum, ip, count); endingInstances.add (inst); } } } else { super.incrementCount (count); } } } private double computeLocalLogZ (FeatureVector fv0, FeatureVector fv1, State biState1, State biState2) { // Let the transition sources state be L,C and the destinatin be C,R. // To do the pseudolikelihood normalization, we need to normalizate over // all values of C! Messy, messy. // OTOH, this also means that for a single transition, each destiantion will // have a different normalizing factor,which could be an advantage. double[] costs = new double [numOriginalStates]; String stateNameL = leftNameOfState (biState1); String stateNameR = rightNameOfState (biState2); for (int i = 0; i < numOriginalStates; i++) { String stateNameC = getOriginalStateName (i); State twiddledState1 = (State) getState (stateNameL + LABEL_SEPARATOR + stateNameC); State twiddledState2 = (State) getState (stateNameC + LABEL_SEPARATOR + stateNameR); if ((twiddledState1 != null) && (twiddledState2 != null) && isTransition (twiddledState1, twiddledState2)) { costs [i] = transitionCost (fv0, fv1, twiddledState1, twiddledState2); } else { costs [i] = Double.NEGATIVE_INFINITY; } } double logZ = Maths.sumLogProb (costs); return logZ; } private boolean isTransition (State s1, State s2) { return (ArrayUtils.indexOf (s1.destinationNames, s2.getName()) >= 0); } private double transitionCost (FeatureVector fv0, FeatureVector fv1, State biState1, State biState2) { int widx = ArrayUtils.indexOf (biState1.destinationNames, biState2.getName()); int[] weightIndices = biState1.weightsIndices [widx]; int[] prevWeightIndices = biState1.prevWeightsIndices [widx]; double sum = 0; sum += weightsDotProduct (weightIndices, fv1); sum += weightsDotProduct (prevWeightIndices, fv0); return sum; } private double weightsDotProduct (int[] weightIndices, FeatureVector fv) { double sum = 0; for (int wi = 0; wi < weightIndices.length; wi++) { int weightsIndex = weightIndices[wi]; SparseVector w = weights [weightsIndex]; sum += w.dotProduct (fv) + defaultWeights[weightsIndex]; } return sum; } private double logTransitionProb (FeatureVector fv0, FeatureVector fv1, State biState1, State biState2) { double cost = transitionCost (fv0, fv1, biState1, biState2); double logZ = computeLocalLogZ (fv0, fv1, biState1, biState2); return cost - logZ; } // Decoding the 2nd-order state names int leftIndexFromStateName (CRF4 crf, String name) { String leftName = leftNameOfState (name); int idx = originalStateNames.lookupIndex (leftName, false); if (idx == -1) throw new IllegalStateException ("Could not extract left state name from "+name+" Tried "+leftName); return idx; } private String leftNameOfState (State state) { return leftNameOfState (state.getName ()); } private String leftNameOfState (String name) { int leftIdx = name.indexOf (LABEL_SEPARATOR); if (leftIdx < 0) { return name; } else { String leftName = name.substring (0, leftIdx); return leftName; } } int rightIndexFromStateName (CRF4 crf, String name) { String rightName = rightNameOfState (name); int idx = originalStateNames.lookupIndex (rightName, false); if (idx == -1) throw new IllegalStateException ("Could not extract left state name from "+name+" Tried "+rightName); return idx; } private String rightNameOfState (State state) { return rightNameOfState (state.getName ()); } private String rightNameOfState (String name) { int leftIdx = name.indexOf (LABEL_SEPARATOR); String rightName = name.substring (leftIdx + 1); return rightName; } private int numOriginalStates; private Alphabet originalStateNames; private void initOriginalStates () { originalStateNames = new Alphabet (); for (int snum = 0; snum < numStates (); snum++) { State state = (State) getState (snum); if (!higherOrderState(state)) continue; String stateL = leftNameOfState (state); originalStateNames.lookupIndex (stateL); String stateR = rightNameOfState (state); originalStateNames.lookupIndex (stateR); } numOriginalStates = originalStateNames.size(); } private boolean higherOrderState (State state) { return state.getName().indexOf (LABEL_SEPARATOR) > 0; } private String getOriginalStateName (int idx) { return (String) originalStateNames.lookupObject (idx); } public class MaximizableCRF_PL extends MaximizableCRF implements Maximizable.ByGradient { protected MaximizableCRF_PL (InstanceList trainingData, CRF_PL memm) { super (trainingData, memm); } // if constraints=false, return log probability of the training labels protected double gatherExpectationsOrConstraints (boolean constraints) { double totalLogProb = 0; for (int i = 0; i < numOriginalStates; i++) { String stateNameL = getOriginalStateName (i); for (int j = 0; j < numOriginalStates; j++) { String stateNameR = getOriginalStateName (j); List training = trainingSets[i][j]; if (training == null) { continue; } for (int inum = 0; inum < training.size(); inum++) { PLInstance instance = (PLInstance) training.get (inum); double instWeight = instance.weight; FeatureVector fv0 = instance.fv0; FeatureVector fv1 = instance.fv1; // Compute the value // This way doesn't allow different labels than states, but I don't know what the best way to implement // that is right now. String stateNameC = instance.stateNameC; State trueBiSource = concatState (stateNameL, stateNameC); State trueBiDest = concatState (stateNameC, stateNameR); double logLabelProb = logTransitionProb (fv0, fv1, trueBiSource, trueBiDest); totalLogProb += instWeight * logLabelProb; if (dumpProbabilities) { System.out.println ("Instance "+instance.inum+" pos "+instance.ip+" (w="+instWeight+") prob = "+Math.exp (logLabelProb)); } // Now do the gradient if (constraints) { incrementConstraints (trueBiSource, trueBiDest, fv0, fv1, instWeight); } else { for (int stateC = 0; stateC < numOriginalStates; stateC++) { String twiddledNameC = getOriginalStateName (stateC); State biState1 = concatState (stateNameL, twiddledNameC); State biState2 = concatState (twiddledNameC, stateNameR); if ((biState1 != null) && (biState2 != null) && isTransition (biState1, biState2)) { double logProb = logTransitionProb (fv0, fv1, biState1, biState2); incrementExpectations (biState1, biState2, fv0, fv1, Math.exp (logProb) * instWeight); } } } } } } //xxx Now the starting and ending instances /* for (int i = 0; i < startingInstances.size(); i++) { PLInstance inst = (PLInstance) startingInstances.get (i); } */ // Force initial & final costs to 0 ??? for (int i = 0; i < crf.numStates(); i++) { State s = (State) crf.getState (i); s.initialExpectation = s.initialConstraint; s.finalExpectation = s.finalConstraint; } return totalLogProb; } private void incrementConstraints (State source, State target, FeatureVector fv0, FeatureVector fv1, double prob) { int targetIdx = ArrayUtils.indexOf (source.destinationNames, target.getName()); int[] widxs = source.weightsIndices [targetIdx]; for (int wi = 0; wi < widxs.length; wi++) { int widx = widxs[wi]; constraints[widx].plusEqualsSparse (fv0, prob); defaultConstraints[widx] += prob; } widxs = source.prevWeightsIndices [targetIdx]; for (int wi = 0; wi < widxs.length; wi++) { int widx = widxs[wi]; constraints[widx].plusEqualsSparse (fv1, prob); defaultConstraints[widx] += prob; } } private void incrementExpectations (State source, State target, FeatureVector fv0, FeatureVector fv1, double prob) { int targetIdx = ArrayUtils.indexOf (source.destinationNames, target.getName()); int[] widxs = source.weightsIndices [targetIdx]; // Increment expectations for current transition for (int wi = 0; wi < widxs.length; wi++) { int widx = widxs[wi]; expectations[widx].plusEqualsSparse (fv0, prob); defaultExpectations[widx] += prob; } // And previous transition widxs = source.prevWeightsIndices [targetIdx]; for (int wi = 0; wi < widxs.length; wi++) { int widx = widxs[wi]; expectations[widx].plusEqualsSparse (fv1, prob); defaultExpectations[widx] += prob; } } private State concatState (String stateNameL, String stateNameR) { String biStateName = stateNameL + LABEL_SEPARATOR + stateNameR; return (State) getState (biStateName); } // log probability of the training sequence labels, and fill in expectations[] protected double getExpectationValue () { return gatherExpectationsOrConstraints (false); } protected void gatherConstraints (InstanceList ilist) { gatherExpectationsOrConstraints (true); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -