📄 testcrf4.java
字号:
private static final int CURRENT_SERIAL_VERSION = 0; private void writeObject(ObjectOutputStream out) throws IOException { out.writeInt(CURRENT_SERIAL_VERSION); } private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt(); } } public void doTestSpacePrediction(boolean testValueAndGradient) { Pipe p = makeSpacePredictionPipe (); Pipe p2 = new TestCRF2String(); InstanceList instances = new InstanceList(p); instances.add(new ArrayIterator(data)); InstanceList[] lists = instances.split(new double[]{.5, .5}); CRF4 crf = new CRF4(p, p2); crf.addFullyConnectedStatesForLabels(); if (testValueAndGradient) { Maximizable.ByGradient minable = crf.getMaximizableCRF(lists[0]); TestMaximizable.testValueAndGradient(minable); } else { System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy before training = " + crf.averageTokenAccuracy(lists[1])); System.out.println("Training..."); crf.train(lists[0]); System.out.println("Training Accuracy after training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy after training = " + crf.averageTokenAccuracy(lists[1])); System.out.println("Training results:"); for (int i = 0; i < lists[0].size(); i++) { Instance inst = lists[0].getInstance(i); Sequence input = (Sequence) inst.getData (); Sequence output = crf.transduce (input); System.out.println (output); } System.out.println ("Testing results:"); for (int i = 0; i < lists[1].size(); i++) { Instance inst = lists[1].getInstance(i); Sequence input = (Sequence) inst.getData (); Sequence output = crf.transduce (input); System.out.println (output); } } } public void doTestSpacePrediction(boolean testValueAndGradient, boolean useSaved, boolean useSparseWeights) { Pipe p = makeSpacePredictionPipe (); CRF4 savedCRF; File f = new File("TestObject.obj"); InstanceList instances = new InstanceList(p); instances.add(new ArrayIterator(data)); InstanceList[] lists = instances.split(new double[]{.5, .5}); CRF4 crf = new CRF4(p.getDataAlphabet(), p.getTargetAlphabet()); crf.addFullyConnectedStatesForLabels(); crf.setUseSparseWeights (useSparseWeights); if (testValueAndGradient) { Maximizable.ByGradient minable = crf.getMaximizableCRF(lists[0]); TestMaximizable.testValueAndGradient(minable); } else { System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy before training = " + crf.averageTokenAccuracy(lists[1])); savedCRF = crf; System.out.println("Training serialized crf."); crf.train(lists[0]); double preTrainAcc = crf.averageTokenAccuracy(lists[0]); double preTestAcc = crf.averageTokenAccuracy(lists[1]); System.out.println("Training Accuracy after training = " + preTrainAcc); System.out.println("Testing Accuracy after training = " + preTestAcc); try { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f)); oos.writeObject(crf); oos.close(); } catch (IOException e) { System.err.println("Exception writing file: " + e); } System.err.println("Wrote out CRF"); System.err.println("CRF parameters. hyperbolicPriorSlope: " + crf.getUseHyperbolicPriorSlope() + ". hyperbolicPriorSharpness: " + crf.getUseHyperbolicPriorSharpness() + ". gaussianPriorVariance: " + crf.getGaussianPriorVariance()); // And read it back in if (useSaved) { crf = null; try { ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f)); crf = (CRF4) ois.readObject(); ois.close(); } catch (IOException e) { System.err.println("Exception reading file: " + e); } catch (ClassNotFoundException cnfe) { System.err.println("Cound not find class reading in object: " + cnfe); } System.err.println("Read in CRF."); crf = savedCRF; double postTrainAcc = crf.averageTokenAccuracy(lists[0]); double postTestAcc = crf.averageTokenAccuracy(lists[1]); System.out.println("Training Accuracy after saving = " + postTrainAcc); System.out.println("Testing Accuracy after saving = " + postTestAcc); assertEquals(postTrainAcc, preTrainAcc, 0.0001); assertEquals(postTestAcc, preTestAcc, 0.0001); } } } private Pipe makeSpacePredictionPipe () { Pipe p = new SerialPipes(new Pipe[]{ new CharSequence2TokenSequence("."), new TokenSequenceLowercase(), new TestCRFTokenSequenceRemoveSpaces(), new TokenText(), new OffsetConjunctions(true, new int[][]{//{0}, /*{1},{-1,0},{0,1}, */ {1}, {-1, 0}, {0, 1}, {-2, -1, 0}, {0, 1, 2}, {-3, -2, -1}, {1, 2, 3}, //{-2,-1}, {-1,0}, {0,1}, {1,2}, //{-3,-2,-1}, {-2,-1,0}, {-1,0,1}, {0,1,2}, {1,2,3}, }),// new PrintInputAndTarget(), new TokenSequence2FeatureVectorSequence() }); return p; } public void testAddOrderNStates () { Pipe p = makeSpacePredictionPipe (); InstanceList instances = new InstanceList (p); instances.add (new ArrayIterator(data)); InstanceList[] lists = instances.split (new java.util.Random (678), new double[]{.5, .5}); // Compare 3 CRFs trained with addOrderNStates, and make sure // that having more features leads to a higher likelihood CRF4 crf1 = new CRF4(p.getDataAlphabet(), p.getTargetAlphabet()); crf1.addOrderNStates (lists [0], new int[] { 1, }, new boolean[] { false, }, "START", null, null, false); crf1.train (lists [0]); CRF4 crf2 = new CRF4(p.getDataAlphabet(), p.getTargetAlphabet()); crf2.addOrderNStates (lists [0], new int[] { 1, 2, }, new boolean[] { false, true }, "START", null, null, false); crf2.train (lists [0]); CRF4 crf3 = new CRF4(p.getDataAlphabet(), p.getTargetAlphabet()); crf3.addOrderNStates (lists [0], new int[] { 1, 2, }, new boolean[] { false, false }, "START", null, null, false); crf3.train (lists [0]); // Prevent cached values double lik1 = getLikelihood (crf1, lists[0]); double lik2 = getLikelihood (crf2, lists[0]); double lik3 = getLikelihood (crf3, lists[0]); System.out.println("CRF1 likelihood "+lik1); assertTrue ("Final zero-order likelihood <"+lik1+"> greater than first-order <"+lik2+">", lik1 < lik2); assertTrue ("Final defaults-only likelihood <"+lik2+"> greater than full first-order <"+lik3+">", lik2 < lik3); assertEquals (-167.2234457483949, lik1, 0.0001); assertEquals (-165.81326484466342, lik2, 0.0001); assertEquals (-90.37680146432787, lik3, 0.0001); } double getLikelihood (CRF4 crf, InstanceList data) { Maximizable.ByGradient mcrf = crf.getMaximizableCRF (data); // Do this elaborate thing so that crf.cachedValueStale is forced true double[] params = new double [mcrf.getNumParameters()]; mcrf.getParameters (params); mcrf.setParameters (params); return mcrf.getValue (); } public void testFrozenWeights () { Pipe p = makeSpacePredictionPipe (); InstanceList instances = new InstanceList (p); instances.add (new ArrayIterator (data)); CRF4 crf1 = new CRF4 (p.getDataAlphabet (), p.getTargetAlphabet ()); crf1.addFullyConnectedStatesForLabels (); crf1.train (instances); CRF4 crf2 = new CRF4 (p.getDataAlphabet (), p.getTargetAlphabet ()); crf2.addFullyConnectedStatesForLabels (); // Freeze some weights, before training for (int i = 0; i < crf2.getWeights ().length; i += 2) { crf2.freezeWeights (i); } crf2.train (instances); SparseVector[] w = crf2.getWeights (); double[] b = crf2.getDefaultWeights (); for (int i = 0; i < w.length; i += 2) { assertEquals (0.0, b[i], 1e-10); for (int loc = 0; loc < w[i].numLocations (); loc++) { assertEquals (0.0, w[i].valueAtLocation (loc), 1e-10); } } // Check that the frozen weights has worse likelihood Maximizable.ByGradient maxable1 = crf1.getMaximizableCRF (instances); Maximizable.ByGradient maxable2 = crf2.getMaximizableCRF (instances); double val1 = maxable1.getValue(); double val2 = maxable2.getValue (); assertTrue ("Error: Freezing weights helps performance! Full "+val1+", Frozen "+val2, val1 > val2); } public void testValueGradient() { doTestSpacePrediction(true); } public void testTrain() { doTestSpacePrediction(false); } public void testDenseTrain () {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -