📄 crf3.java
字号:
out.writeInt(NULL_INTEGER); } if(defaultWeights != null) { size = defaultWeights.length; out.writeInt(size); for (i=0; i<size; i++) out.writeDouble(defaultWeights[i]); } else { out.writeInt(NULL_INTEGER); } if(defaultConstraints != null) { size = defaultConstraints.length; out.writeInt(size); for (i=0; i<size; i++) out.writeDouble(defaultConstraints[i]); } else { out.writeInt(NULL_INTEGER); } if (defaultExpectations != null) { size = defaultExpectations.length; out.writeInt(size); for (i=0; i<size; i++) out.writeDouble(defaultExpectations[i]); } else { out.writeInt(NULL_INTEGER); } if (weightsPresent != null) { size = weightsPresent.length; out.writeInt(size); for (i=0; i<size; i++) out.writeObject(weightsPresent[i]); } else { out.writeInt(NULL_INTEGER); } if (featureSelections != null) { size = featureSelections.length; out.writeInt(size); for (i=0; i<size; i++) out.writeObject(featureSelections[i]); } else { out.writeInt(NULL_INTEGER); } out.writeObject(globalFeatureSelection); out.writeObject(weightAlphabet); out.writeBoolean(trainable); out.writeBoolean(gatheringConstraints); out.writeBoolean(gatheringWeightsPresent); //out.writeInt(defaultFeatureIndex); out.writeBoolean(usingHyperbolicPrior); out.writeDouble(gaussianPriorVariance); out.writeDouble(hyperbolicPriorSlope); out.writeDouble(hyperbolicPriorSharpness); out.writeBoolean(cachedCostStale); out.writeBoolean(cachedGradientStale); out.writeBoolean(someTrainingDone); out.writeInt(featureInducers.size()); for (i = 0; i < featureInducers.size(); i++) { out.writeObject(featureInducers.get(i)); } out.writeBoolean(printGradient); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int size, i; int version = in.readInt (); inputPipe = (Pipe) in.readObject(); outputPipe = (Pipe) in.readObject(); inputAlphabet = (Alphabet) in.readObject(); outputAlphabet = (Alphabet) in.readObject(); size = in.readInt(); states = new ArrayList(); for (i=0; i<size; i++) { State s = (CRF3.State) in.readObject(); states.add(s); } size = in.readInt(); initialStates = new ArrayList(); for (i=0; i<size; i++) { State s = (CRF3.State) in.readObject(); initialStates.add(s); } name2state = (HashMap) in.readObject(); size = in.readInt(); if (size == NULL_INTEGER) { weights = null; } else { weights = new SparseVector[size]; for(i=0; i< size; i++) { weights[i] = (SparseVector) in.readObject(); } } size = in.readInt(); if (size == NULL_INTEGER) { constraints = null; } else { constraints = new SparseVector[size]; for(i=0; i< size; i++) { constraints[i] = (SparseVector) in.readObject(); } } size = in.readInt(); if (size == NULL_INTEGER) { expectations = null; } else { expectations = new SparseVector[size]; for(i=0; i< size; i++) { expectations[i] = (SparseVector)in.readObject(); } } size = in.readInt(); if (size == NULL_INTEGER) { defaultWeights = null; } else { defaultWeights = new double[size]; for(i=0; i< size; i++) { defaultWeights[i] = in.readDouble(); } } size = in.readInt(); if (size == NULL_INTEGER) { defaultConstraints = null; } else { defaultConstraints = new double[size]; for(i=0; i< size; i++) { defaultConstraints[i] = in.readDouble(); } } size = in.readInt(); if (size == NULL_INTEGER) { defaultExpectations = null; } else { defaultExpectations = new double[size]; for(i=0; i< size; i++) { defaultExpectations[i] = in.readDouble(); } } size = in.readInt(); if (size == NULL_INTEGER) { weightsPresent = null; } else { weightsPresent = new BitSet[size]; for(i=0; i<size; i++) weightsPresent[i] = (BitSet)in.readObject(); } size = in.readInt(); if (size == NULL_INTEGER) { featureSelections = null; } else { featureSelections = new FeatureSelection[size]; for(i=0; i<size; i++) featureSelections[i] = (FeatureSelection)in.readObject(); } globalFeatureSelection = (FeatureSelection) in.readObject(); weightAlphabet = (Alphabet) in.readObject(); trainable = in.readBoolean(); gatheringConstraints = in.readBoolean(); gatheringWeightsPresent = in.readBoolean(); //defaultFeatureIndex = in.readInt(); usingHyperbolicPrior = in.readBoolean(); gaussianPriorVariance = in.readDouble(); hyperbolicPriorSlope = in.readDouble(); hyperbolicPriorSharpness = in.readDouble(); cachedCostStale = in.readBoolean(); cachedGradientStale = in.readBoolean(); someTrainingDone = in.readBoolean(); size = in.readInt(); featureInducers = new ArrayList(); for (i = 0; i < size; i++) { featureInducers.add((FeatureInducer)in.readObject()); } printGradient = in.readBoolean(); } public class MinimizableCRF implements Minimizable.ByGradient, Serializable { InstanceList trainingSet; double cachedCost = -123456789; DenseVector cachedGradient; BitSet infiniteCosts = null; int numParameters; CRF3 crf; protected MinimizableCRF (InstanceList ilist, CRF3 crf) { // Set up this.numParameters = 2 * numStates() + defaultWeights.length; for (int i = 0; i < weights.length; i++) numParameters += weights[i].numLocations(); this.trainingSet = ilist; this.crf = crf; cachedGradient = (DenseVector) getNewMatrix (); // This resets and values that may have been in expecations and constraints setTrainable (true); // Set the contraints by running forward-backward with the *output // label sequence provided*, thus restricting it to only those // paths that agree with the label sequence. gatheringConstraints = true; for (int i = 0; i < ilist.size(); i++) { Instance instance = ilist.getInstance(i); FeatureVectorSequence input = (FeatureVectorSequence) instance.getData(); FeatureSequence output = (FeatureSequence) instance.getTarget(); //System.out.println ("Confidence-gathering forward-backward on instance "+i+" of "+ilist.size()); this.crf.forwardBackward (input, output, true); //System.out.println ("Gathering constraints for Instance #"+i); } gatheringConstraints = false; } public Matrix getNewMatrix () { return new DenseVector (numParameters); } // Negate initialCost and finalCost because the parameters are in // terms of "weights", not "costs". public Matrix getParameters (Matrix m) { assert (m instanceof DenseVector && ((Vector)m).singleSize() == numParameters); DenseVector parameters = (DenseVector)m; int pi = 0; for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); parameters.setValue (pi++, -s.initialCost); parameters.setValue (pi++, -s.finalCost); } for (int i = 0; i < weights.length; i++) { parameters.setValue (pi++, defaultWeights[i]); int nl = weights[i].numLocations(); for (int j = 0; j < nl; j++) parameters.setValue (pi++, weights[i].valueAtLocation(j)); } return parameters; } public void setParameters (Matrix m) { assert (m instanceof DenseVector && ((DenseVector)m).singleSize() == numParameters); cachedCostStale = cachedGradientStale = true; DenseVector parameters = (DenseVector)m; int pi = 0; for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); s.initialCost = -parameters.value (pi++); s.finalCost = -parameters.value (pi++); } for (int i = 0; i < weights.length; i++) { defaultWeights[i] = parameters.value (pi++); int nl = weights[i].numLocations(); for (int j = 0; j < nl; j++) weights[i].setValueAtLocation (j, parameters.value (pi++)); } someTrainingDone = true; } public double getParameter (int[] indices) { assert (indices.length == 1); int index = indices[0]; int numStateParms = 2 * numStates(); if (index < numStateParms) { State s = (State)getState(index / 2); if (index % 2 == 0) return -s.initialCost; else return -s.finalCost; } else { index -= numStateParms; for (int i = 0; i < weights.length; i++) { if (index == 0) return defaultWeights[i]; index--; if (index < weights[i].numLocations()) return weights[i].valueAtLocation (index); else index -= weights[i].numLocations(); } throw new IllegalArgumentException ("index too high = "+indices[0]); } } public void setParameter (int[] indices, double value) { cachedCostStale = cachedGradientStale = true; assert (indices.length == 1); int index = indices[0]; int numStateParms = 2 * numStates(); if (index < numStateParms) { State s = (State)getState(index / 2); if (index % 2 == 0) s.initialCost = -value; else s.finalCost = -value; } else { index -= numStateParms; for (int i = 0; i < weights.length; i++) { if (index == 0) { defaultWeights[i] = value; return; } else index--; if (index < weights[i].numLocations()) { weights[i].setValueAtLocation (index, value); } else index -= weights[i].numLocations(); } throw new IllegalArgumentException ("index too high = "+indices[0]); } } // Minus log probability of the training sequence labels public double getCost () { if (cachedCostStale) { long startingTime = System.currentTimeMillis(); cachedCost = 0; cachedGradientStale = true; // Instance costs must either always or never be included in // the total costs; we can't just sometimes skip a cost // because it is infinite, this throws off the total costs. boolean initializingInfiniteCosts = false; if (infiniteCosts == null) { infiniteCosts = new BitSet (); initializingInfiniteCosts = true; } // Clear the sufficient statistics that we are about to fill for (int i = 0; i < numStates(); i++) { State s = (State)getState(i); s.initialExpectation = 0; s.finalExpectation = 0; } for (int i = 0; i < weights.length; i++) { expectations[i].setAll (0.0); defaultEx
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -