📄 crf2.java
字号:
// terms of "weights", not "costs". public Matrix getParameters (Matrix m) { assert (m instanceof DenseVector && ((Vector)m).singleSize() == numParameters); DenseVector parameters = (DenseVector)m; int pi = 0; for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); parameters.setValue (pi++, -s.initialCost); parameters.setValue (pi++, -s.finalCost); } for (int i = 0; i < weights.length; i++) { parameters.setValue (pi++, defaultWeights[i]); int nl = weights[i].numLocations(); for (int j = 0; j < nl; j++) parameters.setValue (pi++, weights[i].valueAtLocation(j)); } return parameters; } public void setParameters (Matrix m) { assert (m instanceof DenseVector && ((DenseVector)m).singleSize() == numParameters); cachedCostStale = cachedGradientStale = true; DenseVector parameters = (DenseVector)m; int pi = 0; for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); s.initialCost = -parameters.value (pi++); s.finalCost = -parameters.value (pi++); } for (int i = 0; i < weights.length; i++) { defaultWeights[i] = parameters.value (pi++); int nl = weights[i].numLocations(); for (int j = 0; j < nl; j++) weights[i].setValueAtLocation (j, parameters.value (pi++)); } } public double getParameter (int[] indices) { assert (indices.length == 1); int index = indices[0]; int numStateParms = 2 * numStates(); if (index < numStateParms) { State s = (State)getState(index / 2); if (index % 2 == 0) return -s.initialCost; else return -s.finalCost; } else { index -= numStateParms; for (int i = 0; i < weights.length; i++) { if (index == 0) return defaultWeights[i]; index--; if (index < weights[i].numLocations()) return weights[i].valueAtLocation (index); else index -= weights[i].numLocations(); } throw new IllegalArgumentException ("index too high = "+indices[0]); } } public void setParameter (int[] indices, double value) { cachedCostStale = cachedGradientStale = true; assert (indices.length == 1); int index = indices[0]; int numStateParms = 2 * numStates(); if (index < numStateParms) { State s = (State)getState(index / 2); if (index % 2 == 0) s.initialCost = -value; else s.finalCost = -value; } else { index -= numStateParms; for (int i = 0; i < weights.length; i++) { if (index == 0) { defaultWeights[i] = value; return; } else index--; if (index < weights[i].numLocations()) { weights[i].setValueAtLocation (index, value); } else index -= weights[i].numLocations(); } throw new IllegalArgumentException ("index too high = "+indices[0]); } } // Minus log probability of the training sequence labels public double getCost () { if (cachedCostStale) { cachedCost = 0; cachedGradientStale = true; // Instance costs must either always or never be included in // the total costs; we can't just sometimes skip a cost // because it is infinite, this throws off the total costs. boolean initializingInfiniteCosts = false; if (infiniteCosts == null) { infiniteCosts = new BitSet (); initializingInfiniteCosts = true; } // Clear the sufficient statistics that we are about to fill for (int i = 0; i < numStates(); i++) { State s = (State)getState(i); s.initialExpectation = 0; s.finalExpectation = 0; } for (int i = 0; i < weights.length; i++) { expectations[i].setAll (0.0); defaultExpectations[i] = 0.0; } // Calculate the cost of each instance, and also fill in expectations double unlabeledCost, labeledCost, cost; for (int ii = 0; ii < trainingSet.size(); ii++) { Instance instance = trainingSet.getInstance(ii); FeatureVectorSequence input = (FeatureVectorSequence) instance.getData(); FeatureSequence output = (FeatureSequence) instance.getTarget(); labeledCost = forwardBackward (input, output, false).getCost(); if (Double.isInfinite (labeledCost)) logger.warning (instance.getName().toString() + " has infinite labeled cost.\n" +(instance.getSource() != null ? instance.getSource() : "")); unlabeledCost = forwardBackward (input, true).getCost (); if (Double.isInfinite (unlabeledCost)) logger.warning (instance.getName().toString() + " has infinite unlabeled cost.\n" +(instance.getSource() != null ? instance.getSource() : "")); // Here cost is -log(conditional probability correct label sequence) cost = labeledCost - unlabeledCost; logger.info("Instance "+ii+" CRF.MinimizableCRF.getCost = "+cost); if (Double.isInfinite(cost)) { logger.warning (instance.getName().toString() + " has infinite cost; skipping."); if (initializingInfiniteCosts) infiniteCosts.set (ii); else if (!infiniteCosts.get(ii)) throw new IllegalStateException ("Instance i used to have non-infinite cost, " +"but now it has infinite cost."); continue; } else { cachedCost += cost; } } // Incorporate prior on parameters if (usingHyperbolicPrior) { // Hyperbolic prior for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); if (!Double.isInfinite(s.initialCost)) cachedCost += (hyperbolicPriorSlope / hyperbolicPriorSharpness * Math.log (Maths.cosh (hyperbolicPriorSharpness * -s.initialCost))); if (!Double.isInfinite(s.finalCost)) cachedCost += (hyperbolicPriorSlope / hyperbolicPriorSharpness * Math.log (Maths.cosh (hyperbolicPriorSharpness * -s.finalCost))); } for (int i = 0; i < weights.length; i++) { cachedCost += (hyperbolicPriorSlope / hyperbolicPriorSharpness * Math.log (Maths.cosh (hyperbolicPriorSharpness * defaultWeights[i]))); for (int j = 0; j < weights[i].numLocations(); j++) { double w = weights[i].valueAtLocation(j); if (!Double.isInfinite(w)) cachedCost += (hyperbolicPriorSlope / hyperbolicPriorSharpness * Math.log (Maths.cosh (hyperbolicPriorSharpness * w))); } } } else { // Gaussian prior double priorDenom = 2 * gaussianPriorVariance; for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); if (!Double.isInfinite(s.initialCost)) cachedCost += s.initialCost * s.initialCost / priorDenom; if (!Double.isInfinite(s.finalCost)) cachedCost += s.finalCost * s.finalCost / priorDenom; } for (int i = 0; i < weights.length; i++) { if (!Double.isInfinite(defaultWeights[i])) cachedCost += defaultWeights[i] * defaultWeights[i] / priorDenom; for (int j = 0; j < weights[i].numLocations(); j++) { double w = weights[i].valueAtLocation (j); if (!Double.isInfinite(w)) cachedCost += w * w / priorDenom; } } } cachedCostStale = false; logger.info ("getCost() (-loglikelihood) = "+cachedCost); logger.fine ("getCost() (-loglikelihood) = "+cachedCost); //crf.print(); } return cachedCost; } private boolean checkForNaN () { for (int i = 0; i < weights.length; i++) { assert (!weights[i].isNaN()); assert (constraints == null || !constraints[i].isNaN()); assert (expectations == null || !expectations[i].isNaN()); assert (!Double.isNaN(defaultExpectations[i])); assert (!Double.isNaN(defaultConstraints[i])); } for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); assert (!Double.isNaN (s.initialExpectation)); assert (!Double.isNaN (s.initialConstraint)); assert (!Double.isNaN (s.initialCost)); assert (!Double.isNaN (s.finalExpectation)); assert (!Double.isNaN (s.finalConstraint)); assert (!Double.isNaN (s.finalCost)); } return true; } public Matrix getCostGradient (Matrix m) { // Gradient is -(constraint - expectation - parameters/gaussianPriorVariance) // == (expectation + parameters/gaussianPriorVariance - constraint) // This might be opposite from what you are used to seeing, this // is because this is the gradient of the "cost" and the // gradient should point "up-hill", which is actually away from // the direction we want to parameters to go. if (cachedGradientStale) { if (cachedCostStale) // This will fill in the this.expectation getCost (); assert (checkForNaN()); Vector g = (Vector) m; int gi = 0; for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); cachedGradient.setValue (gi++, (Double.isInfinite(s.initialCost) ? 0.0 : (s.initialExpectation + (usingHyperbolicPrior ? (hyperbolicPriorSlope * Maths.tanh (-s.initialCost) * hyperbolicPriorSharpness) : ((-s.initialCost) / gaussianPriorVariance)) - s.initialConstraint))); cachedGradient.setValue (gi++, (Double.isInfinite (s.finalCost) ? 0.0 : s.finalExpectation + (usingHyperbolicPrior ? (hyperbolicPriorSlope * Maths.tanh (-s.finalCost) * hyperbolicPriorSharpness) : ((-s.finalCost) / gaussianPriorVariance)) - s.finalConstraint)); } if (usingHyperbolicPrior) { // Hyperbolic prior for (int i = 0; i < weights.length; i++) { cachedGradient.setValue (gi++, (Double.isInfinite (defaultWeights[i]) ? 0.0 : (defaultExpectations[i] + (hyperbolicPriorSlope * Maths.tanh (defaultWeights[i]) * hyperbolicPriorSharpness) - defaultConstraints[i]))); if (printGradient) System.out.println ("CRF gradient["+crf.getWeightsName(i)+"][<DEFAULT_FEATURE>]="+cachedGradient.value(gi-1)); for (int j = 0; j < weights[i].numLocations(); j++) { cachedGradient.setValue (gi++, (Double.isInfinite (weights[i].valueAtLocation(j)) ? 0.0 : (expectations[i].valueAtLocation(j) + (hyperbolicPriorSlope * Maths.tanh (weights[i].valueAtLocation(j)) * hyperbolicPriorSharpness) - constraints[i].valueAtLocation(j)))); if (printGradient) System.out.println ("CRF gradient["+crf.getWeightsName(i)+"]["+inputAlphabet.lookupObject(j)+"]="+cachedGradient.value(gi-1)); } } } else { // Gaussian prior for (int i = 0; i < weights.length; i++) { cachedGradient.setValue (gi++, (Double.isInfinite (defaultWeights[i]) ? 0.0 : (defaultExpectations[i] + defaultWeights[i] / gaussianPriorVariance - defaultConstraints[i]))); if (printGradient) System.out.println ("CRF gradient["+crf.getWeightsName(i)+"][<DEFAULT_FEATURE>]="+cachedGradient.value(gi-1)); for (int j = 0; j < weights[i].numLocations(); j++) { cachedGradient.setValue (gi++, (Double.isInfinite (weights[i].valueAtLocation(j)) ? 0.0 : (expectations[i].valueAtLocation(j) + weights[i].valueAtLocation(j) / gaussianPriorVariance - constraints[i].valueAtLocation(j)))); if (printGradient) System.out.println ("CRF gradient["+crf.getWeightsName(i)+"]["+inputAlphabet.lookupObject(j)+"]="+cachedGradient.value(gi-1)); } } } // xxx Show the feature with maximum gradient cachedGradientStale = false; assert (!cachedGradient.isNaN()); } m.set (cachedGradient); printGradient = false; return m; } //Serialization of MinimizableCRF private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); out.writeObject(trainingSet); out.writeDouble(cachedCost); out.writeObject(cachedGradient); out.writeObject(infiniteCosts); out.writeInt(numParameters); out.writeObject(crf); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt (); trainingSet = (InstanceList) in.readObject(); cachedCost = in.readDouble(); cachedGradient = (DenseVector) in.readObject(); infiniteCosts = (BitSet) in.readObject(); numParameters = in.readInt(); crf = (CRF2)in.readObject(); } } public class State extends Transducer.State implements Serializable { // Parameters indexed by destination state, feature index
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -