⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 testinference.java

📁 这是一个matlab的java实现。里面有许多内容。请大家慢慢捉摸。
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
        DiscretePotential maxPotBp = maxprod.lookupMarginal(var);        DiscretePotential maxPotTrp = trp.lookupMarginal(var);				maxPotTrp.delogify();				maxPotBp.normalize();				maxPotTrp.normalize();        assertTrue("TRP 1 iter maxprod propagation not the same as plain maxProd!\n" +                   "Trp " + maxPotTrp + "\n Plain maxprod " + maxPotBp,                   maxPotBp.almostEquals(maxPotTrp));      }    }  }  // Tests that TRP and max-product propagation return the same  // results when TRP is allowed to run to convergence.  public void testTrpViterbiEquiv2()  {    for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) {      UndirectedModel mdl = trees[mdlIdx];      ViterbiPropagation maxprod = new ViterbiPropagation();      ViterbiTRP trp = new ViterbiTRP();      maxprod.computeMarginals(mdl);      trp.computeMarginals(mdl);      // TRP should return same results as viterbi      for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) {        Variable var = (Variable) it.next();        DiscretePotential maxPotBp = maxprod.lookupMarginal(var);        DiscretePotential maxPotTrp = trp.lookupMarginal(var);        maxPotTrp.delogify();        assertTrue("TRP maxprod propagation not the same as plain maxProd!\n" +                   "Trp " + maxPotTrp + "\n Plain maxprod " + maxPotBp,                   maxPotBp.almostEquals(maxPotTrp));      }    }  }  public void testTreeViterbi()  {    for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) {      UndirectedModel mdl = trees[mdlIdx];      BruteForceInferencer brute = new BruteForceInferencer();      ViterbiPropagation maxprod = new ViterbiPropagation();      DiscretePotential joint = brute.joint(mdl);      maxprod.computeMarginals(mdl);      for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) {        Variable var = (Variable) it.next();        DiscretePotential maxPot = maxprod.lookupMarginal(var);        DiscretePotential trueMaxPot = joint.extractMax(var);        maxPot.normalize();        trueMaxPot.normalize();        assertTrue("Maximization failed! Normalized returns:\n" + maxPot                   + "\nTrue: " + trueMaxPot,                   maxPot.almostEquals(trueMaxPot));      }    }    logger.info("Test treeViterbi passed: " + trees.length + " models.");  }  public void testJtViterbi()  {    JunctionTreeInferencer jti = new JunctionTreeInferencer();    for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) {      UndirectedModel mdl = models[mdlIdx];      JunctionTree jt = jti.buildJunctionTree(mdl);       BruteForceInferencer brute = new BruteForceInferencer();      ViterbiPropagation maxprod = new ViterbiPropagation();      DiscretePotential joint = brute.joint(mdl);      maxprod.computeMarginals(jt);      for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) {        Variable var = (Variable) it.next();        DiscretePotential maxPot = maxprod.lookupMarginal(var);        DiscretePotential trueMaxPot = joint.extractMax(var);        maxPot.normalize();        trueMaxPot.normalize();        assertTrue("Maximization failed on model " + mdlIdx                   + " ! Normalized returns:\n" + maxPot                   + "\nTrue: " + trueMaxPot,                   maxPot.almostEquals(trueMaxPot));      }    }    logger.info("Test jtViterbi passed.");  }        /*  public void testMM() throws Exception  {      testQuery();      testTreeViterbi();      testTrpViterbiEquiv();      testTrpViterbiEquiv2();      testMaxMarginals();  }    */  public void testMaxMarginals() throws Exception  {    for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) {//    { int mdlIdx = 4;      UndirectedModel mdl = models[mdlIdx];//      if (mdlIdx != 3) {//        Visualizer.showModel(mdl);//          mdl.dump(); System.out.println ("***END MDL "+mdlIdx+"***");//      }      BruteForceInferencer brute = new BruteForceInferencer();      DiscretePotential joint = brute.joint(mdl);//      long foo = System.currentTimeMillis ();//      System.out.println(foo);      for (int infIdx = 0; infIdx < maxMargAlgs.length; infIdx++) {        Inferencer inf = (Inferencer) maxMargAlgs[infIdx].newInstance();        if (inf instanceof ViterbiTRP)          ((ViterbiTRP)inf).setRandomSeed(42);        inf.computeMarginals(mdl);        for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) {          Variable var = (Variable) it.next();          DiscretePotential maxPot = inf.lookupMarginal(var);          DiscretePotential trueMaxPot = joint.extractMax(var);          if (maxPot.argmax() != trueMaxPot.argmax()) {            logger.warning("Argmax not equal on model " + mdlIdx + " inferencer "                           + inf + " !\n  Potentials:\nReturned: " + maxPot +                           "\nTrue: " + trueMaxPot);            System.err.println("Dump of model " + mdlIdx + " ***");            mdl.dump();            assertTrue (maxPot.argmax() == trueMaxPot.argmax());          }        }      }    }    logger.info("Test maxMarginals passed.");  }  public void testBeliefPropagation()  {    for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) {      UndirectedModel mdl = trees[mdlIdx];      BeliefPropagation prop = new BeliefPropagation();      System.out.println(mdl);      prop.computeMarginals(mdl);      for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) {        Variable var = (Variable) it.next();        DiscretePotential marg1 = treeMargs[mdlIdx][mdl.getIndex(var)];        DiscretePotential marg2 = prop.lookupMarginal(var);        try {          assertTrue("Test failed on graph " + mdlIdx + " vertex " + var + "\n" +                     "Model: " + mdl + "\nExpected: " + marg1 + "\nActual: " + marg2,                     marg1.almostEquals(marg2, 0.011));        } catch (AssertionFailedError e) {          System.out.println("*******************************************\nMODEL:\n");          mdl.dump();          System.out.println("*******************************************\nMESSAGES:\n");          prop.dump();          throw e;        }      }    }    logger.info("Test beliefPropagation passed.");  }  protected void setUp()  {    modelsList = createTestModels();    createTestTrees();    models = (UndirectedModel[]) modelsList.toArray            (new UndirectedModel[]{});    computeTestTreeMargs();  }  public void testMultiply()  {    MultinomialPotential p1 = new MultinomialPotential            (new Variable[]{});    System.out.println(p1);    Variable[] vars = new Variable[]{      new Variable(2),      new Variable(2),    };    double[] probs = new double[]{1, 3, 5, 6};    MultinomialPotential p2 = new MultinomialPotential            (vars, probs);    DiscretePotential p3 = p1.multiply(p2);    assertTrue("Should be equal: " + p2 + "\n" + p3,               p2.almostEquals(p3));  }  /* TODO: Not sure how to test this anymore.	// Test multiplication of potentials where variables are in	//  a different order	public void testMultiplication2 ()	{		Variable[] vars = new Variable[] {			new Variable (2),			new Variable (2),		};		double[] probs1 = new double[] { 2, 4, 1, 6 };		double[] probs2a = new double[] { 3, 7, 6, 5 };		double[] probs2b = new double[] { 3, 6, 7, 5 };		MultinomialPotential ptl1a = new MultinomialPotential (vars, probs1);		MultinomialPotential ptl1b = new MultinomialPotential (vars, probs1);		MultinomialPotential ptl2a = new MultinomialPotential (vars, probs2a);		Variable[] vars2 = new Variable[] { vars[1], vars[0], };		MultinomialPotential ptl2b = new MultinomialPotential (vars2, probs2b);		ptl1a.multiplyBy (ptl2a);		ptl1b.multiplyBy (ptl2b);				assertTrue (ptl1a.almostEquals (ptl1b));	}  */	public void testLogMarginalize ()	{		UndirectedModel mdl = models [0];		Iterator it = mdl.getVerticesIterator();		Variable v1 = (Variable) it.next();		Variable v2 = (Variable) it.next();		Random rand = new Random (3214123);		for (int i = 0; i < 10; i++) {			DiscretePotential ptl = randomEdgePotential (rand, v1, v2);			DiscretePotential logmarg1 = ptl.log().marginalize (v1);			DiscretePotential marglog1 = ptl.marginalize (v1).log();			assertTrue ("LogMarg failed! Correct: "+marglog1+" Log-marg: "+logmarg1,									logmarg1.almostEquals (marglog1));			DiscretePotential logmarg2 = ptl.log().marginalize (v2);			DiscretePotential marglog2 = ptl.marginalize (v2).log();			assertTrue (logmarg2.almostEquals (marglog2));		}	}	public void testLogNormalize ()	{		UndirectedModel mdl = models [0];		Iterator it = mdl.getVerticesIterator();		Variable v1 = (Variable) it.next();		Variable v2 = (Variable) it.next();		Random rand = new Random (3214123);		for (int i = 0; i < 10; i++) {			DiscretePotential ptl = randomEdgePotential (rand, v1, v2);			DiscretePotential norm1 = ptl.log();			DiscretePotential norm2 = ptl.duplicate();			norm1.normalize();			norm2.normalize(); norm2.logify();			assertTrue ("LogNormalize failed! Correct: "+norm2+" Log-normed: "+norm1,									norm1.almostEquals (norm2));		}	}	public void testSumLogProb ()	{		java.util.Random rand = new java.util.Random (3214123);				for (int i = 0; i < 10; i++) {			double v1 = rand.nextDouble();			double v2 = rand.nextDouble();						double sum1 = Math.log (v1 + v2);			double sum2 = Maths.sumLogProb (Math.log(v1), Math.log (v2));//			System.out.println("Summing "+v1+" + "+v2);			assertEquals (sum1, sum2, 0.00001);		}			}  public void testInfiniteCost()  {    Variable[] vars = new Variable[3];    for (int i = 0; i < vars.length; i++) {      vars[i] = new Variable(2);    }    UndirectedModel mdl = new UndirectedModel(vars);    mdl.addEdgeWithPotential(vars[0], vars[1], new double[]{2, 6, 4, 8});    mdl.addEdgeWithPotential(vars[1], vars[2], new double[]{1, 0, 0, 1});    mdl.dump();    BeliefPropagation bp = new BeliefPropagation();    mdl.computeMarginals(bp);    //below should be true, except potentials have different ranges.    //assertTrue (bp.lookupMarginal(vars[1]).almostEquals (bp.lookupMarginal(vars[2])));  }  public void testJtCaching()  {    // clear all caches    for (int i = 0; i < models.length; i++) {      UndirectedModel model = models[i];      model.setInferenceCache (JunctionTreeInferencer.class, null);    }    DiscretePotential[][] margs = new DiscretePotential[models.length][];    long stime1 = new Date().getTime();    for (int i = 0; i < models.length; i++) {      UndirectedModel model = models[i];      JunctionTreeInferencer inf = new JunctionTreeInferencer();      inf.computeMarginals(model);      margs[i] = new DiscretePotential[model.getVerticesCount()];      Iterator it = model.getVerticesIterator();      int j = -1;      while (it.hasNext()) {        Variable var = (Variable) it.next();        j++;        margs[i][j] = inf.lookupMarginal(var);      }    }    long etime1 = new Date().getTime();    long diff1 = etime1 - stime1;    logger.info ("Pre-cache took "+diff1+" ms.");    long stime2 = new Date().getTime();    for (int i = 0; i < models.length; i++) {      UndirectedModel model = models[i];      JunctionTreeInferencer inf = new JunctionTreeInferencer();      inf.computeMarginals(model);      Iterator it = model.getVerticesIterator();      int j = -1;      while (it.hasNext()) {        Variable var = (Variable) it.next();        j++;        assertTrue (margs[i][j].almostEquals (inf.lookupMarginal (var)));      }    }    long etime2 = new Date().getTime();    long diff2 = etime2 - stime2;    logger.info ("Post-cache took "+diff2+" ms.");//    assertTrue (diff2 < diff1);  }	public void testFindVariable ()	{		UndirectedModel mdl = models [0];		Variable[] vars = new Variable [mdl.getVerticesCount()];    Iterator it = mdl.getVerticesIterator();    while (it.hasNext()) {      Variable var = (Variable) it.next();      String name = new String (var.getLabel());      assertTrue (var == mdl.findVariable (name));    }    assertTrue (mdl.findVariable ("xsdfasdf") == null);	}	public void timeMarginalization () 	{		java.util.Random r = new java.util.Random (7732847);		Variable[] vars = new Variable[] { new Variable (2),																			 new Variable (2),		};		MultinomialPotential ptl = randomEdgePotential (r, vars[0], vars[1]);		long stime = System.currentTimeMillis ();		for (int i = 0; i < 1000; i++) {			DiscretePotential marg = ptl.marginalize (vars[0]);			DiscretePotential marg2 = ptl.marginalize (vars[1]);		}		long etime = System.currentTimeMillis ();		logger.info ("Marginalization (2-outcome) took "+(etime-stime)+" ms.");		Variable[] vars45 = new Variable[] { new Variable (45),																			 new Variable (45),		};		MultinomialPotential ptl45 = randomEdgePotential (r, vars45[0], vars45[1]);		stime = System.currentTimeMillis();		for (int i = 0; i < 1000; i++) {			DiscretePotential marg = ptl45.marginalize (vars45[0]);			DiscretePotential marg2 = ptl45.marginalize (vars45[1]);					}		etime = System.currentTimeMillis();		logger.info ("Marginalization (45-outcome) took "+(etime-stime)+" ms.");	}  // using this for profiling  public void runJunctionTree ()  {    for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) {      UndirectedModel model = models[mdlIdx];      JunctionTreeInferencer inf = new JunctionTreeInferencer();      inf.computeMarginals(model);      Iterator it = model.getVerticesIterator ();      while (it.hasNext()) {        Variable var = (Variable) it.next();        inf.lookupMarginal (var);      }    }  }	public void testDestructiveAssignment ()	{		Variable vars[] = { new Variable(2), new Variable (2), };		Assignment assn = new Assignment (vars, new int[] { 0, 1 });		assertEquals (0, assn.get (vars[0]));		assertEquals (1, assn.get (vars[1]));		assn.setValue (vars[0], 1);		assertEquals (1, assn.get (vars[0]));		assertEquals (1, assn.get (vars[1]));	}	public void testLoopyConvergence ()	{		Random r = new Random (67);		UndirectedModel mdl = createRandomGrid (5, 5, 2, r);		LoopyBP loopy = new LoopyBP ();		loopy.computeMarginals (mdl);		assertTrue (loopy.iterationsUsed() > 8);	}  public static Test suite()  {    return new TestSuite(TestInference.class);  }  public static void main(String[] args) throws Exception  {    TestSuite theSuite;    if (args.length > 0) {      theSuite = new TestSuite();      for (int i = 0; i < args.length; i++) {        theSuite.addTest(new TestInference(args[i]));      }    } else {      theSuite = (TestSuite) suite();    }    junit.textui.TestRunner.run(theSuite);  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -