📄 testcrf4.java
字号:
new CharSequence2TokenSequence("."), new TokenSequenceLowercase(), new TestCRFTokenSequenceRemoveSpaces(), new TokenText(), new OffsetConjunctions(true, new int[][]{//{0}, {1}, {-1, 0}, {0, 1}, {-2, -1, 0}, {0, 1, 2}, {-3, -2, -1}, {1, 2, 3}, //{-2,-1}, {1,2}, //{-3,-2,-1}, {-2,-1,0}, {-1,0,1}, {0,1,2}, {1,2,3} }),// new PrintInputAndTarget(), new TokenSequence2FeatureVectorSequence() }); Pipe p2 = new TestCRF2String(); InstanceList instances = new InstanceList(p); instances.add(new ArrayIterator(data)); InstanceList[] lists = instances.split(new double[]{.5, .5}); CRF4 crf = new CRF4(p, p2); crf.addFullyConnectedStatesForLabels(); if (testValueAndGradient) { Maximizable.ByGradient minable = crf.getMaximizableCRF(lists[0]); TestMaximizable.testValueAndGradient(minable); } else { System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy before training = " + crf.averageTokenAccuracy(lists[1])); System.out.println("Training..."); crf.train(lists[0]); System.out.println("Training Accuracy after training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy after training = " + crf.averageTokenAccuracy(lists[1])); System.out.println("Training results:"); for (int i = 0; i < lists[0].size(); i++) { Instance instance = crf.transduce(lists[0].getInstance(i)); System.out.println(instance.getSource()); } System.out.println("Testing results:"); for (int i = 0; i < lists[1].size(); i++) { Instance instance = crf.transduce(lists[1].getInstance(i)); System.out.println(instance.getSource()); } } } public void doTestSpacePrediction(boolean testValueAndGradient, boolean useSaved, boolean useSparseWeights) { Pipe p = new SerialPipes(new Pipe[]{ new CharSequence2TokenSequence("."), new TokenSequenceLowercase(), new TestCRFTokenSequenceRemoveSpaces(), new TokenText(), new OffsetConjunctions(true, new int[][]{//{0}, /*{1},{-1,0},{0,1}, */ {1}, {-1, 0}, {0, 1}, {-2, -1, 0}, {0, 1, 2}, {-3, -2, -1}, {1, 2, 3}, //{-2,-1}, {-1,0}, {0,1}, {1,2}, //{-3,-2,-1}, {-2,-1,0}, {-1,0,1}, {0,1,2}, {1,2,3}, }),// new PrintInputAndTarget(), new TokenSequence2FeatureVectorSequence() }); CRF4 savedCRF; File f = new File("TestObject.obj"); InstanceList instances = new InstanceList(p); instances.add(new ArrayIterator(data)); InstanceList[] lists = instances.split(new double[]{.5, .5}); CRF4 crf = new CRF4(p.getDataAlphabet(), p.getTargetAlphabet()); crf.addFullyConnectedStatesForLabels(); crf.setUseSparseWeights (useSparseWeights); if (testValueAndGradient) { Maximizable.ByGradient minable = crf.getMaximizableCRF(lists[0]); TestMaximizable.testValueAndGradient(minable); } else { System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy before training = " + crf.averageTokenAccuracy(lists[1])); savedCRF = crf; System.out.println("Training serialized crf."); crf.train(lists[0]); double preTrainAcc = crf.averageTokenAccuracy(lists[0]); double preTestAcc = crf.averageTokenAccuracy(lists[1]); System.out.println("Training Accuracy after training = " + preTrainAcc); System.out.println("Testing Accuracy after training = " + preTestAcc); try { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f)); oos.writeObject(crf); oos.close(); } catch (IOException e) { System.err.println("Exception writing file: " + e); } System.err.println("Wrote out CRF"); System.err.println("CRF parameters. hyperbolicPriorSlope: " + crf.getUseHyperbolicPriorSlope() + ". hyperbolicPriorSharpness: " + crf.getUseHyperbolicPriorSharpness() + ". gaussianPriorVariance: " + crf.getGaussianPriorVariance()); // And read it back in if (useSaved) { crf = null; try { ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f)); crf = (CRF4) ois.readObject(); ois.close(); } catch (IOException e) { System.err.println("Exception reading file: " + e); } catch (ClassNotFoundException cnfe) { System.err.println("Cound not find class reading in object: " + cnfe); } System.err.println("Read in CRF."); crf = savedCRF; double postTrainAcc = crf.averageTokenAccuracy(lists[0]); double postTestAcc = crf.averageTokenAccuracy(lists[1]); System.out.println("Training Accuracy after saving = " + postTrainAcc); System.out.println("Testing Accuracy after saving = " + postTestAcc); assertEquals(postTrainAcc, preTrainAcc, 0.0001); assertEquals(postTestAcc, preTestAcc, 0.0001); } } } public void testAddOrderNStates () { Pipe p = new SerialPipes (new Pipe[] { new CharSequence2TokenSequence("."), new TokenSequenceLowercase(), new TestCRFTokenSequenceRemoveSpaces(), new TokenText(), new OffsetConjunctions (true, new int[][] { {1}, {-1, 0}, {0, 1}, {-2, -1, 0}, {0, 1, 2}, {-3, -2, -1}, {1, 2, 3}, }), new TokenSequence2FeatureVectorSequence() }); InstanceList instances = new InstanceList (p); instances.add (new ArrayIterator(data)); InstanceList[] lists = instances.split (new java.util.Random (678), new double[]{.5, .5}); // Compare 3 CRFs trained with addOrderNStates, and make sure // that having more features leads to a higher likelihood CRF4 crf1 = new CRF4(p.getDataAlphabet(), p.getTargetAlphabet()); crf1.addOrderNStates (lists [0], new int[] { 1, }, new boolean[] { false, }, "START", null, null, false); crf1.train (lists [0]); CRF4 crf2 = new CRF4(p.getDataAlphabet(), p.getTargetAlphabet()); crf2.addOrderNStates (lists [0], new int[] { 1, 2, }, new boolean[] { false, true }, "START", null, null, false); crf2.train (lists [0]); CRF4 crf3 = new CRF4(p.getDataAlphabet(), p.getTargetAlphabet()); crf3.addOrderNStates (lists [0], new int[] { 1, 2, }, new boolean[] { false, false }, "START", null, null, false); crf3.train (lists [0]); // Prevent cached values double lik1 = getLikelihood (crf1, lists[0]); double lik2 = getLikelihood (crf2, lists[0]); double lik3 = getLikelihood (crf3, lists[0]); System.out.println("CRF1 likelihood "+lik1); assertTrue ("Final zero-order likelihood <"+lik1+"> greater than first-order <"+lik2+">", lik1 < lik2); assertTrue ("Final defaults-only likelihood <"+lik2+"> greater than full first-order <"+lik3+">", lik2 < lik3); assertEquals (-167.335971702, lik1, 0.0001); assertEquals (-166.212235389, lik2, 0.0001); assertEquals ( -90.386005741, lik3, 0.0001); } double getLikelihood (CRF4 crf, InstanceList data) { Maximizable.ByGradient mcrf = crf.getMaximizableCRF (data); // Do this elaborate thing so that crf.cachedValueStale is forced true double[] params = new double [mcrf.getNumParameters()]; mcrf.getParameters (params); mcrf.setParameters (params); return mcrf.getValue (); } public void testValueGradient() { doTestSpacePrediction(true); } public void testTrain() { doTestSpacePrediction(false); } public void testDenseTrain () { doTestSpacePrediction (false, false, false); } public void testSerialization() { doTestSpacePrediction(false, true, true); } public void testDenseSerialization () { doTestSpacePrediction(false, true, false); } public void testPrint () { Pipe p = new SerialPipes (new Pipe[] { new CharSequence2TokenSequence("."), new TokenText(), new TestCRFTokenSequenceRemoveSpaces(), new TokenSequence2FeatureVectorSequence(), new PrintInputAndTarget(), }); InstanceList one = new InstanceList (p); String[] data = new String[] { "ABCDE", }; one.add (new ArrayIterator (data)); CRF4 crf = new CRF4 (p, null); crf.addFullyConnectedStatesForLabels(); crf.setWeightsDimensionAsIn (one); CRF4.MaximizableCRF mcrf = crf.getMaximizableCRF(one); double[] params = new double[mcrf.getNumParameters()]; for (int i = 0; i < params.length; i++) { params [i] = i; } mcrf.setParameters (params); crf.print (); } public static Test suite() { return new TestSuite(TestCRF4.class); } public static void main(String[] args) { TestSuite theSuite; if (args.length > 0) { theSuite = new TestSuite(); for (int i = 0; i < args.length; i++) { theSuite.addTest (new TestCRF4 (args [i])); } } else { theSuite = (TestSuite) suite(); } junit.textui.TestRunner.run (theSuite); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -