📄 crf.java
字号:
TransducerEvaluator eval, int numIterations, int numIterationsBetweenFeatureInductions, int numFeatureInductions, int numFeaturesPerFeatureInduction, double trueLabelProbThreshold, boolean clusteredFeatureInduction, double[] trainingProportions) { int trainingIteration = 0; int numLabels = outputAlphabet.size(); for (int featureInductionIteration = 0; featureInductionIteration < numFeatureInductions; featureInductionIteration++) { // Print out some feature information logger.info ("Feature induction iteration "+featureInductionIteration); // Train the CRF InstanceList theTrainingData = trainingData; if (trainingProportions != null && featureInductionIteration < trainingProportions.length) { logger.info ("Training on "+trainingProportions[featureInductionIteration]+"% of the data this round."); InstanceList[] sampledTrainingData = trainingData.split (new Random(1), new double[] {trainingProportions[featureInductionIteration], 1-trainingProportions[featureInductionIteration]}); theTrainingData = sampledTrainingData[0]; } boolean converged = this.train (theTrainingData, validationData, testingData, eval, numIterationsBetweenFeatureInductions); trainingIteration += numIterationsBetweenFeatureInductions; logger.info ("Starting feature induction with "+inputAlphabet.size()+" features."); // Create the list of error tokens, for both unclustered and clustered feature induction InstanceList errorInstances = new InstanceList (trainingData.getDataAlphabet(), trainingData.getTargetAlphabet()); ArrayList errorLabelVectors = new ArrayList(); InstanceList clusteredErrorInstances[][] = new InstanceList[numLabels][numLabels]; ArrayList clusteredErrorLabelVectors[][] = new ArrayList[numLabels][numLabels]; for (int i = 0; i < numLabels; i++) for (int j = 0; j < numLabels; j++) { clusteredErrorInstances[i][j] = new InstanceList (trainingData.getDataAlphabet(), trainingData.getTargetAlphabet()); clusteredErrorLabelVectors[i][j] = new ArrayList(); } for (int i = 0; i < theTrainingData.size(); i++) { logger.info ("instance="+i); Instance instance = theTrainingData.getInstance(i); Sequence input = (Sequence) instance.getData(); Sequence trueOutput = (Sequence) instance.getTarget(); assert (input.size() == trueOutput.size()); Transducer.Lattice lattice = this.forwardBackward (input, null, false, (LabelAlphabet)theTrainingData.getTargetAlphabet()); int prevLabelIndex = 0; // This will put extra error instances in this cluster for (int j = 0; j < trueOutput.size(); j++) { Label label = (Label) ((LabelSequence)trueOutput).getLabelAtPosition(j); assert (label != null); //System.out.println ("Instance="+i+" position="+j+" fv="+lattice.getLabelingAtPosition(j).toString(true)); LabelVector latticeLabeling = lattice.getLabelingAtPosition(j); double trueLabelProb = latticeLabeling.value(label.getIndex()); int labelIndex = latticeLabeling.getBestIndex(); //System.out.println ("position="+j+" trueLabelProb="+trueLabelProb); if (trueLabelProb < trueLabelProbThreshold) { logger.info ("Adding error: instance="+i+" position="+j+" prtrue="+trueLabelProb+ (label == latticeLabeling.getBestLabel() ? " " : " *")+ " truelabel="+label+ " predlabel="+latticeLabeling.getBestLabel()+ " fv="+((FeatureVector)input.get(j)).toString(true)); errorInstances.add (input.get(j), label, null, null); errorLabelVectors.add (latticeLabeling); clusteredErrorInstances[prevLabelIndex][labelIndex].add (input.get(j), label, null, null); clusteredErrorLabelVectors[prevLabelIndex][labelIndex].add (latticeLabeling); } prevLabelIndex = labelIndex; } } logger.info ("Error instance list size = "+errorInstances.size()); if (clusteredFeatureInduction) { FeatureInducer[][] klfi = new FeatureInducer[numLabels][numLabels]; for (int i = 0; i < numLabels; i++) { for (int j = 0; j < numLabels; j++) { logger.info ("Doing feature induction for "+ outputAlphabet.lookupObject(i)+" -> "+outputAlphabet.lookupObject(j)); if (clusteredErrorInstances[i][j].size() < 20) { logger.info ("..skipping because only "+clusteredErrorInstances[i][j].size()+" instances."); continue; } int s = clusteredErrorLabelVectors[i][j].size(); LabelVector[] lvs = new LabelVector[s]; for (int k = 0; k < s; k++) lvs[k] = (LabelVector) clusteredErrorLabelVectors[i][j].get(k); klfi[i][j] = new FeatureInducer (new ExpGain.Factory (lvs), clusteredErrorInstances[i][j], numFeaturesPerFeatureInduction); } } for (int i = 0; i < numLabels; i++) { for (int j = 0; j < numLabels; j++) { logger.info ("Adding new induced features for "+ outputAlphabet.lookupObject(i)+" -> "+outputAlphabet.lookupObject(j)); if (klfi[i][j] == null) { logger.info ("...skipping because no features induced."); continue; } klfi[i][j].induceFeaturesFor (trainingData, false, false); klfi[i][j].induceFeaturesFor (testingData, false, false); } } klfi = null; } else if (true) { int s = errorLabelVectors.size(); LabelVector[] lvs = new LabelVector[s]; for (int i = 0; i < s; i++) lvs[i] = (LabelVector) errorLabelVectors.get(i); FeatureInducer klfi = new FeatureInducer (new ExpGain.Factory (lvs), errorInstances, numFeaturesPerFeatureInduction); klfi.induceFeaturesFor (trainingData, false, false); klfi.induceFeaturesFor (testingData, false, false); klfi = null; } else { // Currently not ever used FeatureInducer igfi = new FeatureInducer (new InfoGain.Factory(), errorInstances, numFeaturesPerFeatureInduction); igfi.induceFeaturesFor (trainingData, false, false); igfi.induceFeaturesFor (testingData, false, false); igfi = null; } this.growWeightsDimensionToInputAlphabet (); } return this.train (trainingData, validationData, testingData, eval, numIterations - trainingIteration); } public void write (File f) { try { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f)); oos.writeObject(this); oos.close(); } catch (IOException e) { System.err.println("Exception writing file " + f + ": " + e); } } public MinimizableCRF getMinimizableCRF (InstanceList ilist) { return new MinimizableCRF (ilist, this); } // Serialization // For CRF class private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; static final int NULL_INTEGER = -1; /* Need to check for null pointers. */ private void writeObject (ObjectOutputStream out) throws IOException { int i, size; out.writeInt (CURRENT_SERIAL_VERSION); out.writeObject (inputAlphabet); out.writeObject (outputAlphabet); size = states.size(); out.writeInt(size); for (i = 0; i<size; i++) { out.writeObject(states.get(i)); } size = initialStates.size(); out.writeInt(size); for (i = 0; i <size; i++) { out.writeObject(initialStates.get(i)); } out.writeObject(name2state); if(weights != null) { size = weights.length; out.writeInt(size); for (i=0; i<size; i++) { out.writeObject(weights[i]); } } else { out.writeInt(NULL_INTEGER); } if(constraints != null) { size = constraints.length; out.writeInt(size); for (i=0; i<size; i++) { out.writeObject(constraints[i]); } } else { out.writeInt(NULL_INTEGER); } if (expectations != null) { size = expectations.length; out.writeInt(size); for (i=0; i<size; i++) { out.writeObject(expectations[i]); } } else { out.writeInt(NULL_INTEGER); } out.writeObject(weightAlphabet); out.writeBoolean(trainable); out.writeBoolean(gatheringConstraints); out.writeInt(defaultFeatureIndex); out.writeBoolean(usingHyperbolicPrior); out.writeDouble(gaussianPriorVariance); out.writeDouble(hyperbolicPriorSlope); out.writeDouble(hyperbolicPriorSharpness); out.writeBoolean(cachedCostStale); out.writeBoolean(cachedGradientStale); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int size, i; int version = in.readInt (); inputAlphabet = (Alphabet) in.readObject(); outputAlphabet = (Alphabet) in.readObject(); size = in.readInt(); states = new ArrayList(size); for (i=0; i<size; i++) { State s = (CRF.State) in.readObject(); states.add(s); } size = in.readInt(); initialStates = new ArrayList(); for (i=0; i<size; i++) { State s = (CRF.State) in.readObject(); initialStates.add(s); } name2state = (HashMap) in.readObject(); size = in.readInt(); if (size == NULL_INTEGER) { weights = null; } else { weights = new DenseVector[size]; for(i=0; i< size; i++) { weights[i] = (DenseVector) in.readObject(); } } size = in.readInt(); if (size == NULL_INTEGER) { constraints = null; } else { constraints = new DenseVector[size]; for(i=0; i< size; i++) { constraints[i] = (DenseVector) in.readObject(); } } size = in.readInt(); if (size == NULL_INTEGER) { expectations = null; } else { expectations = new DenseVector[size]; for(i=0; i< size; i++) { expectations[i] = (DenseVector)in.readObject(); } } weightAlphabet = (Alphabet) in.readObject(); trainable = in.readBoolean(); gatheringConstraints = in.readBoolean(); defaultFeatureIndex = in.readInt(); usingHyperbolicPrior = in.readBoolean(); gaussianPriorVariance = in.readDouble(); hyperbolicPriorSlope = in.readDouble(); hyperbolicPriorSharpness = in.readDouble(); cachedCostStale = in.readBoolean(); cachedGradientStale = in.readBoolean(); } public class MinimizableCRF implements Minimizable.ByGradient, Serializable { InstanceList trainingSet; double cachedCost = -123456789; DenseVector cachedGradient; BitSet infiniteCosts = null; int numParameters; CRF crf; protected MinimizableCRF (InstanceList ilist, CRF crf) { // Set up this.numParameters = 2 * numStates() + weights.length * (defaultFeatureIndex+1); this.trainingSet = ilist; this.crf = crf; cachedGradient = (DenseVector) getNewMatrix (); // This resets and values that may have been in expecations and constraints setTrainable (true); // Set the contraints by running forward-backward with the *output // label sequence provided*, thus restricting it to only those // paths that agree with the label sequence. gatheringConstraints = true; for (int i = 0; i < ilist.size(); i++) { Instance instance = ilist.getInstance(i); FeatureVectorSequence input = (FeatureVectorSequence) instance.getData(); FeatureSequence output = (FeatureSequence) instance.getTarget(); //System.out.println ("Confidence-gathering forward-backward on instance "+i+" of "+ilist.size()); this.crf.forwardBackward (input, output, true); } gatheringConstraints = false; } public Matrix getNewMatrix () { return new DenseVector (numParameters); } // Negate initialCost and finalCost because the parameters are in // terms of "weights", not "costs". public Matrix getParameters (Matrix m) { assert (m instanceof DenseVector && ((Vector)m).singleSize() == numParameters); DenseVector parameters = (DenseVector)m; int pi = 0; for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); parameters.setValue (pi++, -s.initialCost); parameters.setValue (pi++, -s.finalCost); } for (int i = 0; i < weights.length; i++)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -