📄 crfbygisupdate.java
字号:
initializingInfiniteCosts = true; } // Clear the sufficient statistics that we are about to fill for (int i = 0; i < numStates(); i++) { State s = (State)getState(i); s.initialExpectation = 0; s.finalExpectation = 0; } for (int i = 0; i < weights.length; i++) { expectations[i].setAll (0.0); defaultExpectations[i] = 0.0; } // Calculate the cost of each instance, and also fill in expectations double unlabeledCost, labeledCost, cost; for (int ii = 0; ii < trainingSet.size(); ii++) { Instance instance = trainingSet.getInstance(ii); FeatureVectorSequence input = (FeatureVectorSequence) instance.getData(); FeatureSequence output = (FeatureSequence) instance.getTarget(); labeledCost = forwardBackward (input, output, false).getCost(); if (Double.isInfinite (labeledCost)) logger.warning (instance.getName().toString() + " has infinite labeled cost.\n" +(instance.getSource() != null ? instance.getSource() : "")); unlabeledCost = forwardBackward (input, true).getCost (); if (Double.isInfinite (unlabeledCost)) logger.warning (instance.getName().toString() + " has infinite unlabeled cost.\n" +(instance.getSource() != null ? instance.getSource() : "")); // Here cost is -log(conditional probability correct label sequence) cost = labeledCost - unlabeledCost; //System.out.println ("Instance "+ii+" CRF.MinimizableCRF.getCost = "+cost); if (Double.isInfinite(cost)) { logger.warning (instance.getName().toString() + " has infinite cost; skipping."); if (initializingInfiniteCosts) infiniteCosts.set (ii); else if (!infiniteCosts.get(ii)) throw new IllegalStateException ("Instance i used to have non-infinite cost, " +"but now it has infinite cost."); continue; } else { cachedCost += cost; } } // Incorporate prior on parameters if (usingHyperbolicPrior) { // Hyperbolic prior for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); if (!Double.isInfinite(s.initialCost)) cachedCost += (hyperbolicPriorSlope / hyperbolicPriorSharpness * Math.log (Maths.cosh (hyperbolicPriorSharpness * -s.initialCost))); if (!Double.isInfinite(s.finalCost)) cachedCost += (hyperbolicPriorSlope / hyperbolicPriorSharpness * Math.log (Maths.cosh (hyperbolicPriorSharpness * -s.finalCost))); } for (int i = 0; i < weights.length; i++) { cachedCost += (hyperbolicPriorSlope / hyperbolicPriorSharpness * Math.log (Maths.cosh (hyperbolicPriorSharpness * defaultWeights[i]))); for (int j = 0; j < weights[i].numLocations(); j++) { double w = weights[i].valueAtLocation(j); if (!Double.isInfinite(w)) cachedCost += (hyperbolicPriorSlope / hyperbolicPriorSharpness * Math.log (Maths.cosh (hyperbolicPriorSharpness * w))); } } } else { // Gaussian prior double priorDenom = 2 * gaussianPriorVariance; for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); if (!Double.isInfinite(s.initialCost)) cachedCost += s.initialCost * s.initialCost / priorDenom; if (!Double.isInfinite(s.finalCost)) cachedCost += s.finalCost * s.finalCost / priorDenom; } for (int i = 0; i < weights.length; i++) { if (!Double.isInfinite(defaultWeights[i])) cachedCost += defaultWeights[i] * defaultWeights[i] / priorDenom; for (int j = 0; j < weights[i].numLocations(); j++) { double w = weights[i].valueAtLocation (j); if (!Double.isInfinite(w)) cachedCost += w * w / priorDenom; } } } cachedCostStale = false; System.out.println ("getCost() (-loglikelihood) = "+cachedCost); logger.fine ("getCost() (-loglikelihood) = "+cachedCost); //crf.print(); long endingTime = System.currentTimeMillis(); System.out.println ("Inference milliseconds = "+(endingTime - startingTime)); } return cachedCost; } private boolean checkForNaN () { for (int i = 0; i < weights.length; i++) { assert (!weights[i].isNaN()); assert (constraints == null || !constraints[i].isNaN()); assert (expectations == null || !expectations[i].isNaN()); assert (!Double.isNaN(defaultExpectations[i])); assert (!Double.isNaN(defaultConstraints[i])); } for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); assert (!Double.isNaN (s.initialExpectation)); assert (!Double.isNaN (s.initialConstraint)); assert (!Double.isNaN (s.initialCost)); assert (!Double.isNaN (s.finalExpectation)); assert (!Double.isNaN (s.finalConstraint)); assert (!Double.isNaN (s.finalCost)); } return true; } /** * Returns a GIS update for the current parameter setting specified by params * * @param params feature weights of current model * @param updates Matrix Object in which to store the updates */ public void getGISUpdate (Matrix params, Matrix updates) { assert (params instanceof DenseVector); assert (updates instanceof DenseVector); // get max empirical feature count and store in maxCount. if(maxCount <= 1.0) { int pi = 0; for(int i = 0; i < numStates(); i++) { pi++; pi++; } for (int i = 0; i < weights.length; i++) { if(defaultConstraints[i] > maxCount) maxCount = defaultConstraints[i]; pi++; int nl = weights[i].numLocations(); for (int j = 0; j < nl; j++) { if(constraints[i].valueAtLocation(j) > maxCount) maxCount = constraints[i].valueAtLocation(j); pi++; } } } computeGISUpdate((DenseVector)params,(DenseVector)updates,gaussianPriorVariance,maxCount); } private void computeGISUpdate(DenseVector lambda, DenseVector updates, double sigma, double s) { double p = 1/(s*sigma*sigma); double inv_s = 1/(sigma*sigma); int pi = 0; for(int i = 0; i < numStates(); i++) { updates.setValue(pi,gis_solver(lambda.value(pi)*inv_s, p, 1.0, 1.0)/s); pi++; updates.setValue(pi,gis_solver(lambda.value(pi)*inv_s, p, 1.0, 1.0)/s); pi++; } for (int i = 0; i < weights.length; i++) { updates.setValue(pi,gis_solver(lambda.value(pi)*inv_s, p, defaultConstraints[i], defaultExpectations[i])/s); pi++; int nl = weights[i].numLocations(); for (int j = 0; j < nl; j++) { updates.setValue(pi,gis_solver(lambda.value(pi)*inv_s, p, constraints[i].valueAtLocation(j), expectations[i].valueAtLocation(j))/s); pi++; } } } private double gis_solver(double m, double p, double e1, double e2) { int iter=0; double x = 2; double new_x = 1; boolean find = false; while(x>0 && iter <5000) { x= new_x; new_x = x*(p-m-p*Math.log(x) +e1)/(e2*x+p); if(Math.abs(new_x-x)<=(0.001*x) ) { return Math.log(x); } iter++; } return 0; } //Serialization of MinimizableCRF private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); out.writeObject(trainingSet); out.writeDouble(cachedCost); out.writeObject(cachedGISUpdate); out.writeObject(infiniteCosts); out.writeInt(numParameters); out.writeObject(crf); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt (); trainingSet = (InstanceList) in.readObject(); cachedCost = in.readDouble(); cachedGISUpdate = (DenseVector) in.readObject(); infiniteCosts = (BitSet) in.readObject(); numParameters = in.readInt(); crf = (CRFByGISUpdate)in.readObject(); } } public static class State extends Transducer.State implements Serializable { // Parameters indexed by destination state, feature index double initialConstraint, initialExpectation; double finalConstraint, finalExpectation; String name; int index; String[] destinationNames; State[] destinations; int[][] weightsIndices; // contains indices into CRF.weights[], String[] labels; CRFByGISUpdate crf; // No arg constructor so serialization works protected State() { super (); } protected State (String name, int index, double initialCost, double finalCost, String[] destinationNames, String[] labelNames, String[][] weightNames, CRFByGISUpdate crf) { super (); assert (destinationNames.length == labelNames.length); assert (destinationNames.length == weightNames.length); this.name = name; this.index = index; this.initialCost = initialCost; this.finalCost = finalCost; this.destinationNames = new String[destinationNames.length]; this.destinations = new State[labelNames.length]; this.weightsIndices = new int[labelNames.length][]; this.labels = new String[labelNames.length]; this.crf = crf; for (int i = 0; i < labelNames.length; i++) { // Make sure this label appears in our output Alphabet crf.outputAlphabet.lookupIndex (labelNames[i]); this.destinationNames[i] = destinationNames[i]; this.labels[i] = labelNames[i]; this.weightsIndices[i] = new int[weightNames[i].length]; for (int j = 0; j < weightNames[i].length; j++) this.weightsIndices[i][j] = crf.getWeightsIndex (weightNames[i][j]); } crf.cachedCostStale = crf.cachedGISUpdateStale = true; } public void print () { System.out.println ("State #"+index+" \""+name+"\""); System.out.println ("initialCost="+initialCost+", finalCost="+finalCost); System.out.println ("#destinations="+destinations.length); for (int i = 0; i < destinations.length; i++) System.out.println ("-> "+destinationNames[i]); } public State getDestinationState (int index) { State ret; if ((ret = destinations[index]) == null) { ret = destinations[index] = (State) crf.name2state.get (destinationNames[index]); //if (ret == null) System.out.println ("this.name="+this.name+" index="+index+" destinationNames[index]="+destinationNames[index]+" name2state.size()="+ crf.name2state.size()); assert (ret != null) : index; } return ret; } public void setTrainable (boolean f) { if (f) { initialConstraint = finalConstraint = 0; initialExpectation = finalExpectation = 0; } } public Transducer.TransitionIterator transitionIterator ( Sequence inputSequence, int inputPosition, Sequence outputSequence, int outputPosition) { if (inputPosition < 0 || outputPosition < 0) throw new UnsupportedOperationException ("Epsilon transitions not implemented."); if (inputSequence == null) throw new UnsupportedOperationException ("CRFs are not generative models; must have an input sequence."); return new TransitionIterator ( this, (FeatureVectorSequence)inputSequence, inputPosition, (outputSequence == null ? null : (String)outputSequence.get(outputPosition)), crf); } public String getName () { return name; } public int getIndex () { return index; } public void incrementInitialCount (double count) { //System.out.println ("incrementInitialCount "+(gatheringConstraints?"constraints":"expectations")+" state#="+this.index+" count="+count); assert (crf.trainable || crf.gatheringWeightsPresent); if (crf.gatheringConstraints) initialConstraint += count; else initialExpectation += count; } public void incrementFinalCount (double count) { //System.out.println ("incrementFinalCount "+(gatheringConstraints?"constraints":"expectations")+" state#="+this.index+" count="+count); assert (crf.trainable || crf.gatheringWeightsPresent); if (crf.gatheringConstraints) finalConstraint += count; else finalExpectation += count; } // Serialization // For class State private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private static final int NULL_INTEGER = -1; private void writeObject (ObjectOutputStream out) throws IOException { int i, size; out.writeInt (CURRENT_SERIAL_VERSION); out.writeDouble(initialConstraint); out.writeDouble(initialExpectation);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -