📄 testinference.java
字号:
continue; } else { throw e; } } /* lookup all marginals */ int vrt = 0; int maxV = models[mdl].getVerticesCount(); joints[mdl][alg1] = new DiscretePotential[maxV]; for (Iterator it = models[mdl].getVertexSet().iterator(); it.hasNext(); vrt++) { Variable var = (Variable) it.next(); logger.finer("Lookup marginal for model " + mdl + " vrt " + var + " alg " + alg); DiscretePotential ptl = alg.lookupMarginal(var); joints[mdl][alg1][vrt] = ptl.duplicate(); } for (vrt = 0; vrt < maxV; vrt++) { DiscretePotential joint1 = joints[mdl][alg1][vrt]; DiscretePotential joint2 = joints[mdl][alg2][vrt]; try { assertTrue(joint1.almostEquals(joint2, APPX_EPSILON)); } catch (AssertionFailedError e) { System.out.println("\n************************************\nAppx Marginal Test FAILED\n\n"); System.out.println("Inferencer: " + alg); System.out.println("Model " + mdl + " Vertex " + vrt); System.out.println(joint1); System.out.println(joint2); models[mdl].dump(); System.out.println("All marginals:"); for (int i = 0; i < maxV; i++) { System.out.println(joints[mdl][alg1][i]); } System.out.println("Correct marginals:"); for (int i = 0; i < maxV; i++) { System.out.println(joints[mdl][alg2][i]); } throw e; } } } } System.out.println("Tested " + models.length + " undirected models."); } public void testQuery () throws Exception { java.util.Random rand = new java.util.Random (15667); for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { UndirectedModel mdl = models [mdlIdx]; int size = rand.nextInt (3) + 2; Collection vars = CollectionUtils.subset (mdl.getVertexSet(), size, rand); Variable[] varArr = (Variable[]) vars.toArray (new Variable [0]); Assignment assn = new Assignment (varArr, new int [size]); BruteForceInferencer brute = new BruteForceInferencer(); DiscretePotential joint = brute.joint(mdl); double marginal = joint.marginalize(vars).phi (assn); for (int algIdx = 0; algIdx < appxAlgs.length; algIdx++) { Inferencer alg = (Inferencer) appxAlgs[algIdx].newInstance(); double returned = mdl.query (alg, assn); assertEquals ("Failure on model "+mdlIdx+" alg "+alg, marginal, returned, APPX_EPSILON); } } logger.info ("Test testQuery passed."); } // Tests the measurement of numbers of messages sent public void testNumMessages () { for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { UndirectedModel mdl = models [mdlIdx]; TRP trp = new TRP (); trp.computeMarginals (mdl); int expectedMessages = (mdl.getVerticesCount() - 1) * 2 * trp.iterationsUsed(); assertEquals (expectedMessages, trp.getTotalMessagesSent ()); LoopyBP loopy = new LoopyBP (); loopy.computeMarginals (mdl); expectedMessages = mdl.getEdgeSet().size() * 2 * loopy.iterationsUsed(); assertEquals (expectedMessages, loopy.getTotalMessagesSent ()); } } private UndirectedModel createJtChain() { int numNodes = 4; Variable[] nodes = new Variable[numNodes]; for (int i = 0; i < numNodes; i++) { nodes[i] = new Variable(2); } DiscretePotential[] pots = new MultinomialPotential[]{ new MultinomialPotential(new Variable[]{nodes[0], nodes[1]}, new double[]{1, 2, 5, 4}), new MultinomialPotential(new Variable[]{nodes[1], nodes[2]}, new double[]{4, 2, 4, 1}), new MultinomialPotential(new Variable[]{nodes[2], nodes[3]}, new double[]{7, 3, 6, 9}), }; for (int i = 0; i < pots.length; i++) { pots[i].normalize(); } UndirectedModel uGraph = new UndirectedModel(); for (int i = 0; i < numNodes; i++) { uGraph.add(nodes[i]); } for (int i = 0; i < numNodes - 1; i++) { uGraph.addEdgeWithPotential(nodes[i], nodes[(i + 1)], pots[i]); } return uGraph; } private void createTestTrees() { Random r = new Random(185); trees = new UndirectedModel[]{ createJtChain(), createRandomGrid(5, 1, 3, r), createRandomGrid(6, 1, 2, r), createRandomTree(10, 2, r), createRandomTree(10, 2, r), createRandomTree(8, 3, r), createRandomTree(8, 3, r), }; modelsList.addAll(Arrays.asList(trees)); } private void computeTestTreeMargs() { treeMargs = new DiscretePotential[trees.length][]; BruteForceInferencer brute = new BruteForceInferencer(); for (int i = 0; i < trees.length; i++) { UndirectedModel mdl = trees[i]; DiscretePotential joint = brute.joint(mdl); treeMargs[i] = new DiscretePotential[mdl.getVerticesCount()]; for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) { Variable var = (Variable) it.next(); treeMargs[i][mdl.getIndex(var)] = joint.marginalize(var); } } } // Takes an undirected model that is a tree parameterized by // normalized edge potentials and converts it into a junction tree. private JunctionTree graphToJt(UndirectedModel tree) { int numNodes = tree.getVerticesCount(); JunctionTree jt = new JunctionTree(numNodes - 1); Clique[] edgeCliques = new Clique[tree.getEdgeSet().size()]; int i = 0; for (Iterator it = tree.getEdgeSet().iterator(); it.hasNext();) { Edge e = (Edge) it.next(); Variable v1 = (Variable) e.getVertexA(); Variable v2 = (Variable) e.getVertexB(); DiscretePotential pot = tree.potentialOfEdge(v1, v2); edgeCliques[i] = new HashClique(pot.varSet()); jt.setCPF(edgeCliques[i], pot.duplicate()); try { jt.add(edgeCliques[i]); } catch (Exception exception) { } i++; } for (int j = 0; j < edgeCliques.length; j++) { for (int k = j + 1; k < edgeCliques.length; k++) { Set set = edgeCliques[k].intersection(edgeCliques[j]); if (!set.isEmpty()) { jt.addNode(edgeCliques[j], edgeCliques[k]); } } } return jt; }public void testJtConsistency() { for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { UndirectedModel mdl = models[mdlIdx]; JunctionTreeInferencer jti = new JunctionTreeInferencer(); JunctionTree jt = jti.buildJunctionTree(mdl); for (Iterator it = jt.getVerticesIterator(); it.hasNext();) { Clique parent = (Clique) it.next(); for (Iterator it2 = jt.getChildren(parent).iterator(); it2.hasNext();) { Clique child = (Clique) it2.next(); DiscretePotential ptl = jt.getSepsetPot(parent, child); Set intersection = parent.intersection (child); assertTrue (intersection.equals (ptl.varSet())); } } }} private void compareTrpJoint(DiscretePotential joint, TRP trp) { Assignment assn = null; double prob1 = 0.0, prob2 = 0.0; try { Clique all = new HashClique(joint.varSet()); for (Iterator it = all.assignmentIterator(); it.hasNext();) { assn = (Assignment) it.next(); prob1 = trp.lookupJoint(assn); prob2 = joint.phi(assn);// assertTrue (Maths.almostEquals (prob1, prob2)); assertTrue(Math.abs(prob1 - prob2) < 0.01); } } catch (AssertionFailedError e) { System.out.println("*****************************************\nTEST FAILURE in compareTrpJoint"); System.out.println("*****************************************\nat"); System.out.println(assn); System.out.println("Expected: " + prob2); System.out.println("TRP: " + prob1); System.out.println("*****************************************\nExpected joint"); System.out.println(joint); System.out.println("*****************************************\nTRP dump"); trp.dump(); throw e; } } public void testTrp() { final UndirectedModel model = createTriangle(); TRP trp = new TRP(new RandomTreeFactory(), new TRP.IterationTerminator(200)); BruteForceInferencer brute = new BruteForceInferencer(); DiscretePotential joint = brute.joint(model); trp.computeMarginals(model); // Check joint// DiscretePotential joint = brute.joint (model); compareTrpJoint(joint, trp); // Check all marginals try { for (Iterator it = model.getVerticesIterator(); it.hasNext();) { Variable var = (Variable) it.next(); DiscretePotential marg1 = trp.lookupMarginal(var); DiscretePotential marg2 = joint.marginalize (var); assertTrue(marg1.almostEquals(marg2, APPX_EPSILON)); } for (Iterator it = model.getEdgeSet().iterator(); it.hasNext();) { Edge e = (Edge) it.next(); Variable v1 = (Variable) e.getVertexA(); Variable v2 = (Variable) e.getVertexB(); DiscretePotential marg1 = trp.lookupMarginal(v1, v2); DiscretePotential marg2 = joint.marginalize (new Variable[]{v1, v2}); assertTrue(marg1.almostEquals(marg2, APPX_EPSILON)); } } catch (AssertionFailedError e) { System.out.println("\n*************************************\nTEST FAILURE in compareTrpMargs");// System.out.println(marg1);// System.out.println(marg2); System.out.println("*************************************\nComplete model:\n\n"); model.dump(); System.out.println("*************************************\nTRP margs:\n\n"); trp.dump(); System.out.println("**************************************\nAll correct margs:\n"); for (Iterator it2 = model.getVerticesIterator(); it2.hasNext();) { Variable v2 = (Variable) it2.next(); brute.computeMarginals (model); System.out.println(brute.lookupMarginal(v2)); } throw e; } } public void testTrpJoint() { UndirectedModel model = createTriangle(); TRP trp = new TRP(new RandomTreeFactory(), new TRP.IterationTerminator(25)); trp.computeMarginals(model); // For each assignment to the model, check that // TRP.lookupLogJoint and TRP.lookupJoint are consistent Clique all = new HashClique(model.getVertexSet()); for (Iterator it = all.assignmentIterator(); it.hasNext();) { Assignment assn = (Assignment) it.next(); double log = trp.lookupLogJoint(assn); double prob = trp.lookupJoint(assn); assertTrue(Maths.almostEquals(Math.exp(log), prob)); } logger.info("Test trpJoint passed."); } /** Tests that running TRP doesn't inadvertantly change potentials in the original graph. */ public void testTrpNonDestructivity() { UndirectedModel model = createTriangle(); TRP trp = new TRP(new RandomTreeFactory(), new TRP.IterationTerminator(25)); BruteForceInferencer brute = new BruteForceInferencer(); DiscretePotential joint1 = brute.joint(model); trp.computeMarginals(model); DiscretePotential joint2 = brute.joint(model); assertTrue(joint1.almostEquals(joint2)); logger.info("Test trpNonDestructivity passed."); } // Verify that potentialOfVertex and potentialOfEdge (which use // caches) are consistent with the potentials set. public void testUndirectedCaches() { for (int i = 0; i < models.length; i++) { UndirectedModel mdl = models[i]; verifyCachesConsistent (mdl); } logger.info("Test undirectedCaches passed."); } private void verifyCachesConsistent (UndirectedModel mdl) { DiscretePotential pot, pot2, pot3; for (Iterator it = mdl.potentials().iterator(); it.hasNext();) { pot = (DiscretePotential) it.next(); // System.out.println("Testing model "+i+" potential "+pot); Object[] vars = pot.varSet().toArray(); switch (vars.length) { case 1: pot2 = mdl.potentialOfVertex((Variable) vars[0]); assertTrue(pot == pot2); break; case 2: Variable var1 = (Variable) vars[0]; Variable var2 = (Variable) vars[1]; pot2 = mdl.potentialOfEdge(var1, var2); pot3 = mdl.potentialOfEdge(var2, var1); assertTrue(pot == pot2); assertTrue(pot2 == pot3); break; // Potentials of size > 2 aren't now cached. default: break; } } } // Verify that potentialOfVertex and potentialOfEdge (which use // caches) are consistent with the potentials set even if a vertex is removed. public void testUndirectedCachesAfterRemove () { DiscretePotential pot, pot2, pot3; for (int i = 0; i < models.length; i++) { UndirectedModel mdl = models[i]; mdl = mdl.duplicate (); mdl.remove (mdl.get (0)); // Verify that indexing correct for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) { Variable var = (Variable) it.next (); int idx = mdl.getIndex (var); assertTrue (idx >= 0); assertTrue (idx < mdl.getVerticesCount()); } // Verify that caches consistent verifyCachesConsistent (mdl); } logger.info("Test undirectedCaches passed."); } public void testTrpReuse() { TRP trp1 = new TRP(new TRP.RandomTreeFactory(), new TRP.IterationTerminator(25)); for (int i = 0; i < models.length; i++) { trp1.computeMarginals(models[i]); } // Hard to do automatically right now... logger.info("Please ensure that all instantiations above run for 25 iterations."); // Ensure that all edges touched works... UndirectedModel mdl = models[0]; final Tree tree = new TRP.RandomTreeFactory().nextTree(mdl); TRP trp2 = new TRP(new TRP.TreeFactory() { public Tree nextTree(Graph g) { return tree; } }); trp2.computeMarginals(mdl); logger.info("Ensure that the above instantiation ran for 1000 iterations with a warning."); } // Verify that variable indices are consistent in undirectected // models. public void testUndirectedIndices() { for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { UndirectedModel mdl = models[mdlIdx]; for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) { Variable var1 = (Variable) it.next(); Variable var2 = mdl.get(mdl.getIndex(var1)); assertTrue("Mismatch in Variable index for " + var1 + " vs " + var2 + " in model " + mdlIdx + "\n" + mdl, var1 == var2); } } logger.info("Test undirectedIndices passed."); } // Tests that TRP and max-product propagation return the same // results when TRP runs for exactly one iteration. public void testTrpViterbiEquiv() { for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) { UndirectedModel mdl = trees[mdlIdx]; ViterbiPropagation maxprod = new ViterbiPropagation(); ViterbiTRP trp = new ViterbiTRP(new TRP.RandomTreeFactory(), new TRP.IterationTerminator(1)); maxprod.computeMarginals(mdl); trp.computeMarginals(mdl); // TRP should return same results as viterbi for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) { Variable var = (Variable) it.next();
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -