testmemm.java

来自「mallet是自然语言处理、机器学习领域的一个开源项目。」· Java 代码 · 共 599 行 · 第 1/2 页

JAVA
599
字号
	      System.out.println("target[" + i + "]=" + as.get(i).toString());	      if (as.get(i).toString().equals("start") && i != 0)	        sb.append(' ');	      sb.append(source.charAt(i));	    }	    carrier.setSource(sb.toString());	    System.out.println("carrier.getSource() = " + carrier.getSource());	    return carrier;	  }	  private static final long serialVersionUID = 1;	  private static final int CURRENT_SERIAL_VERSION = 0;	  private void writeObject(ObjectOutputStream out) throws IOException	  {	    out.writeInt(CURRENT_SERIAL_VERSION);	  }	  private void readObject(ObjectInputStream in) throws IOException	  {	    int version = in.readInt();	  }	}	public void doTestSpacePrediction(boolean testValueAndGradient)	{    Pipe p = makeSpacePredictionPipe ();    Pipe p2 = new TestMEMM2String();	  InstanceList instances = new InstanceList(p);	  instances.add(new ArrayIterator(data));	  InstanceList[] lists = instances.split(new double[]{.5, .5});	  MEMM memm = new MEMM(p, p2);	  memm.addFullyConnectedStatesForLabels();	  if (testValueAndGradient) {	    Maximizable.ByGradient minable = memm.getMaximizableCRF(lists[0]);	    TestMaximizable.testValueAndGradient(minable);	  } else {	    System.out.println("Training Accuracy before training = " + memm.averageTokenAccuracy(lists[0]));	    System.out.println("Testing  Accuracy before training = " + memm.averageTokenAccuracy(lists[1]));	    System.out.println("Training...");	    memm.train(lists[0]);	    System.out.println("Training Accuracy after training = " + memm.averageTokenAccuracy(lists[0]));	    System.out.println("Testing  Accuracy after training = " + memm.averageTokenAccuracy(lists[1]));	    System.out.println("Training results:");      for (int i = 0; i < lists[0].size(); i++) {        Instance inst = lists[0].getInstance(i);        Sequence input = (Sequence) inst.getData ();        Sequence output = memm.transduce (input);        System.out.println (output);      }      System.out.println ("Testing results:");      for (int i = 0; i < lists[1].size(); i++) {        Instance inst = lists[1].getInstance(i);        Sequence input = (Sequence) inst.getData ();        Sequence output = memm.transduce (input);        System.out.println (output);      }	  }	}	public void doTestSpacePrediction(boolean testValueAndGradient,																		boolean useSaved,																		boolean useSparseWeights)	{    Pipe p = makeSpacePredictionPipe ();    MEMM savedCRF;	  File f = new File("TestObject.obj");	  InstanceList instances = new InstanceList(p);	  instances.add(new ArrayIterator(data));	  InstanceList[] lists = instances.split(new double[]{.5, .5});	  MEMM crf = new MEMM(p.getDataAlphabet(), p.getTargetAlphabet());	  crf.addFullyConnectedStatesForLabels();		crf.setUseSparseWeights (useSparseWeights);	  if (testValueAndGradient) {	    Maximizable.ByGradient minable = crf.getMaximizableCRF(lists[0]);	    TestMaximizable.testValueAndGradient(minable);	  } else {	    System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0]));	    System.out.println("Testing  Accuracy before training = " + crf.averageTokenAccuracy(lists[1]));	    savedCRF = crf;	    System.out.println("Training serialized crf.");	    crf.train(lists[0]);	    double preTrainAcc = crf.averageTokenAccuracy(lists[0]);	    double preTestAcc = crf.averageTokenAccuracy(lists[1]);	    System.out.println("Training Accuracy after training = " + preTrainAcc);	    System.out.println("Testing  Accuracy after training = " + preTestAcc);	    try {	      ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));	      oos.writeObject(crf);	      oos.close();	    } catch (IOException e) {	      System.err.println("Exception writing file: " + e);	    }	    System.err.println("Wrote out CRF");	    System.err.println("CRF parameters. hyperbolicPriorSlope: " + crf.getUseHyperbolicPriorSlope() + ". hyperbolicPriorSharpness: " + crf.getUseHyperbolicPriorSharpness() + ". gaussianPriorVariance: " + crf.getGaussianPriorVariance());	    // And read it back in	    if (useSaved) {	      crf = null;	      try {	        ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f));	        crf = (MEMM) ois.readObject();	        ois.close();	      } catch (IOException e) {	        System.err.println("Exception reading file: " + e);	      } catch (ClassNotFoundException cnfe) {	        System.err.println("Cound not find class reading in object: " + cnfe);	      }	      System.err.println("Read in CRF.");	      crf = savedCRF;	      double postTrainAcc = crf.averageTokenAccuracy(lists[0]);	      double postTestAcc = crf.averageTokenAccuracy(lists[1]);	      System.out.println("Training Accuracy after saving = " + postTrainAcc);	      System.out.println("Testing  Accuracy after saving = " + postTestAcc);	      assertEquals(postTrainAcc, preTrainAcc, 0.0001);	      assertEquals(postTestAcc, preTestAcc, 0.0001);	    }	  }	}  public static  Pipe makeSpacePredictionPipe ()  {    Pipe p = new SerialPipes(new Pipe[]{	    new CharSequence2TokenSequence("."),	    new TokenSequenceLowercase(),	    new TestMEMMTokenSequenceRemoveSpaces(),	    new TokenText(),	    new OffsetConjunctions(true,	                           new int[][]{//{0}, /*{1},{-1,0},{0,1}, */	                             {1}, {-1, 0}, {0, 1},//	                             {-2, -1, 0}, {0, 1, 2}, {-3, -2, -1}, {1, 2, 3},	                             //{-2,-1}, {-1,0}, {0,1}, {1,2},	                             //{-3,-2,-1}, {-2,-1,0}, {-1,0,1}, {0,1,2}, {1,2,3},	                           }),//      new PrintInputAndTarget(),	    new TokenSequence2FeatureVectorSequence()	  });    return p;  }  public void disabledtestAddOrderNStates ()	{    Pipe p = makeSpacePredictionPipe ();    InstanceList instances = new InstanceList (p);	  instances.add (new ArrayIterator(data));	  InstanceList[] lists = instances.split (new java.util.Random (678), new double[]{.5, .5});		// Compare 3 CRFs trained with addOrderNStates, and make sure		// that having more features leads to a higher likelihood	  MEMM crf1 = new MEMM(p.getDataAlphabet(), p.getTargetAlphabet());	  crf1.addOrderNStates (lists [0],												 new int[] { 1, },												 new boolean[] { false, },												 "START",												 null,												 null,												 false);		crf1.train (lists [0]);	  MEMM crf2 = new MEMM(p.getDataAlphabet(), p.getTargetAlphabet());	  crf2.addOrderNStates (lists [0],													 new int[] { 1, 2, },													 new boolean[] { false, true },													 "START",													 null,													 null,													 false);		crf2.train (lists [0]);	  MEMM crf3 = new MEMM(p.getDataAlphabet(), p.getTargetAlphabet());	  crf3.addOrderNStates (lists [0],												 new int[] { 1, 2, },												 new boolean[] { false, false },												 "START",												 null,												 null,												 false);		crf3.train (lists [0]);		// Prevent cached values		double lik1 = getLikelihood (crf1, lists[0]);		double lik2 = getLikelihood (crf2, lists[0]);		double lik3 = getLikelihood (crf3, lists[0]);		System.out.println("CRF1 likelihood "+lik1);		assertTrue ("Final zero-order likelihood <"+lik1+"> greater than first-order <"+lik2+">",								lik1 < lik2);		assertTrue ("Final defaults-only likelihood <"+lik2+"> greater than full first-order <"+lik3+">",								lik2 < lik3);		assertEquals (-167.335971702, lik1, 0.0001);		assertEquals (-166.212235389, lik2, 0.0001);		assertEquals ( -90.386005741, lik3, 0.0001);	}	double getLikelihood (MEMM crf, InstanceList data) {		Maximizable.ByGradient mcrf = crf.getMaximizableCRF (data);		// Do this elaborate thing so that crf.cachedValueStale is forced true		double[] params = new double [mcrf.getNumParameters()];		mcrf.getParameters (params);		mcrf.setParameters (params);		return mcrf.getValue ();	}	public void disabledtestValueGradient()	{	  doTestSpacePrediction(true);	}	public void disabledtestTrain()	{	  doTestSpacePrediction(false);	}	public void disabledtestDenseTrain ()	{		doTestSpacePrediction (false, false, false);	}	public void disabledtestSerialization()	{	  doTestSpacePrediction(false, true, true);	}	public void disabledtestDenseSerialization ()	{		doTestSpacePrediction(false, true, false);	}	public void disabledtestPrint ()	{		Pipe p = new SerialPipes (new Pipe[] {	     new CharSequence2TokenSequence("."),			 new TokenText(),			 new TestMEMM.TestMEMMTokenSequenceRemoveSpaces(),			 new TokenSequence2FeatureVectorSequence(),			 new PrintInputAndTarget(),	  });		InstanceList one = new InstanceList (p);		String[] data = new String[] { "ABCDE", };		one.add (new ArrayIterator (data));		MEMM crf = new MEMM (p, null);		crf.addFullyConnectedStatesForLabels();		crf.setWeightsDimensionAsIn (one);		MEMM.MaximizableCRF mcrf = crf.getMaximizableCRF(one);		double[] params = new double[mcrf.getNumParameters()];		for (int i = 0; i < params.length; i++) {			params [i] = i;		}		mcrf.setParameters (params);		crf.print ();	}	public static Test suite()	{	  return new TestSuite(TestMEMM.class);	}	public static void main(String[] args)	{		TestMEMM tm = new TestMEMM ("");		tm.doTestSpacePrediction (true);		return;/*		TestSuite theSuite;		if (args.length > 0) {			theSuite = new TestSuite();			for (int i = 0; i < args.length; i++) {				theSuite.addTest (new TestMEMM (args [i]));			}		} else {			theSuite = (TestSuite) suite();		}		junit.textui.TestRunner.run (theSuite);*/	}}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?