📄 testcrf2.java
字号:
} private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt (); } } public class TestCRF2String extends Pipe implements Serializable { public TestCRF2String () { super (); } public Instance pipe (Instance carrier) { StringBuffer sb = new StringBuffer(); String source = (String)carrier.getSource(); Transducer.ViterbiPath vp = (Transducer.ViterbiPath) carrier.getTarget(); ArraySequence as = (ArraySequence) vp.output(); //int startLabelIndex = as.getAlphabet().lookupIndex("start"); for (int i = 0; i < source.length(); i++) { System.out.println ("target["+i+"]="+as.get(i).toString()); if (as.get(i).toString().equals("start") && i != 0) sb.append (' '); sb.append (source.charAt(i)); } carrier.setSource(sb.toString()); System.out.println ("carrier.getSource() = "+carrier.getSource()); return carrier; } private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt (); } } public void doTestSpacePrediction (boolean testCostAndGradient) { Pipe p = new SerialPipes (new Pipe[] { new CharSequence2TokenSequence ("."), new TokenSequenceLowercase (), new TestCRFTokenSequenceRemoveSpaces (), new TokenText (), new OffsetConjunctions (false, new int[][] {{0}, {1},{-1,0},{0,1}, {-2,-1,0}, {0,1,2}, {-3,-2,-1}, {1,2,3}, //{-2,-1}, {1,2}, //{-3,-2,-1}, {-2,-1,0}, {-1,0,1}, {0,1,2}, {1,2,3} }), new PrintInputAndTarget(), new TokenSequence2FeatureVectorSequence () }); Pipe p2 = new TestCRF2String (); InstanceList instances = new InstanceList (p); instances.add (new ArrayIterator (data)); InstanceList[] lists = instances.split (new double[] {.5,.5}); CRF2 crf = new CRF2 (p, p2); crf.addFullyConnectedStatesForLabels (); if (testCostAndGradient) { Minimizable.ByGradient minable = crf.getMinimizableCRF (lists[0]); TestMinimizable.testCostAndGradient (minable); } else { System.out.println ("Training Accuracy before training = "+crf.averageTokenAccuracy (lists[0])); System.out.println ("Testing Accuracy before training = "+crf.averageTokenAccuracy (lists[1])); System.out.println("Training..."); crf.train (lists[0]); System.out.println ("Training Accuracy after training = "+crf.averageTokenAccuracy (lists[0])); System.out.println ("Testing Accuracy after training = "+crf.averageTokenAccuracy (lists[1])); System.out.println ("Training results:"); for (int i = 0; i < lists[0].size(); i++) { Instance instance = crf.transduce (lists[0].getInstance(i)); System.out.println (instance.getSource()); } System.out.println ("Testing results:"); for (int i = 0; i < lists[1].size(); i++) { Instance instance = crf.transduce (lists[1].getInstance(i)); System.out.println (instance.getSource()); } } } public void doTestSpacePrediction (boolean testCostAndGradient, int useSaved) { Pipe p = new SerialPipes (new Pipe[] { new CharSequence2TokenSequence ("."), new TokenSequenceLowercase (), new TestCRFTokenSequenceRemoveSpaces (), new TokenText (), new OffsetConjunctions (false, new int[][] {{0}, /*{1},{-1,0},{0,1}, */ //{-2,-1}, {-1,0}, {0,1}, {1,2}, //{-3,-2,-1}, {-2,-1,0}, {-1,0,1}, {0,1,2}, {1,2,3}, }), new PrintInputAndTarget(), new TokenSequence2FeatureVectorSequence () }); CRF2 savedCRF; File f = new File("TestObject.obj"); File f2 = new File("TestObject2.obj"); File f3 = new File("TestList.obj"); InstanceList instances = new InstanceList (p); instances.add (new ArrayIterator (data)); InstanceList[] lists = instances.split (new double[] {.5,.5}); CRF2 crf = new CRF2 (p.getDataAlphabet(), p.getTargetAlphabet()); crf.addFullyConnectedStatesForLabels (); if (testCostAndGradient) { Minimizable.ByGradient minable = crf.getMinimizableCRF (lists[0]); TestMinimizable.testCostAndGradient (minable); } else { System.out.println ("Training Accuracy before training = "+crf.averageTokenAccuracy (lists[0])); System.out.println ("Testing Accuracy before training = "+crf.averageTokenAccuracy (lists[1])); savedCRF = crf; System.out.println("Training serialized crf."); crf.train (lists[0]); System.out.println ("Training Accuracy after training = "+crf.averageTokenAccuracy (lists[0])); System.out.println ("Testing Accuracy after training = "+crf.averageTokenAccuracy (lists[1])); try { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f)); oos.writeObject(crf); oos.close(); } catch (IOException e) { System.err.println("Exception writing file: " + e); } System.err.println("Wrote out CRF"); //System.err.println("CRF parameters. hyperbolicPriorSlope: " + crf.getUseHyperbolicPriorSlope() + ". hyperbolicPriorSharpness: " + crf.getUseHyperbolicPriorSharpness() + ". gaussianPriorVariance: " + crf.getGaussianPriorVariance()+ ". defaultFeatureIndex: " + crf.getDefaultFeatureIndex()); // And read it back in crf = null; try { ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f)); crf = (CRF2) ois.readObject(); ois.close(); } catch (IOException e) { System.err.println("Exception reading file: " + e); } catch (ClassNotFoundException cnfe) { System.err.println("Cound not find class reading in object: " + cnfe); } System.err.println("Read in CRF."); if (useSaved == 1) { crf = savedCRF; } System.out.println ("Training Accuracy after saving = "+crf.averageTokenAccuracy (lists[0])); System.out.println ("Testing Accuracy after saving = "+crf.averageTokenAccuracy (lists[1])); } } public void testCostGradient () { doTestSpacePrediction (true); } public void testTrain () { doTestSpacePrediction (false); } public void testSerialization (int useSaved) { doTestSpacePrediction (false, useSaved); } public static Test suite () { return new TestSuite (TestCRF2.class); } public static void main (String[] args) { //args = new String[] {"testCostGradient"}; if (args.length >0) { TestCRF2 t = new TestCRF2("testTrain"); if (args[0].equals ("testTrain")) t.testTrain(); else if (args[0].equals ("testSerialization")) t.testSerialization(Integer.parseInt(args[1])); else if (args[0].equals ("testGetSetParameters")) t.testGetSetParameters(); else if (args[0].equals ("testCostGradient")) t.testCostGradient(); else if (args[0].equals ("testCost")) t.testCost(Integer.parseInt(args[1])); else { System.err.println ("Unrecognized test."); System.exit(-1); } } else { junit.textui.TestRunner.run (suite()); } } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -