testmemm.java
来自「mallet是自然语言处理、机器学习领域的一个开源项目。」· Java 代码 · 共 599 行 · 第 1/2 页
JAVA
599 行
System.out.println("target[" + i + "]=" + as.get(i).toString()); if (as.get(i).toString().equals("start") && i != 0) sb.append(' '); sb.append(source.charAt(i)); } carrier.setSource(sb.toString()); System.out.println("carrier.getSource() = " + carrier.getSource()); return carrier; } private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private void writeObject(ObjectOutputStream out) throws IOException { out.writeInt(CURRENT_SERIAL_VERSION); } private void readObject(ObjectInputStream in) throws IOException { int version = in.readInt(); } } public void doTestSpacePrediction(boolean testValueAndGradient) { Pipe p = makeSpacePredictionPipe (); Pipe p2 = new TestMEMM2String(); InstanceList instances = new InstanceList(p); instances.add(new ArrayIterator(data)); InstanceList[] lists = instances.split(new double[]{.5, .5}); MEMM memm = new MEMM(p, p2); memm.addFullyConnectedStatesForLabels(); if (testValueAndGradient) { Maximizable.ByGradient minable = memm.getMaximizableCRF(lists[0]); TestMaximizable.testValueAndGradient(minable); } else { System.out.println("Training Accuracy before training = " + memm.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy before training = " + memm.averageTokenAccuracy(lists[1])); System.out.println("Training..."); memm.train(lists[0]); System.out.println("Training Accuracy after training = " + memm.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy after training = " + memm.averageTokenAccuracy(lists[1])); System.out.println("Training results:"); for (int i = 0; i < lists[0].size(); i++) { Instance inst = lists[0].getInstance(i); Sequence input = (Sequence) inst.getData (); Sequence output = memm.transduce (input); System.out.println (output); } System.out.println ("Testing results:"); for (int i = 0; i < lists[1].size(); i++) { Instance inst = lists[1].getInstance(i); Sequence input = (Sequence) inst.getData (); Sequence output = memm.transduce (input); System.out.println (output); } } } public void doTestSpacePrediction(boolean testValueAndGradient, boolean useSaved, boolean useSparseWeights) { Pipe p = makeSpacePredictionPipe (); MEMM savedCRF; File f = new File("TestObject.obj"); InstanceList instances = new InstanceList(p); instances.add(new ArrayIterator(data)); InstanceList[] lists = instances.split(new double[]{.5, .5}); MEMM crf = new MEMM(p.getDataAlphabet(), p.getTargetAlphabet()); crf.addFullyConnectedStatesForLabels(); crf.setUseSparseWeights (useSparseWeights); if (testValueAndGradient) { Maximizable.ByGradient minable = crf.getMaximizableCRF(lists[0]); TestMaximizable.testValueAndGradient(minable); } else { System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy before training = " + crf.averageTokenAccuracy(lists[1])); savedCRF = crf; System.out.println("Training serialized crf."); crf.train(lists[0]); double preTrainAcc = crf.averageTokenAccuracy(lists[0]); double preTestAcc = crf.averageTokenAccuracy(lists[1]); System.out.println("Training Accuracy after training = " + preTrainAcc); System.out.println("Testing Accuracy after training = " + preTestAcc); try { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f)); oos.writeObject(crf); oos.close(); } catch (IOException e) { System.err.println("Exception writing file: " + e); } System.err.println("Wrote out CRF"); System.err.println("CRF parameters. hyperbolicPriorSlope: " + crf.getUseHyperbolicPriorSlope() + ". hyperbolicPriorSharpness: " + crf.getUseHyperbolicPriorSharpness() + ". gaussianPriorVariance: " + crf.getGaussianPriorVariance()); // And read it back in if (useSaved) { crf = null; try { ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f)); crf = (MEMM) ois.readObject(); ois.close(); } catch (IOException e) { System.err.println("Exception reading file: " + e); } catch (ClassNotFoundException cnfe) { System.err.println("Cound not find class reading in object: " + cnfe); } System.err.println("Read in CRF."); crf = savedCRF; double postTrainAcc = crf.averageTokenAccuracy(lists[0]); double postTestAcc = crf.averageTokenAccuracy(lists[1]); System.out.println("Training Accuracy after saving = " + postTrainAcc); System.out.println("Testing Accuracy after saving = " + postTestAcc); assertEquals(postTrainAcc, preTrainAcc, 0.0001); assertEquals(postTestAcc, preTestAcc, 0.0001); } } } public static Pipe makeSpacePredictionPipe () { Pipe p = new SerialPipes(new Pipe[]{ new CharSequence2TokenSequence("."), new TokenSequenceLowercase(), new TestMEMMTokenSequenceRemoveSpaces(), new TokenText(), new OffsetConjunctions(true, new int[][]{//{0}, /*{1},{-1,0},{0,1}, */ {1}, {-1, 0}, {0, 1},// {-2, -1, 0}, {0, 1, 2}, {-3, -2, -1}, {1, 2, 3}, //{-2,-1}, {-1,0}, {0,1}, {1,2}, //{-3,-2,-1}, {-2,-1,0}, {-1,0,1}, {0,1,2}, {1,2,3}, }),// new PrintInputAndTarget(), new TokenSequence2FeatureVectorSequence() }); return p; } public void disabledtestAddOrderNStates () { Pipe p = makeSpacePredictionPipe (); InstanceList instances = new InstanceList (p); instances.add (new ArrayIterator(data)); InstanceList[] lists = instances.split (new java.util.Random (678), new double[]{.5, .5}); // Compare 3 CRFs trained with addOrderNStates, and make sure // that having more features leads to a higher likelihood MEMM crf1 = new MEMM(p.getDataAlphabet(), p.getTargetAlphabet()); crf1.addOrderNStates (lists [0], new int[] { 1, }, new boolean[] { false, }, "START", null, null, false); crf1.train (lists [0]); MEMM crf2 = new MEMM(p.getDataAlphabet(), p.getTargetAlphabet()); crf2.addOrderNStates (lists [0], new int[] { 1, 2, }, new boolean[] { false, true }, "START", null, null, false); crf2.train (lists [0]); MEMM crf3 = new MEMM(p.getDataAlphabet(), p.getTargetAlphabet()); crf3.addOrderNStates (lists [0], new int[] { 1, 2, }, new boolean[] { false, false }, "START", null, null, false); crf3.train (lists [0]); // Prevent cached values double lik1 = getLikelihood (crf1, lists[0]); double lik2 = getLikelihood (crf2, lists[0]); double lik3 = getLikelihood (crf3, lists[0]); System.out.println("CRF1 likelihood "+lik1); assertTrue ("Final zero-order likelihood <"+lik1+"> greater than first-order <"+lik2+">", lik1 < lik2); assertTrue ("Final defaults-only likelihood <"+lik2+"> greater than full first-order <"+lik3+">", lik2 < lik3); assertEquals (-167.335971702, lik1, 0.0001); assertEquals (-166.212235389, lik2, 0.0001); assertEquals ( -90.386005741, lik3, 0.0001); } double getLikelihood (MEMM crf, InstanceList data) { Maximizable.ByGradient mcrf = crf.getMaximizableCRF (data); // Do this elaborate thing so that crf.cachedValueStale is forced true double[] params = new double [mcrf.getNumParameters()]; mcrf.getParameters (params); mcrf.setParameters (params); return mcrf.getValue (); } public void disabledtestValueGradient() { doTestSpacePrediction(true); } public void disabledtestTrain() { doTestSpacePrediction(false); } public void disabledtestDenseTrain () { doTestSpacePrediction (false, false, false); } public void disabledtestSerialization() { doTestSpacePrediction(false, true, true); } public void disabledtestDenseSerialization () { doTestSpacePrediction(false, true, false); } public void disabledtestPrint () { Pipe p = new SerialPipes (new Pipe[] { new CharSequence2TokenSequence("."), new TokenText(), new TestMEMM.TestMEMMTokenSequenceRemoveSpaces(), new TokenSequence2FeatureVectorSequence(), new PrintInputAndTarget(), }); InstanceList one = new InstanceList (p); String[] data = new String[] { "ABCDE", }; one.add (new ArrayIterator (data)); MEMM crf = new MEMM (p, null); crf.addFullyConnectedStatesForLabels(); crf.setWeightsDimensionAsIn (one); MEMM.MaximizableCRF mcrf = crf.getMaximizableCRF(one); double[] params = new double[mcrf.getNumParameters()]; for (int i = 0; i < params.length; i++) { params [i] = i; } mcrf.setParameters (params); crf.print (); } public static Test suite() { return new TestSuite(TestMEMM.class); } public static void main(String[] args) { TestMEMM tm = new TestMEMM (""); tm.doTestSpacePrediction (true); return;/* TestSuite theSuite; if (args.length > 0) { theSuite = new TestSuite(); for (int i = 0; i < args.length; i++) { theSuite.addTest (new TestMEMM (args [i])); } } else { theSuite = (TestSuite) suite(); } junit.textui.TestRunner.run (theSuite);*/ }}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?