📄 crf.java
字号:
this.destinations = new State[labelNames.length]; this.crf = crf; for (int i = 0; i < labelNames.length; i++) { outputAlphabet.lookupIndex (labelNames[i]); this.labels[i] = labelNames[i]; this.weightsIndices[i] = getWeightsIndex (weightNames[i]); } } public void print () { System.out.println ("State #"+index+" \""+name+"\""); System.out.println ("initialCost="+initialCost+", finalCost="+finalCost); System.out.println ("#destinations="+destinations.length); for (int i = 0; i < destinations.length; i++) System.out.println ("-> "+destinationNames[i]); } public State getDestinationState (int index) { State ret; if ((ret = destinations[index]) == null) { ret = destinations[index] = (State) crf.name2state.get (destinationNames[index]); //if (ret == null) System.out.println ("this.name="+this.name+" index="+index+" destinationNames[index]="+destinationNames[index]+" name2state.size()="+ crf.name2state.size()); assert (ret != null) : index; } return ret; } public void setTrainable (boolean f) { if (f) { initialConstraint = finalConstraint = 0; initialExpectation = finalExpectation = 0; } } public Transducer.TransitionIterator transitionIterator ( Sequence inputSequence, int inputPosition, Sequence outputSequence, int outputPosition) { if (inputPosition < 0 || outputPosition < 0) throw new UnsupportedOperationException ("Epsilon transitions not implemented."); if (inputSequence == null) throw new UnsupportedOperationException ("CRFs are not generative models; must have an input sequence."); return new TransitionIterator ( this, (FeatureVectorSequence)inputSequence, inputPosition, (outputSequence == null ? null : (String)outputSequence.get(outputPosition)), crf); } public String getName () { return name; } public int getIndex () { return index; } public void incrementInitialCount (double count) { //System.out.println ("incrementInitialCount "+(gatheringConstraints?"constraints":"expectations")+" state#="+this.index+" count="+count); assert (crf.trainable); if (crf.gatheringConstraints) initialConstraint += count; else initialExpectation += count; } public void incrementFinalCount (double count) { //System.out.println ("incrementFinalCount "+(gatheringConstraints?"constraints":"expectations")+" state#="+this.index+" count="+count); assert (crf.trainable); if (crf.gatheringConstraints) finalConstraint += count; else finalExpectation += count; } // Serialization // For class State private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private static final int NULL_INTEGER = -1; private void writeObject (ObjectOutputStream out) throws IOException { int i, size; out.writeInt (CURRENT_SERIAL_VERSION); out.writeDouble(initialConstraint); out.writeDouble(initialExpectation); out.writeDouble(finalConstraint); out.writeDouble(finalExpectation); out.writeObject(name); out.writeInt(index); size = (destinationNames == null) ? NULL_INTEGER : destinationNames.length; out.writeInt(size); if (size != NULL_INTEGER) { for(i=0; i<size; i++){ out.writeObject(destinationNames[i]); } } size = (destinations == null) ? NULL_INTEGER : destinations.length; out.writeInt(size); if (size != NULL_INTEGER) { for(i=0; i<size;i++) { out.writeObject(destinations[i]); } } size = (weightsIndices == null) ? NULL_INTEGER : weightsIndices.length; out.writeInt(size); if (size != NULL_INTEGER) { for (i=0; i<size; i++){ out.writeInt(weightsIndices[i]); } } size = (labels == null) ? NULL_INTEGER : labels.length; out.writeInt(size); if (size != NULL_INTEGER) { for (i=0; i<size; i++) out.writeObject(labels[i]);// out.writeObject (inputAlphabet); // this will cause error// out.writeObject (outputAlphabet); } out.writeObject(crf); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int size, i; int version = in.readInt (); initialConstraint = in.readDouble(); initialExpectation = in.readDouble(); finalConstraint = in.readDouble(); finalExpectation = in.readDouble(); name = (String) in.readObject(); index = in.readInt(); size = in.readInt(); if (size != NULL_INTEGER) { destinationNames = new String[size]; for (i=0; i<size; i++) { destinationNames[i] = (String) in.readObject(); } } else { destinationNames = null; } size = in.readInt(); if (size != NULL_INTEGER) { destinations = new State[size]; for (i=0; i<size; i++) { destinations[i] = (State) in.readObject(); } } else { destinations = null; } size = in.readInt(); if (size != NULL_INTEGER) { weightsIndices = new int[size]; for (i=0; i<size; i++) { weightsIndices[i] = in.readInt(); } } else { weightsIndices = null; } size = in.readInt(); if (size != NULL_INTEGER) { labels = new String[size]; for (i=0; i<size; i++) labels[i] = (String) in.readObject();// inputAlphabet = (Alphabet) in.readObject();// outputAlphabet = (Alphabet) in.readObject(); } else { labels = null; } crf = (CRF) in.readObject(); } } protected class TransitionIterator extends Transducer.TransitionIterator implements Serializable { State source; int index, nextIndex; double[] costs; // Eventually change this because we will have a more space-efficient // FeatureVectorSequence that cannot break out each FeatureVector FeatureVector input; CRF crf; public TransitionIterator (State source, FeatureVectorSequence inputSeq, int inputPosition, String output, CRF crf) { this.source = source; this.crf = crf; this.input = (FeatureVector) inputSeq.get(inputPosition); this.costs = new double[source.destinations.length]; double totalCost = 0; //for normalization, added by Fuchun Peng for (int transIndex = 0; transIndex < source.destinations.length; transIndex++) { // xxx Or do we want output.equals(...) here? if (output == null || output.equals(source.labels[transIndex])) { // Here is the dot product of the feature weights with the lambda weights // for one transition costs[transIndex] = -(inputSeq.dotProduct (inputPosition, crf.weights[source.weightsIndices[transIndex]]) // include with implicit weight 1.0 the default feature + crf.weights[source.weightsIndices[transIndex]].value (crf.defaultFeatureIndex));// System.out.println("costs: " + costs[transIndex]); totalCost = sumNegLogProb (totalCost, costs[transIndex]); assert (!Double.isNaN(costs[transIndex])); } else costs[transIndex] = INFINITE_COST; }/* // normalized cost for (int transIndex = 0; transIndex < source.destinations.length; transIndex++) { // xxx Or do we want output.equals(...) here? if (output == null || output.equals(source.labels[transIndex])) { // Here is the dot product of the feature weights with the lambda weights // for one transition costs[transIndex] -= totalCost;// System.out.println("costs: " + costs[transIndex]); assert (!Double.isNaN(costs[transIndex])); } else costs[transIndex] = INFINITE_COST; }*/ nextIndex = 0; while (nextIndex < source.destinations.length && costs[nextIndex] == INFINITE_COST) nextIndex++; } public boolean hasNext () { return nextIndex < source.destinations.length; } public int numberNext(){return source.destinations.length;} //added by Fuchun Peng public Transducer.State nextState () { assert (nextIndex < source.destinations.length); index = nextIndex; nextIndex++; while (nextIndex < source.destinations.length && costs[nextIndex] == INFINITE_COST) nextIndex++; return source.getDestinationState (index); } public Object getInput () { return input; } public Object getOutput () { return source.labels[index]; } public double getCost () { return costs[index]; } public Transducer.State getSourceState () { return source; } public Transducer.State getDestinationState () { return source.getDestinationState (index); } public void incrementCount (double count) { //System.out.println ("incrementCount "+(gatheringConstraints?"constraints":"expectations")+" dest#="+source.index+" count="+count); assert (crf.trainable); int weightsIndex = source.weightsIndices[index]; if (crf.gatheringConstraints) { crf.constraints[weightsIndex].plusEquals (input, count); crf.constraints[weightsIndex].columnPlusEquals (crf.defaultFeatureIndex, count); } else { crf.expectations[weightsIndex].plusEquals (input, count); crf.expectations[weightsIndex].columnPlusEquals (crf.defaultFeatureIndex, count); } } // Serialization // TransitionIterator private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private static final int NULL_INTEGER = -1; private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); out.writeObject (source); out.writeInt (index); out.writeInt (nextIndex); if (costs != null) { out.writeInt (costs.length); for (int i = 0; i < costs.length; i++) { out.writeDouble (costs[i]); } } else { out.writeInt(NULL_INTEGER); } out.writeObject (input); out.writeObject(crf); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt (); source = (State) in.readObject(); index = in.readInt (); nextIndex = in.readInt (); int size = in.readInt(); if (size == NULL_INTEGER) { costs = null; } else { costs = new double[size]; for (int i =0; i <size; i++) { costs[i] = in.readDouble(); } } input = (FeatureVector) in.readObject(); crf = (CRF) in.readObject(); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -