📄 testinference.java
字号:
DiscretePotential maxPotBp = maxprod.lookupMarginal(var); DiscretePotential maxPotTrp = trp.lookupMarginal(var); maxPotTrp.delogify(); maxPotBp.normalize(); maxPotTrp.normalize(); assertTrue("TRP 1 iter maxprod propagation not the same as plain maxProd!\n" + "Trp " + maxPotTrp + "\n Plain maxprod " + maxPotBp, maxPotBp.almostEquals(maxPotTrp)); } } } // Tests that TRP and max-product propagation return the same // results when TRP is allowed to run to convergence. public void testTrpViterbiEquiv2() { for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) { UndirectedModel mdl = trees[mdlIdx]; ViterbiPropagation maxprod = new ViterbiPropagation(); ViterbiTRP trp = new ViterbiTRP(); maxprod.computeMarginals(mdl); trp.computeMarginals(mdl); // TRP should return same results as viterbi for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) { Variable var = (Variable) it.next(); DiscretePotential maxPotBp = maxprod.lookupMarginal(var); DiscretePotential maxPotTrp = trp.lookupMarginal(var); maxPotTrp.delogify(); assertTrue("TRP maxprod propagation not the same as plain maxProd!\n" + "Trp " + maxPotTrp + "\n Plain maxprod " + maxPotBp, maxPotBp.almostEquals(maxPotTrp)); } } } public void testTreeViterbi() { for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) { UndirectedModel mdl = trees[mdlIdx]; BruteForceInferencer brute = new BruteForceInferencer(); ViterbiPropagation maxprod = new ViterbiPropagation(); DiscretePotential joint = brute.joint(mdl); maxprod.computeMarginals(mdl); for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) { Variable var = (Variable) it.next(); DiscretePotential maxPot = maxprod.lookupMarginal(var); DiscretePotential trueMaxPot = joint.extractMax(var); maxPot.normalize(); trueMaxPot.normalize(); assertTrue("Maximization failed! Normalized returns:\n" + maxPot + "\nTrue: " + trueMaxPot, maxPot.almostEquals(trueMaxPot)); } } logger.info("Test treeViterbi passed: " + trees.length + " models."); } public void testJtViterbi() { JunctionTreeInferencer jti = new JunctionTreeInferencer(); for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { UndirectedModel mdl = models[mdlIdx]; JunctionTree jt = jti.buildJunctionTree(mdl); BruteForceInferencer brute = new BruteForceInferencer(); ViterbiPropagation maxprod = new ViterbiPropagation(); DiscretePotential joint = brute.joint(mdl); maxprod.computeMarginals(jt); for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) { Variable var = (Variable) it.next(); DiscretePotential maxPot = maxprod.lookupMarginal(var); DiscretePotential trueMaxPot = joint.extractMax(var); maxPot.normalize(); trueMaxPot.normalize(); assertTrue("Maximization failed on model " + mdlIdx + " ! Normalized returns:\n" + maxPot + "\nTrue: " + trueMaxPot, maxPot.almostEquals(trueMaxPot)); } } logger.info("Test jtViterbi passed."); } /* public void testMM() throws Exception { testQuery(); testTreeViterbi(); testTrpViterbiEquiv(); testTrpViterbiEquiv2(); testMaxMarginals(); } */ public void testMaxMarginals() throws Exception { for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) {// { int mdlIdx = 4; UndirectedModel mdl = models[mdlIdx];// if (mdlIdx != 3) {// Visualizer.showModel(mdl);// mdl.dump(); System.out.println ("***END MDL "+mdlIdx+"***");// } BruteForceInferencer brute = new BruteForceInferencer(); DiscretePotential joint = brute.joint(mdl);// long foo = System.currentTimeMillis ();// System.out.println(foo); for (int infIdx = 0; infIdx < maxMargAlgs.length; infIdx++) { Inferencer inf = (Inferencer) maxMargAlgs[infIdx].newInstance(); if (inf instanceof ViterbiTRP) ((ViterbiTRP)inf).setRandomSeed(42); inf.computeMarginals(mdl); for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) { Variable var = (Variable) it.next(); DiscretePotential maxPot = inf.lookupMarginal(var); DiscretePotential trueMaxPot = joint.extractMax(var); if (maxPot.argmax() != trueMaxPot.argmax()) { logger.warning("Argmax not equal on model " + mdlIdx + " inferencer " + inf + " !\n Potentials:\nReturned: " + maxPot + "\nTrue: " + trueMaxPot); System.err.println("Dump of model " + mdlIdx + " ***"); mdl.dump(); assertTrue (maxPot.argmax() == trueMaxPot.argmax()); } } } } logger.info("Test maxMarginals passed."); } public void testBeliefPropagation() { for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) { UndirectedModel mdl = trees[mdlIdx]; BeliefPropagation prop = new BeliefPropagation(); System.out.println(mdl); prop.computeMarginals(mdl); for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) { Variable var = (Variable) it.next(); DiscretePotential marg1 = treeMargs[mdlIdx][mdl.getIndex(var)]; DiscretePotential marg2 = prop.lookupMarginal(var); try { assertTrue("Test failed on graph " + mdlIdx + " vertex " + var + "\n" + "Model: " + mdl + "\nExpected: " + marg1 + "\nActual: " + marg2, marg1.almostEquals(marg2, 0.011)); } catch (AssertionFailedError e) { System.out.println("*******************************************\nMODEL:\n"); mdl.dump(); System.out.println("*******************************************\nMESSAGES:\n"); prop.dump(); throw e; } } } logger.info("Test beliefPropagation passed."); } protected void setUp() { modelsList = createTestModels(); createTestTrees(); models = (UndirectedModel[]) modelsList.toArray (new UndirectedModel[]{}); computeTestTreeMargs(); } public void testMultiply() { MultinomialPotential p1 = new MultinomialPotential (new Variable[]{}); System.out.println(p1); Variable[] vars = new Variable[]{ new Variable(2), new Variable(2), }; double[] probs = new double[]{1, 3, 5, 6}; MultinomialPotential p2 = new MultinomialPotential (vars, probs); DiscretePotential p3 = p1.multiply(p2); assertTrue("Should be equal: " + p2 + "\n" + p3, p2.almostEquals(p3)); } /* TODO: Not sure how to test this anymore. // Test multiplication of potentials where variables are in // a different order public void testMultiplication2 () { Variable[] vars = new Variable[] { new Variable (2), new Variable (2), }; double[] probs1 = new double[] { 2, 4, 1, 6 }; double[] probs2a = new double[] { 3, 7, 6, 5 }; double[] probs2b = new double[] { 3, 6, 7, 5 }; MultinomialPotential ptl1a = new MultinomialPotential (vars, probs1); MultinomialPotential ptl1b = new MultinomialPotential (vars, probs1); MultinomialPotential ptl2a = new MultinomialPotential (vars, probs2a); Variable[] vars2 = new Variable[] { vars[1], vars[0], }; MultinomialPotential ptl2b = new MultinomialPotential (vars2, probs2b); ptl1a.multiplyBy (ptl2a); ptl1b.multiplyBy (ptl2b); assertTrue (ptl1a.almostEquals (ptl1b)); } */ public void testLogMarginalize () { UndirectedModel mdl = models [0]; Iterator it = mdl.getVerticesIterator(); Variable v1 = (Variable) it.next(); Variable v2 = (Variable) it.next(); Random rand = new Random (3214123); for (int i = 0; i < 10; i++) { DiscretePotential ptl = randomEdgePotential (rand, v1, v2); DiscretePotential logmarg1 = ptl.log().marginalize (v1); DiscretePotential marglog1 = ptl.marginalize (v1).log(); assertTrue ("LogMarg failed! Correct: "+marglog1+" Log-marg: "+logmarg1, logmarg1.almostEquals (marglog1)); DiscretePotential logmarg2 = ptl.log().marginalize (v2); DiscretePotential marglog2 = ptl.marginalize (v2).log(); assertTrue (logmarg2.almostEquals (marglog2)); } } public void testLogNormalize () { UndirectedModel mdl = models [0]; Iterator it = mdl.getVerticesIterator(); Variable v1 = (Variable) it.next(); Variable v2 = (Variable) it.next(); Random rand = new Random (3214123); for (int i = 0; i < 10; i++) { DiscretePotential ptl = randomEdgePotential (rand, v1, v2); DiscretePotential norm1 = ptl.log(); DiscretePotential norm2 = ptl.duplicate(); norm1.normalize(); norm2.normalize(); norm2.logify(); assertTrue ("LogNormalize failed! Correct: "+norm2+" Log-normed: "+norm1, norm1.almostEquals (norm2)); } } public void testSumLogProb () { java.util.Random rand = new java.util.Random (3214123); for (int i = 0; i < 10; i++) { double v1 = rand.nextDouble(); double v2 = rand.nextDouble(); double sum1 = Math.log (v1 + v2); double sum2 = Maths.sumLogProb (Math.log(v1), Math.log (v2));// System.out.println("Summing "+v1+" + "+v2); assertEquals (sum1, sum2, 0.00001); } } public void testInfiniteCost() { Variable[] vars = new Variable[3]; for (int i = 0; i < vars.length; i++) { vars[i] = new Variable(2); } UndirectedModel mdl = new UndirectedModel(vars); mdl.addEdgeWithPotential(vars[0], vars[1], new double[]{2, 6, 4, 8}); mdl.addEdgeWithPotential(vars[1], vars[2], new double[]{1, 0, 0, 1}); mdl.dump(); BeliefPropagation bp = new BeliefPropagation(); mdl.computeMarginals(bp); //below should be true, except potentials have different ranges. //assertTrue (bp.lookupMarginal(vars[1]).almostEquals (bp.lookupMarginal(vars[2]))); } public void testJtCaching() { // clear all caches for (int i = 0; i < models.length; i++) { UndirectedModel model = models[i]; model.setInferenceCache (JunctionTreeInferencer.class, null); } DiscretePotential[][] margs = new DiscretePotential[models.length][]; long stime1 = new Date().getTime(); for (int i = 0; i < models.length; i++) { UndirectedModel model = models[i]; JunctionTreeInferencer inf = new JunctionTreeInferencer(); inf.computeMarginals(model); margs[i] = new DiscretePotential[model.getVerticesCount()]; Iterator it = model.getVerticesIterator(); int j = -1; while (it.hasNext()) { Variable var = (Variable) it.next(); j++; margs[i][j] = inf.lookupMarginal(var); } } long etime1 = new Date().getTime(); long diff1 = etime1 - stime1; logger.info ("Pre-cache took "+diff1+" ms."); long stime2 = new Date().getTime(); for (int i = 0; i < models.length; i++) { UndirectedModel model = models[i]; JunctionTreeInferencer inf = new JunctionTreeInferencer(); inf.computeMarginals(model); Iterator it = model.getVerticesIterator(); int j = -1; while (it.hasNext()) { Variable var = (Variable) it.next(); j++; assertTrue (margs[i][j].almostEquals (inf.lookupMarginal (var))); } } long etime2 = new Date().getTime(); long diff2 = etime2 - stime2; logger.info ("Post-cache took "+diff2+" ms.");// assertTrue (diff2 < diff1); } public void testFindVariable () { UndirectedModel mdl = models [0]; Variable[] vars = new Variable [mdl.getVerticesCount()]; Iterator it = mdl.getVerticesIterator(); while (it.hasNext()) { Variable var = (Variable) it.next(); String name = new String (var.getLabel()); assertTrue (var == mdl.findVariable (name)); } assertTrue (mdl.findVariable ("xsdfasdf") == null); } public void timeMarginalization () { java.util.Random r = new java.util.Random (7732847); Variable[] vars = new Variable[] { new Variable (2), new Variable (2), }; MultinomialPotential ptl = randomEdgePotential (r, vars[0], vars[1]); long stime = System.currentTimeMillis (); for (int i = 0; i < 1000; i++) { DiscretePotential marg = ptl.marginalize (vars[0]); DiscretePotential marg2 = ptl.marginalize (vars[1]); } long etime = System.currentTimeMillis (); logger.info ("Marginalization (2-outcome) took "+(etime-stime)+" ms."); Variable[] vars45 = new Variable[] { new Variable (45), new Variable (45), }; MultinomialPotential ptl45 = randomEdgePotential (r, vars45[0], vars45[1]); stime = System.currentTimeMillis(); for (int i = 0; i < 1000; i++) { DiscretePotential marg = ptl45.marginalize (vars45[0]); DiscretePotential marg2 = ptl45.marginalize (vars45[1]); } etime = System.currentTimeMillis(); logger.info ("Marginalization (45-outcome) took "+(etime-stime)+" ms."); } // using this for profiling public void runJunctionTree () { for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { UndirectedModel model = models[mdlIdx]; JunctionTreeInferencer inf = new JunctionTreeInferencer(); inf.computeMarginals(model); Iterator it = model.getVerticesIterator (); while (it.hasNext()) { Variable var = (Variable) it.next(); inf.lookupMarginal (var); } } } public void testDestructiveAssignment () { Variable vars[] = { new Variable(2), new Variable (2), }; Assignment assn = new Assignment (vars, new int[] { 0, 1 }); assertEquals (0, assn.get (vars[0])); assertEquals (1, assn.get (vars[1])); assn.setValue (vars[0], 1); assertEquals (1, assn.get (vars[0])); assertEquals (1, assn.get (vars[1])); } public void testLoopyConvergence () { Random r = new Random (67); UndirectedModel mdl = createRandomGrid (5, 5, 2, r); LoopyBP loopy = new LoopyBP (); loopy.computeMarginals (mdl); assertTrue (loopy.iterationsUsed() > 8); } public static Test suite() { return new TestSuite(TestInference.class); } public static void main(String[] args) throws Exception { TestSuite theSuite; if (args.length > 0) { theSuite = new TestSuite(); for (int i = 0; i < args.length; i++) { theSuite.addTest(new TestInference(args[i])); } } else { theSuite = (TestSuite) suite(); } junit.textui.TestRunner.run(theSuite); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -