⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 testcrf4.java

📁 mallet是自然语言处理、机器学习领域的一个开源项目。
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
    private static final int CURRENT_SERIAL_VERSION = 0;    private void writeObject(ObjectOutputStream out) throws IOException    {      out.writeInt(CURRENT_SERIAL_VERSION);    }    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException    {      int version = in.readInt();    }  }  public void doTestSpacePrediction(boolean testValueAndGradient)  {    Pipe p = makeSpacePredictionPipe ();    Pipe p2 = new TestCRF2String();    InstanceList instances = new InstanceList(p);    instances.add(new ArrayIterator(data));    InstanceList[] lists = instances.split(new double[]{.5, .5});    CRF4 crf = new CRF4(p, p2);    crf.addFullyConnectedStatesForLabels();    if (testValueAndGradient) {      Maximizable.ByGradient minable = crf.getMaximizableCRF(lists[0]);      TestMaximizable.testValueAndGradient(minable);    } else {      System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0]));      System.out.println("Testing  Accuracy before training = " + crf.averageTokenAccuracy(lists[1]));      System.out.println("Training...");      crf.train(lists[0]);      System.out.println("Training Accuracy after training = " + crf.averageTokenAccuracy(lists[0]));      System.out.println("Testing  Accuracy after training = " + crf.averageTokenAccuracy(lists[1]));      System.out.println("Training results:");      for (int i = 0; i < lists[0].size(); i++) {        Instance inst = lists[0].getInstance(i);        Sequence input = (Sequence) inst.getData ();        Sequence output = crf.transduce (input);        System.out.println (output);      }      System.out.println ("Testing results:");      for (int i = 0; i < lists[1].size(); i++) {        Instance inst = lists[1].getInstance(i);        Sequence input = (Sequence) inst.getData ();        Sequence output = crf.transduce (input);        System.out.println (output);      }    }  }  public void doTestSpacePrediction(boolean testValueAndGradient, 																		boolean useSaved,																		boolean useSparseWeights)  {    Pipe p = makeSpacePredictionPipe ();    CRF4 savedCRF;    File f = new File("TestObject.obj");    InstanceList instances = new InstanceList(p);    instances.add(new ArrayIterator(data));    InstanceList[] lists = instances.split(new double[]{.5, .5});    CRF4 crf = new CRF4(p.getDataAlphabet(), p.getTargetAlphabet());    crf.addFullyConnectedStatesForLabels();		crf.setUseSparseWeights (useSparseWeights);    if (testValueAndGradient) {      Maximizable.ByGradient minable = crf.getMaximizableCRF(lists[0]);      TestMaximizable.testValueAndGradient(minable);    } else {      System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0]));      System.out.println("Testing  Accuracy before training = " + crf.averageTokenAccuracy(lists[1]));      savedCRF = crf;      System.out.println("Training serialized crf.");      crf.train(lists[0]);      double preTrainAcc = crf.averageTokenAccuracy(lists[0]);      double preTestAcc = crf.averageTokenAccuracy(lists[1]);      System.out.println("Training Accuracy after training = " + preTrainAcc);      System.out.println("Testing  Accuracy after training = " + preTestAcc);      try {        ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));        oos.writeObject(crf);        oos.close();      } catch (IOException e) {        System.err.println("Exception writing file: " + e);      }      System.err.println("Wrote out CRF");      System.err.println("CRF parameters. hyperbolicPriorSlope: " + crf.getUseHyperbolicPriorSlope() + ". hyperbolicPriorSharpness: " + crf.getUseHyperbolicPriorSharpness() + ". gaussianPriorVariance: " + crf.getGaussianPriorVariance());      // And read it back in      if (useSaved) {        crf = null;        try {          ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f));          crf = (CRF4) ois.readObject();          ois.close();        } catch (IOException e) {          System.err.println("Exception reading file: " + e);        } catch (ClassNotFoundException cnfe) {          System.err.println("Cound not find class reading in object: " + cnfe);        }        System.err.println("Read in CRF.");        crf = savedCRF;        double postTrainAcc = crf.averageTokenAccuracy(lists[0]);        double postTestAcc = crf.averageTokenAccuracy(lists[1]);        System.out.println("Training Accuracy after saving = " + postTrainAcc);        System.out.println("Testing  Accuracy after saving = " + postTestAcc);        assertEquals(postTrainAcc, preTrainAcc, 0.0001);        assertEquals(postTestAcc, preTestAcc, 0.0001);      }    }  }  private Pipe makeSpacePredictionPipe ()  {    Pipe p = new SerialPipes(new Pipe[]{      new CharSequence2TokenSequence("."),      new TokenSequenceLowercase(),      new TestCRFTokenSequenceRemoveSpaces(),      new TokenText(),      new OffsetConjunctions(true,                             new int[][]{//{0}, /*{1},{-1,0},{0,1}, */                               {1}, {-1, 0}, {0, 1},                               {-2, -1, 0}, {0, 1, 2}, {-3, -2, -1}, {1, 2, 3},                               //{-2,-1}, {-1,0}, {0,1}, {1,2},                               //{-3,-2,-1}, {-2,-1,0}, {-1,0,1}, {0,1,2}, {1,2,3},                             }),//      new PrintInputAndTarget(),      new TokenSequence2FeatureVectorSequence()    });    return p;  }  public void testAddOrderNStates ()	{    Pipe p = makeSpacePredictionPipe ();    InstanceList instances = new InstanceList (p);    instances.add (new ArrayIterator(data));    InstanceList[] lists = instances.split (new java.util.Random (678), new double[]{.5, .5});		// Compare 3 CRFs trained with addOrderNStates, and make sure		// that having more features leads to a higher likelihood    CRF4 crf1 = new CRF4(p.getDataAlphabet(), p.getTargetAlphabet());    crf1.addOrderNStates (lists [0],												 new int[] { 1, },												 new boolean[] { false, },												 "START",												 null,												 null,												 false);		crf1.train (lists [0]);    CRF4 crf2 = new CRF4(p.getDataAlphabet(), p.getTargetAlphabet());    crf2.addOrderNStates (lists [0],													 new int[] { 1, 2, },													 new boolean[] { false, true },													 "START",													 null,													 null,													 false);		crf2.train (lists [0]);    CRF4 crf3 = new CRF4(p.getDataAlphabet(), p.getTargetAlphabet());    crf3.addOrderNStates (lists [0],												 new int[] { 1, 2, },												 new boolean[] { false, false },												 "START",												 null,												 null,												 false);		crf3.train (lists [0]);		// Prevent cached values		double lik1 = getLikelihood (crf1, lists[0]);		double lik2 = getLikelihood (crf2, lists[0]);		double lik3 = getLikelihood (crf3, lists[0]);		System.out.println("CRF1 likelihood "+lik1);				assertTrue ("Final zero-order likelihood <"+lik1+"> greater than first-order <"+lik2+">",								lik1 < lik2);		assertTrue ("Final defaults-only likelihood <"+lik2+"> greater than full first-order <"+lik3+">",								lik2 < lik3);		assertEquals (-167.2234457483949, lik1, 0.0001);		assertEquals (-165.81326484466342, lik2, 0.0001);		assertEquals (-90.37680146432787, lik3, 0.0001);	}	double getLikelihood (CRF4 crf, InstanceList data) {		Maximizable.ByGradient mcrf = crf.getMaximizableCRF (data);		// Do this elaborate thing so that crf.cachedValueStale is forced true		double[] params = new double [mcrf.getNumParameters()];		mcrf.getParameters (params);		mcrf.setParameters (params);		return mcrf.getValue ();	}  public void testFrozenWeights ()  {    Pipe p = makeSpacePredictionPipe ();    InstanceList instances = new InstanceList (p);    instances.add (new ArrayIterator (data));    CRF4 crf1 = new CRF4 (p.getDataAlphabet (), p.getTargetAlphabet ());    crf1.addFullyConnectedStatesForLabels ();    crf1.train (instances);    CRF4 crf2 = new CRF4 (p.getDataAlphabet (), p.getTargetAlphabet ());    crf2.addFullyConnectedStatesForLabels ();    // Freeze some weights, before training    for (int i = 0; i < crf2.getWeights ().length; i += 2) {      crf2.freezeWeights (i);    }    crf2.train (instances);    SparseVector[] w = crf2.getWeights ();    double[] b = crf2.getDefaultWeights ();    for (int i = 0; i < w.length; i += 2) {      assertEquals (0.0, b[i], 1e-10);      for (int loc = 0; loc < w[i].numLocations (); loc++) {        assertEquals (0.0, w[i].valueAtLocation (loc), 1e-10);      }    }    // Check that the frozen weights has worse likelihood    Maximizable.ByGradient maxable1 = crf1.getMaximizableCRF (instances);    Maximizable.ByGradient maxable2 = crf2.getMaximizableCRF (instances);    double val1 = maxable1.getValue();    double val2 = maxable2.getValue ();    assertTrue ("Error: Freezing weights helps performance!  Full "+val1+", Frozen "+val2, val1 > val2);  }  public void testValueGradient()  {    doTestSpacePrediction(true);  }  public void testTrain()  {    doTestSpacePrediction(false);  }	public void testDenseTrain ()	{

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -