📄 crfbygisupdate.java
字号:
double v = rfv.getValueAtRank(m); int index = rfv.getIndexAtRank(m); Object feature = inputAlphabet.lookupObject (index); if (v != 0) { sb.append (" "); sb.append (s.name); sb.append (" -> "); sb.append (s.destinations[j].name); sb.append (": "); sb.append (feature); sb.append (" = "); sb.append (v); sb.append ('\n'); } } } } } System.out.println (sb.toString()); } // Java question: // If I make a non-static inner class CRF.Trainer, // can that class by subclassed in another .java file, // and can that subclass still have access to all the CRF's // instance variables? public boolean train (InstanceList ilist) { return train (ilist, (InstanceList)null, (InstanceList)null); } public boolean train (InstanceList ilist, InstanceList validation, InstanceList testing) { return train (ilist, validation, testing, (TransducerEvaluator)null); } public boolean train (InstanceList ilist, InstanceList validation, InstanceList testing, TransducerEvaluator eval) { return train (ilist, validation, testing, eval, 9999, new AGIS(2.0)); } public boolean train (InstanceList ilist, InstanceList validation, InstanceList testing, TransducerEvaluator eval, int numIterations, Minimizer.ByGISUpdate minimizer) { if (numIterations <= 0) return false; assert (ilist.size() > 0); setWeightsDimensionAsIn (ilist); MinimizableCRF mc = new MinimizableCRF (ilist, this); int i; boolean continueTraining = true; boolean converged = false; System.out.println ("CRF about to train with "+numIterations+" iterations"); for (i = 0; i < numIterations; i++) { try { converged = minimizer.minimize (mc, 1); System.out.println ("CRF finished one iteration of minimizer, i="+i); } catch (IllegalArgumentException e) { e.printStackTrace(); System.out.println ("Catching exception; saying converged."); converged = true; } if (eval != null) { continueTraining = eval.evaluate (this, (converged || i == numIterations-1), i, converged, mc.getCost(), ilist, validation, testing); if (!continueTraining && i > 25) break; } if (converged && i > 25) { System.out.println ("CRF training has converged, i="+i); break; } } System.out.println ("About to setTrainable(false)"); // Free the memory of the expectations and constraints setTrainable (false); System.out.println ("Done setTrainable(false)"); return converged; } public boolean train (InstanceList training, InstanceList validation, InstanceList testing, TransducerEvaluator eval, int numIterations, int numIterationsPerProportion, double[] trainingProportions, Minimizer.ByGISUpdate minimizer) { int trainingIteration = 0; for (int i = 0; i < trainingProportions.length; i++) { // Train the CRF InstanceList theTrainingData = training; if (trainingProportions != null && i < trainingProportions.length) { System.out.println ("Training on "+trainingProportions[i]+"% of the data this round."); InstanceList[] sampledTrainingData = training.split (new Random(1), new double[] {trainingProportions[i], 1-trainingProportions[i]}); theTrainingData = sampledTrainingData[0]; } boolean converged = this.train (theTrainingData, validation, testing, eval, numIterationsPerProportion,minimizer); trainingIteration += numIterationsPerProportion; } System.out.println ("Training on 100% of the data this round, for "+ (numIterations-trainingIteration)+" iterations."); return this.train (training, validation, testing, eval, numIterations - trainingIteration, minimizer); } public boolean trainWithFeatureInduction (InstanceList trainingData, InstanceList validationData, InstanceList testingData, TransducerEvaluator eval, int numIterations, int numIterationsBetweenFeatureInductions, int numFeatureInductions, int numFeaturesPerFeatureInduction, double trueLabelProbThreshold, boolean clusteredFeatureInduction, double[] trainingProportions, Minimizer.ByGISUpdate minimizer) { int trainingIteration = 0; int numLabels = outputAlphabet.size(); this.globalFeatureSelection = trainingData.getFeatureSelection(); if (this.globalFeatureSelection == null) { // Mask out all features; some will be added later by FeatureInducer.induceFeaturesFor(.) this.globalFeatureSelection = new FeatureSelection (trainingData.getDataAlphabet()); trainingData.setFeatureSelection (this.globalFeatureSelection); } if (validationData != null) validationData.setFeatureSelection (this.globalFeatureSelection); if (testingData != null) testingData.setFeatureSelection (this.globalFeatureSelection); for (int featureInductionIteration = 0; featureInductionIteration < numFeatureInductions; featureInductionIteration++) { // Print out some feature information System.out.println ("Feature induction iteration "+featureInductionIteration); // Train the CRF InstanceList theTrainingData = trainingData; if (trainingProportions != null && featureInductionIteration < trainingProportions.length) { System.out.println ("Training on "+trainingProportions[featureInductionIteration]+"% of the data this round."); InstanceList[] sampledTrainingData = trainingData.split (new Random(1), new double[] {trainingProportions[featureInductionIteration], 1-trainingProportions[featureInductionIteration]}); theTrainingData = sampledTrainingData[0]; theTrainingData.setFeatureSelection (this.globalFeatureSelection); // xxx necessary? System.out.println (" which is "+theTrainingData.size()+" instances"); } boolean converged = false; if (featureInductionIteration != 0) // Don't train until we have added some features converged = this.train (theTrainingData, validationData, testingData, eval, numIterationsBetweenFeatureInductions,minimizer); trainingIteration += numIterationsBetweenFeatureInductions; // xxx Remove this next line this.print (); System.out.println ("Starting feature induction with "+inputAlphabet.size()+" features."); // Create the list of error tokens, for both unclustered and clustered feature induction InstanceList errorInstances = new InstanceList (trainingData.getDataAlphabet(), trainingData.getTargetAlphabet()); // This errorInstances.featureSelection will get examined by FeatureInducer, // so it can know how to add "new" singleton features errorInstances.setFeatureSelection (this.globalFeatureSelection); ArrayList errorLabelVectors = new ArrayList(); InstanceList clusteredErrorInstances[][] = new InstanceList[numLabels][numLabels]; ArrayList clusteredErrorLabelVectors[][] = new ArrayList[numLabels][numLabels]; for (int i = 0; i < numLabels; i++) for (int j = 0; j < numLabels; j++) { clusteredErrorInstances[i][j] = new InstanceList (trainingData.getDataAlphabet(), trainingData.getTargetAlphabet()); clusteredErrorInstances[i][j].setFeatureSelection (this.globalFeatureSelection); clusteredErrorLabelVectors[i][j] = new ArrayList(); } for (int i = 0; i < theTrainingData.size(); i++) { System.out.println ("instance="+i); Instance instance = theTrainingData.getInstance(i); Sequence input = (Sequence) instance.getData(); Sequence trueOutput = (Sequence) instance.getTarget(); assert (input.size() == trueOutput.size()); Transducer.Lattice lattice = this.forwardBackward (input, null, false, (LabelAlphabet)theTrainingData.getTargetAlphabet()); int prevLabelIndex = 0; // This will put extra error instances in this cluster for (int j = 0; j < trueOutput.size(); j++) { Label label = (Label) ((LabelSequence)trueOutput).getLabelAtPosition(j); assert (label != null); //System.out.println ("Instance="+i+" position="+j+" fv="+lattice.getLabelingAtPosition(j).toString(true)); LabelVector latticeLabeling = lattice.getLabelingAtPosition(j); double trueLabelProb = latticeLabeling.value(label.getIndex()); int labelIndex = latticeLabeling.getBestIndex(); //System.out.println ("position="+j+" trueLabelProb="+trueLabelProb); if (trueLabelProb < trueLabelProbThreshold) { System.out.println ("Adding error: instance="+i+" position="+j+" prtrue="+trueLabelProb+ (label == latticeLabeling.getBestLabel() ? " " : " *")+ " truelabel="+label+ " predlabel="+latticeLabeling.getBestLabel()+ " fv="+((FeatureVector)input.get(j)).toString(true)); errorInstances.add (input.get(j), label, null, null); errorLabelVectors.add (latticeLabeling); clusteredErrorInstances[prevLabelIndex][labelIndex].add (input.get(j), label, null, null); clusteredErrorLabelVectors[prevLabelIndex][labelIndex].add (latticeLabeling); } prevLabelIndex = labelIndex; } } System.out.println ("Error instance list size = "+errorInstances.size()); if (clusteredFeatureInduction) { FeatureInducer[][] klfi = new FeatureInducer[numLabels][numLabels]; for (int i = 0; i < numLabels; i++) { for (int j = 0; j < numLabels; j++) { // Note that we may see some "impossible" transitions here (like O->I in a OIB model) // because we are using lattice gammas to get the predicted label, not Viterbi. // I don't believe this does any harm, and may do some good. System.out.println ("Doing feature induction for "+ outputAlphabet.lookupObject(i)+" -> "+outputAlphabet.lookupObject(j)+ " with "+clusteredErrorInstances[i][j].size()+" instances"); if (clusteredErrorInstances[i][j].size() < 20) { System.out.println ("..skipping because only "+clusteredErrorInstances[i][j].size()+" instances."); continue; } int s = clusteredErrorLabelVectors[i][j].size(); LabelVector[] lvs = new LabelVector[s]; for (int k = 0; k < s; k++) lvs[k] = (LabelVector) clusteredErrorLabelVectors[i][j].get(k); klfi[i][j] = new FeatureInducer (new ExpGain.Factory (lvs, gaussianPriorVariance), clusteredErrorInstances[i][j], numFeaturesPerFeatureInduction, 2*numFeaturesPerFeatureInduction, 2*numFeaturesPerFeatureInduction); featureInducers.add(klfi[i][j]); } } for (int i = 0; i < numLabels; i++) { for (int j = 0; j < numLabels; j++) { System.out.println ("Adding new induced features for "+ outputAlphabet.lookupObject(i)+" -> "+outputAlphabet.lookupObject(j)); if (klfi[i][j] == null) { System.out.println ("...skipping because no features induced."); continue; } // Note that this adds features globally, but not on a per-transition basis klfi[i][j].induceFeaturesFor (trainingData, false, false); if (testingData != null) klfi[i][j].induceFeaturesFor (testingData, false, false); } } klfi = null; } else { int s = errorLabelVectors.size(); LabelVector[] lvs = new LabelVector[s]; for (int i = 0; i < s; i++) lvs[i] = (LabelVector) errorLabelVectors.get(i); FeatureInducer klfi = new FeatureInducer (new ExpGain.Factory (lvs, gaussianPriorVariance), errorInstances, numFeaturesPerFeatureInduction, 2*numFeaturesPerFeatureInduction, 2*numFeaturesPerFeatureInduction); featureInducers.add(klfi); // Note that this adds features globally, but not on a per-transition basis klfi.induceFeaturesFor (trainingData, false, false); if (testingData != null) klfi.induceFeaturesFor (testingData, false, false); System.out.println ("CRFByGISUpdate FeatureSelection now includes "+this.globalFeatureSelection.cardinality()+" features"); klfi = null; } // This is done in CRFByGISUpdate.train() anyway //this.setWeightsDimensionAsIn (trainingData); ////this.growWeightsDimensionToInputAlphabet (); } return this.train (trainingData, validationData, testingData, eval, numIterations - trainingIteration,minimizer); } /** This method is deprecated. */ public Sequence[] predict (InstanceList testing) { testing.setFeatureSelection(this.globalFeatureSelection); for (int i = 0; i < featureInducers.size(); i++) { FeatureInducer klfi = (FeatureInducer)featureInducers.get(i); klfi.induceFeaturesFor (testing, false, false); } Sequence[] ret = new Sequence[testing.size()]; for (int i = 0; i < testing.size(); i++) { Instance instance = testing.getInstance(i); Sequence input = (Sequence) instance.getData(); Sequence trueOutput = (Sequence) instance.getTarget(); assert (input.size() == trueOutput.size()); Sequence predOutput = viterbiPath(input).output(); assert (predOutput.size() == trueOutput.size()); ret[i] = predOutput; } return ret; } /** This method is deprecated. */ public void evaluate (TransducerEvaluator eval, InstanceList testing) { testing.setFeatureSelection(this.globalFeatureSelection); for (int i = 0; i < featureInducers.size(); i++) { FeatureInducer klfi = (FeatureInducer)featureInducers.get(i); klfi.induceFeaturesFor (testing, false, false); } eval.evaluate (this, true, 0, true, 0.0, null, null, testing); } public void write (File f) { try { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f)); oos.writeObject(this); oos.close(); } catch (IOException e) { System.err.println("Exception writing file " + f + ": " + e); } } public MinimizableCRF getMinimizableCRF (InstanceList ilist) { return new MinimizableCRF (ilist, this); } // Serialization // For CRF class private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 1; static final int NULL_INTEGER = -1; /* Need to check for null pointers. */ private void writeObject (ObjectOutputStream out) throws IOException { int i, size; out.writeInt (CURRENT_SERIAL_VERSION); out.writeObject(inputPipe); out.writeObject(outputPipe); out.writeObject (inputAlphabet); out.writeObject (outputAlphabet); size = states.size(); out.writeInt(size); for (i = 0; i<size; i++) out.writeObject(states.get(i)); size = initialStates.size(); out.writeInt(size); for (i = 0; i <size; i++) out.writeObject(initialStates.get(i)); out.writeObject(name2state); if(weights != null) { size = weights.length; out.writeInt(size); for (i=0; i<size; i++) out.writeObject(weights[i]); } else { out.writeInt(NULL_INTEGER); } if(constraints != null) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -