⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 testcrf3.java

📁 这是一个matlab的java实现。里面有许多内容。请大家慢慢捉摸。
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
    {      int version = in.readInt();    }  }  public class TestCRF2String extends Pipe implements Serializable {    public TestCRF2String()    {      super();    }    public Instance pipe(Instance carrier)    {      StringBuffer sb = new StringBuffer();      String source = (String) carrier.getSource();      Transducer.ViterbiPath vp = (Transducer.ViterbiPath) carrier.getTarget();      ArraySequence as = (ArraySequence) vp.output();      //int startLabelIndex = as.getAlphabet().lookupIndex("start");      for (int i = 0; i < source.length(); i++) {        System.out.println("target[" + i + "]=" + as.get(i).toString());        if (as.get(i).toString().equals("start") && i != 0)          sb.append(' ');        sb.append(source.charAt(i));      }      carrier.setSource(sb.toString());      System.out.println("carrier.getSource() = " + carrier.getSource());      return carrier;    }    private static final long serialVersionUID = 1;    private static final int CURRENT_SERIAL_VERSION = 0;    private void writeObject(ObjectOutputStream out) throws IOException    {      out.writeInt(CURRENT_SERIAL_VERSION);    }    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException    {      int version = in.readInt();    }  }  public void doTestSpacePrediction(boolean testCostAndGradient)  {    Pipe p = new SerialPipes(new Pipe[]{      new CharSequence2TokenSequence("."),      new TokenSequenceLowercase(),      new TestCRFTokenSequenceRemoveSpaces(),      new TokenText(),      new OffsetConjunctions(false,                             new int[][]{//{0},                               {1}, {-1, 0}, {0, 1},                               {-2, -1, 0}, {0, 1, 2}, {-3, -2, -1}, {1, 2, 3},                               //{-2,-1}, {1,2},                               //{-3,-2,-1}, {-2,-1,0}, {-1,0,1}, {0,1,2}, {1,2,3}                             }),//      new PrintInputAndTarget(),      new TokenSequence2FeatureVectorSequence()    });    Pipe p2 = new TestCRF2String();    InstanceList instances = new InstanceList(p);    instances.add(new ArrayIterator(data));    InstanceList[] lists = instances.split(new double[]{.5, .5});    CRF3 crf = new CRF3(p, p2);    crf.priorCost(lists[0]);    crf.addFullyConnectedStatesForLabels();    if (testCostAndGradient) {      Minimizable.ByGradient minable = crf.getMinimizableCRF(lists[0]);      TestMinimizable.testCostAndGradient(minable);    } else {      System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0]));      System.out.println("Testing  Accuracy before training = " + crf.averageTokenAccuracy(lists[1]));      System.out.println("Training...");      crf.train(lists[0]);      System.out.println("Training Accuracy after training = " + crf.averageTokenAccuracy(lists[0]));      System.out.println("Testing  Accuracy after training = " + crf.averageTokenAccuracy(lists[1]));      System.out.println("Training results:");      for (int i = 0; i < lists[0].size(); i++) {        Instance instance = crf.transduce(lists[0].getInstance(i));        System.out.println(instance.getSource());      }      System.out.println("Testing results:");      for (int i = 0; i < lists[1].size(); i++) {        Instance instance = crf.transduce(lists[1].getInstance(i));        System.out.println(instance.getSource());      }    }  }  public void doTestSpacePrediction(boolean testCostAndGradient, boolean useSaved)  {    Pipe p = new SerialPipes(new Pipe[]{      new CharSequence2TokenSequence("."),      new TokenSequenceLowercase(),      new TestCRFTokenSequenceRemoveSpaces(),      new TokenText(),      new OffsetConjunctions(false,                             new int[][]{//{0}, /*{1},{-1,0},{0,1}, */                               //{-2,-1}, {-1,0}, {0,1}, {1,2},                               //{-3,-2,-1}, {-2,-1,0}, {-1,0,1}, {0,1,2}, {1,2,3},                             }),//      new PrintInputAndTarget(),      new TokenSequence2FeatureVectorSequence()    });    CRF3 savedCRF;    File f = new File("TestObject.obj");    InstanceList instances = new InstanceList(p);    instances.add(new ArrayIterator(data));    InstanceList[] lists = instances.split(new double[]{.5, .5});    CRF3 crf = new CRF3(p.getDataAlphabet(), p.getTargetAlphabet());    crf.priorCost(lists[0]);    crf.addFullyConnectedStatesForLabels();    if (testCostAndGradient) {      Minimizable.ByGradient minable = crf.getMinimizableCRF(lists[0]);      TestMinimizable.testCostAndGradient(minable);    } else {      System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0]));      System.out.println("Testing  Accuracy before training = " + crf.averageTokenAccuracy(lists[1]));      savedCRF = crf;      System.out.println("Training serialized crf.");      crf.train(lists[0]);      double preTrainAcc = crf.averageTokenAccuracy(lists[0]);      double preTestAcc = crf.averageTokenAccuracy(lists[1]);      System.out.println("Training Accuracy after training = " + preTrainAcc);      System.out.println("Testing  Accuracy after training = " + preTestAcc);      try {        ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));        oos.writeObject(crf);        oos.close();      } catch (IOException e) {        System.err.println("Exception writing file: " + e);      }      System.err.println("Wrote out CRF");      System.err.println("CRF parameters. hyperbolicPriorSlope: " + crf.getUseHyperbolicPriorSlope() + ". hyperbolicPriorSharpness: " + crf.getUseHyperbolicPriorSharpness() + ". gaussianPriorVariance: " + crf.getGaussianPriorVariance());      // And read it back in      if (useSaved) {        crf = null;        try {          ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f));          crf = (CRF3) ois.readObject();          ois.close();        } catch (IOException e) {          System.err.println("Exception reading file: " + e);        } catch (ClassNotFoundException cnfe) {          System.err.println("Cound not find class reading in object: " + cnfe);        }        System.err.println("Read in CRF.");        crf = savedCRF;        double postTrainAcc = crf.averageTokenAccuracy(lists[0]);        double postTestAcc = crf.averageTokenAccuracy(lists[1]);        System.out.println("Training Accuracy after saving = " + postTrainAcc);        System.out.println("Testing  Accuracy after saving = " + postTestAcc);        assertEquals(postTrainAcc, preTrainAcc, 0.0001);        assertEquals(postTestAcc, preTestAcc, 0.0001);      }    }  }  public void testCostGradient()  {    doTestSpacePrediction(true);  }  public void testTrain()  {    doTestSpacePrediction(false);  }  public void testSerialization()  {    doTestSpacePrediction(false, true);  }  public static Test suite()  {    return new TestSuite(TestCRF3.class);  }  public static void main(String[] args)  {    //args = new String[] {"testCostGradient"};    if (args.length > 0) {      TestCRF3 t = new TestCRF3("testTrain");      if (args[0].equals("testTrain"))        t.testTrain();      else if (args[0].equals("testSerialization"))        t.testSerialization();      else if (args[0].equals("testGetSetParameters"))        t.testGetSetParameters();      else if (args[0].equals("testCostGradient"))        t.testCostGradient();      else if (args[0].equals("testCost"))        t.testCost(Integer.parseInt(args[1]));      else {        System.err.println("Unrecognized test.");        System.exit(-1);      }    } else {      junit.textui.TestRunner.run(suite());    }  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -