📄 testcrf3.java
字号:
{ int version = in.readInt(); } } public class TestCRF2String extends Pipe implements Serializable { public TestCRF2String() { super(); } public Instance pipe(Instance carrier) { StringBuffer sb = new StringBuffer(); String source = (String) carrier.getSource(); Transducer.ViterbiPath vp = (Transducer.ViterbiPath) carrier.getTarget(); ArraySequence as = (ArraySequence) vp.output(); //int startLabelIndex = as.getAlphabet().lookupIndex("start"); for (int i = 0; i < source.length(); i++) { System.out.println("target[" + i + "]=" + as.get(i).toString()); if (as.get(i).toString().equals("start") && i != 0) sb.append(' '); sb.append(source.charAt(i)); } carrier.setSource(sb.toString()); System.out.println("carrier.getSource() = " + carrier.getSource()); return carrier; } private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private void writeObject(ObjectOutputStream out) throws IOException { out.writeInt(CURRENT_SERIAL_VERSION); } private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt(); } } public void doTestSpacePrediction(boolean testCostAndGradient) { Pipe p = new SerialPipes(new Pipe[]{ new CharSequence2TokenSequence("."), new TokenSequenceLowercase(), new TestCRFTokenSequenceRemoveSpaces(), new TokenText(), new OffsetConjunctions(false, new int[][]{//{0}, {1}, {-1, 0}, {0, 1}, {-2, -1, 0}, {0, 1, 2}, {-3, -2, -1}, {1, 2, 3}, //{-2,-1}, {1,2}, //{-3,-2,-1}, {-2,-1,0}, {-1,0,1}, {0,1,2}, {1,2,3} }),// new PrintInputAndTarget(), new TokenSequence2FeatureVectorSequence() }); Pipe p2 = new TestCRF2String(); InstanceList instances = new InstanceList(p); instances.add(new ArrayIterator(data)); InstanceList[] lists = instances.split(new double[]{.5, .5}); CRF3 crf = new CRF3(p, p2); crf.priorCost(lists[0]); crf.addFullyConnectedStatesForLabels(); if (testCostAndGradient) { Minimizable.ByGradient minable = crf.getMinimizableCRF(lists[0]); TestMinimizable.testCostAndGradient(minable); } else { System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy before training = " + crf.averageTokenAccuracy(lists[1])); System.out.println("Training..."); crf.train(lists[0]); System.out.println("Training Accuracy after training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy after training = " + crf.averageTokenAccuracy(lists[1])); System.out.println("Training results:"); for (int i = 0; i < lists[0].size(); i++) { Instance instance = crf.transduce(lists[0].getInstance(i)); System.out.println(instance.getSource()); } System.out.println("Testing results:"); for (int i = 0; i < lists[1].size(); i++) { Instance instance = crf.transduce(lists[1].getInstance(i)); System.out.println(instance.getSource()); } } } public void doTestSpacePrediction(boolean testCostAndGradient, boolean useSaved) { Pipe p = new SerialPipes(new Pipe[]{ new CharSequence2TokenSequence("."), new TokenSequenceLowercase(), new TestCRFTokenSequenceRemoveSpaces(), new TokenText(), new OffsetConjunctions(false, new int[][]{//{0}, /*{1},{-1,0},{0,1}, */ //{-2,-1}, {-1,0}, {0,1}, {1,2}, //{-3,-2,-1}, {-2,-1,0}, {-1,0,1}, {0,1,2}, {1,2,3}, }),// new PrintInputAndTarget(), new TokenSequence2FeatureVectorSequence() }); CRF3 savedCRF; File f = new File("TestObject.obj"); InstanceList instances = new InstanceList(p); instances.add(new ArrayIterator(data)); InstanceList[] lists = instances.split(new double[]{.5, .5}); CRF3 crf = new CRF3(p.getDataAlphabet(), p.getTargetAlphabet()); crf.priorCost(lists[0]); crf.addFullyConnectedStatesForLabels(); if (testCostAndGradient) { Minimizable.ByGradient minable = crf.getMinimizableCRF(lists[0]); TestMinimizable.testCostAndGradient(minable); } else { System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy before training = " + crf.averageTokenAccuracy(lists[1])); savedCRF = crf; System.out.println("Training serialized crf."); crf.train(lists[0]); double preTrainAcc = crf.averageTokenAccuracy(lists[0]); double preTestAcc = crf.averageTokenAccuracy(lists[1]); System.out.println("Training Accuracy after training = " + preTrainAcc); System.out.println("Testing Accuracy after training = " + preTestAcc); try { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f)); oos.writeObject(crf); oos.close(); } catch (IOException e) { System.err.println("Exception writing file: " + e); } System.err.println("Wrote out CRF"); System.err.println("CRF parameters. hyperbolicPriorSlope: " + crf.getUseHyperbolicPriorSlope() + ". hyperbolicPriorSharpness: " + crf.getUseHyperbolicPriorSharpness() + ". gaussianPriorVariance: " + crf.getGaussianPriorVariance()); // And read it back in if (useSaved) { crf = null; try { ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f)); crf = (CRF3) ois.readObject(); ois.close(); } catch (IOException e) { System.err.println("Exception reading file: " + e); } catch (ClassNotFoundException cnfe) { System.err.println("Cound not find class reading in object: " + cnfe); } System.err.println("Read in CRF."); crf = savedCRF; double postTrainAcc = crf.averageTokenAccuracy(lists[0]); double postTestAcc = crf.averageTokenAccuracy(lists[1]); System.out.println("Training Accuracy after saving = " + postTrainAcc); System.out.println("Testing Accuracy after saving = " + postTestAcc); assertEquals(postTrainAcc, preTrainAcc, 0.0001); assertEquals(postTestAcc, preTestAcc, 0.0001); } } } public void testCostGradient() { doTestSpacePrediction(true); } public void testTrain() { doTestSpacePrediction(false); } public void testSerialization() { doTestSpacePrediction(false, true); } public static Test suite() { return new TestSuite(TestCRF3.class); } public static void main(String[] args) { //args = new String[] {"testCostGradient"}; if (args.length > 0) { TestCRF3 t = new TestCRF3("testTrain"); if (args[0].equals("testTrain")) t.testTrain(); else if (args[0].equals("testSerialization")) t.testSerialization(); else if (args[0].equals("testGetSetParameters")) t.testGetSetParameters(); else if (args[0].equals("testCostGradient")) t.testCostGradient(); else if (args[0].equals("testCost")) t.testCost(Integer.parseInt(args[1])); else { System.err.println("Unrecognized test."); System.exit(-1); } } else { junit.textui.TestRunner.run(suite()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -