📄 testcrf4.java
字号:
doTestSpacePrediction (false, false, false); } public void testSerialization() { doTestSpacePrediction(false, true, true); } public void testDenseSerialization () { doTestSpacePrediction(false, true, false); } public void testTokenAccuracy () { Pipe p = makeSpacePredictionPipe (); InstanceList instances = new InstanceList(p); instances.add(new ArrayIterator(data)); InstanceList[] lists = instances.split (new Random (777), new double[]{.5, .5}); CRF4 crf = new CRF4(p.getDataAlphabet(), p.getTargetAlphabet()); crf.addFullyConnectedStatesForLabels(); crf.setUseSparseWeights (true); crf.train (lists[0]); TokenAccuracyEvaluator eval = new TokenAccuracyEvaluator (); eval.test (crf, lists[1], "Testing", null); assertEquals (0.9409, eval.getLastAccuracy (), 0.001); } public void testPrint () { Pipe p = new SerialPipes (new Pipe[] { new CharSequence2TokenSequence("."), new TokenText(), new TestCRFTokenSequenceRemoveSpaces(), new TokenSequence2FeatureVectorSequence(), new PrintInputAndTarget(), }); InstanceList one = new InstanceList (p); String[] data = new String[] { "ABCDE", }; one.add (new ArrayIterator (data)); CRF4 crf = new CRF4 (p, null); crf.addFullyConnectedStatesForThreeQuarterLabels(one); crf.setWeightsDimensionAsIn (one); CRF4.MaximizableCRF mcrf = crf.getMaximizableCRF(one); double[] params = new double[mcrf.getNumParameters()]; for (int i = 0; i < params.length; i++) { params [i] = i; } mcrf.setParameters (params); crf.print (); } public void testCopyStatesAndWeights () { Pipe p = new SerialPipes (new Pipe[] { new CharSequence2TokenSequence("."), new TokenText(), new TestCRFTokenSequenceRemoveSpaces(), new TokenSequence2FeatureVectorSequence(), new PrintInputAndTarget(), }); InstanceList one = new InstanceList (p); String[] data = new String[] { "ABCDE", }; one.add (new ArrayIterator (data)); CRF4 crf = new CRF4 (p, null); crf.addFullyConnectedStatesForLabels(); crf.setWeightsDimensionAsIn (one); CRF4.MaximizableCRF mcrf = crf.getMaximizableCRF(one); double[] params = new double[mcrf.getNumParameters()]; for (int i = 0; i < params.length; i++) { params [i] = i; } mcrf.setParameters (params); StringWriter out = new StringWriter (); crf.print (new PrintWriter (out, true)); CRF4 crf2 = new CRF4 (crf); StringWriter out2 = new StringWriter (); crf2.print (new PrintWriter (out2, true)); assertEquals (out.toString(), out2.toString ()); double val1 = mcrf.getValue (); double val2 = crf2.getMaximizableCRF (one).getValue (); assertEquals (val1, val2, 1e-5); } static String toy = "A a\nB b\nC c\nD d\nB b\nC c\n"; public void testStartState () { Pipe p = new SerialPipes (new Pipe[]{ new LineGroupString2TokenSequence (), new TokenSequenceMatchDataAndTarget (Pattern.compile ("^(\\S+) (.*)"), 2, 1), new TokenSequenceParseFeatureString (false), new TokenText (), new TokenSequence2FeatureVectorSequence (true, false), new Target2LabelSequence (), new PrintInputAndTarget (), }); InstanceList data = new InstanceList (p); data.add (new LineGroupIterator (new StringReader (toy), Pattern.compile ("\n"), true)); CRF4 crf = new CRF4 (p, null); crf.print(); crf.addStatesForLabelsConnectedAsIn (data); crf.addStartState (); Maximizable.ByGradient maxable = crf.getMaximizableCRF (data); assertEquals (-1.3862, maxable.getValue (), 1e-4); crf = new CRF4 (p, null); crf.addOrderNStates (data, new int[] { 1 }, null, "A", null, null, false); crf.print(); maxable = crf.getMaximizableCRF (data); assertEquals (-3.09104245335831, maxable.getValue (), 1e-4); } // Tests that setWeightsDimensionDensely respects featureSelections public void testDenseFeatureSelection () { Pipe p = makeSpacePredictionPipe (); InstanceList instances = new InstanceList (p); instances.add (new ArrayIterator(data)); // Test that dense observations wights aren't added for "default-feature" edges. CRF4 crf1 = new CRF4 (p, null); crf1.addOrderNStates (instances, new int[] { 0 }, null, "start", null, null, true); crf1.setUseSparseWeights (false); crf1.train (instances, null, null, null, 1); // Set weights dimension int nParams1 = crf1.getMaximizableCRF (instances).getNumParameters (); CRF4 crf2 = new CRF4 (p, null); crf2.addOrderNStates (instances, new int[] { 0, 1 }, new boolean[] {false, true}, "start", null, null, true); crf2.setUseSparseWeights (false); crf2.train (instances, null, null, null, 1); // Set weights dimension int nParams2 = crf2.getMaximizableCRF (instances).getNumParameters (); assertEquals (nParams2, nParams1 + 4); } public void testXis () { Pipe p = makeSpacePredictionPipe (); InstanceList instances = new InstanceList (p); instances.add (new ArrayIterator(data)); CRF4 crf1 = new CRF4 (p, null); crf1.addFullyConnectedStatesForLabels (); crf1.train (instances, null, null, null, 10); // Let's get some parameters Instance inst = instances.getInstance (0); Sequence input = (Sequence)inst.getData(); Transducer.Lattice lattice = crf1.forwardBackward (input, (Sequence)inst.getTarget(), false, true, null); for (int ip = 0; ip < lattice.length()-1; ip++) { for (int i = 0; i < crf1.numStates (); i++) { Transducer.State state = crf1.getState (i); Transducer.TransitionIterator it = state.transitionIterator (input, ip); double gamma = lattice.getGammaProbability (ip, state); double xiSum = 0; while (it.hasNext()) { Transducer.State dest = it.nextState (); double xi = lattice.getXiProbability (ip, state, dest); xiSum += xi; } assertEquals (gamma, xiSum, 1e-5); } } } public static Test suite () { return new TestSuite (TestCRF4.class); } public void testStateAddWeights () { Pipe p = TestMEMM.makeSpacePredictionPipe (); InstanceList training = new InstanceList (p); training.add (new ArrayIterator (TestMEMM.data)); CRF4 crf = new CRF4 (p, null); crf.addFullyConnectedStatesForLabels (); crf.train (training); // Check that the notstart state is used at test time Sequence input = (Sequence) training.getInstance (0).getData (); Sequence output = crf.viterbiPath (input).output (); boolean notstartFound = false; for (int i = 0; i < output.size(); i++) { if (output.get(i).toString().equals ("notstart")) { notstartFound = true; } } assertTrue (notstartFound); // Now add infinite cost onto a transition, and make sure that it's honored. CRF4.State state = crf.getState ("notstart"); int widx = crf.getWeightsIndex ("BadBad"); SparseVector w = new SparseVector (new double[250]); w.setAll (Double.NEGATIVE_INFINITY); crf.setWeights (widx, w); state.addWeight (0, "BadBad"); state.addWeight (1, "BadBad"); // Verify that this effectively prevents the notstart state from being used output = crf.viterbiPath (input).output (); notstartFound = false; for (int i = 0; i < output.size() - 1; i++) { if (output.get(i).toString().equals ("notstart")) { notstartFound = true; } } assertTrue (!notstartFound); } private static String oldCrfFile = "test/edu/umass/cs/mallet/base/fst/crf.cnl03.ser.gz"; private static String testString = "John NNP B-NP O\nDoe NNP I-NP O\nsaid VBZ B-VP O\nhi NN B-NP O\n"; public void testOldCrf () { CRF4 crf = (CRF4) FileUtils.readObject (new File (oldCrfFile)); Instance inst = new Instance (testString, null, null, null, crf.getInputPipe ()); Sequence output = crf.transduce ((Sequence) inst.getData ()); String std = output.toString (); assertEquals (" B-PER I-PER O O", std); } public static void main(String[] args) { TestSuite theSuite; if (args.length > 0) { theSuite = new TestSuite(); for (int i = 0; i < args.length; i++) { theSuite.addTest (new TestCRF4 (args [i])); } } else { theSuite = (TestSuite) suite(); } junit.textui.TestRunner.run (theSuite); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -