⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 testinference.java

📁 这是一个matlab的java实现。里面有许多内容。请大家慢慢捉摸。
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
            continue;          } else {            throw e;          }        }        /* lookup all marginals */        int vrt = 0;        int maxV = models[mdl].getVerticesCount();        joints[mdl][alg1] = new DiscretePotential[maxV];        for (Iterator it = models[mdl].getVertexSet().iterator();             it.hasNext();             vrt++) {          Variable var = (Variable) it.next();          logger.finer("Lookup marginal for model " + mdl + " vrt " + var + " alg " + alg);          DiscretePotential ptl = alg.lookupMarginal(var);          joints[mdl][alg1][vrt] = ptl.duplicate();        }        for (vrt = 0; vrt < maxV; vrt++) {          DiscretePotential joint1 = joints[mdl][alg1][vrt];          DiscretePotential joint2 = joints[mdl][alg2][vrt];          try {            assertTrue(joint1.almostEquals(joint2, APPX_EPSILON));          } catch (AssertionFailedError e) {            System.out.println("\n************************************\nAppx Marginal Test FAILED\n\n");            System.out.println("Inferencer: " + alg);            System.out.println("Model " + mdl + " Vertex " + vrt);            System.out.println(joint1);            System.out.println(joint2);            models[mdl].dump();            System.out.println("All marginals:");            for (int i = 0; i < maxV; i++) {              System.out.println(joints[mdl][alg1][i]);            }            System.out.println("Correct marginals:");            for (int i = 0; i < maxV; i++) {              System.out.println(joints[mdl][alg2][i]);            }            throw e;          }        }      }    }    System.out.println("Tested " + models.length + " undirected models.");  }	public void testQuery () throws Exception	{		java.util.Random rand = new java.util.Random (15667);		for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) {			UndirectedModel mdl = models [mdlIdx];			int size = rand.nextInt (3) + 2;			Collection vars = CollectionUtils.subset (mdl.getVertexSet(), size, rand);			Variable[] varArr = (Variable[]) vars.toArray (new Variable [0]);			Assignment assn = new Assignment (varArr, new int [size]);						BruteForceInferencer brute = new BruteForceInferencer();			DiscretePotential joint = brute.joint(mdl);			double marginal = joint.marginalize(vars).phi (assn);			for (int algIdx = 0; algIdx < appxAlgs.length; algIdx++) {				Inferencer alg = (Inferencer) appxAlgs[algIdx].newInstance();				double returned = mdl.query (alg, assn);				assertEquals ("Failure on model "+mdlIdx+" alg "+alg, marginal, returned, APPX_EPSILON);			}		}		logger.info ("Test testQuery passed.");	}	// Tests the measurement of numbers of messages sent	public void testNumMessages ()	{		for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) {			UndirectedModel mdl = models [mdlIdx];						TRP trp = new TRP ();			trp.computeMarginals (mdl);			int expectedMessages = (mdl.getVerticesCount() - 1) * 2 														 * trp.iterationsUsed();			assertEquals (expectedMessages, trp.getTotalMessagesSent ());			LoopyBP loopy = new LoopyBP ();			loopy.computeMarginals (mdl);			expectedMessages = mdl.getEdgeSet().size() * 2 												 * loopy.iterationsUsed();			assertEquals (expectedMessages, loopy.getTotalMessagesSent ());		}	}  private UndirectedModel createJtChain()  {    int numNodes = 4;    Variable[] nodes = new Variable[numNodes];    for (int i = 0; i < numNodes; i++) {      nodes[i] = new Variable(2);    }    DiscretePotential[] pots = new MultinomialPotential[]{      new MultinomialPotential(new Variable[]{nodes[0], nodes[1]},                               new double[]{1, 2, 5, 4}),      new MultinomialPotential(new Variable[]{nodes[1], nodes[2]},                               new double[]{4, 2, 4, 1}),      new MultinomialPotential(new Variable[]{nodes[2], nodes[3]},                               new double[]{7, 3, 6, 9}),    };    for (int i = 0; i < pots.length; i++) {      pots[i].normalize();    }    UndirectedModel uGraph = new UndirectedModel();    for (int i = 0; i < numNodes; i++) {      uGraph.add(nodes[i]);    }    for (int i = 0; i < numNodes - 1; i++) {      uGraph.addEdgeWithPotential(nodes[i], nodes[(i + 1)], pots[i]);    }    return uGraph;  }  private void createTestTrees()  {    Random r = new Random(185);    trees = new UndirectedModel[]{      createJtChain(),      createRandomGrid(5, 1, 3, r),      createRandomGrid(6, 1, 2, r),      createRandomTree(10, 2, r),      createRandomTree(10, 2, r),      createRandomTree(8, 3, r),      createRandomTree(8, 3, r),    };    modelsList.addAll(Arrays.asList(trees));  }  private void computeTestTreeMargs()  {    treeMargs = new DiscretePotential[trees.length][];    BruteForceInferencer brute = new BruteForceInferencer();    for (int i = 0; i < trees.length; i++) {      UndirectedModel mdl = trees[i];      DiscretePotential joint = brute.joint(mdl);      treeMargs[i] = new DiscretePotential[mdl.getVerticesCount()];      for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) {        Variable var = (Variable) it.next();        treeMargs[i][mdl.getIndex(var)] = joint.marginalize(var);      }    }  }  // Takes an undirected model that is a tree parameterized by  // normalized edge potentials and converts it into a junction tree.  private JunctionTree graphToJt(UndirectedModel tree)  {    int numNodes = tree.getVerticesCount();    JunctionTree jt = new JunctionTree(numNodes - 1);    Clique[] edgeCliques = new Clique[tree.getEdgeSet().size()];    int i = 0;    for (Iterator it = tree.getEdgeSet().iterator(); it.hasNext();) {      Edge e = (Edge) it.next();      Variable v1 = (Variable) e.getVertexA();      Variable v2 = (Variable) e.getVertexB();      DiscretePotential pot = tree.potentialOfEdge(v1, v2);      edgeCliques[i] = new HashClique(pot.varSet());      jt.setCPF(edgeCliques[i], pot.duplicate());      try {        jt.add(edgeCliques[i]);      } catch (Exception exception) {      }      i++;    }    for (int j = 0; j < edgeCliques.length; j++) {      for (int k = j + 1; k < edgeCliques.length; k++) {        Set set = edgeCliques[k].intersection(edgeCliques[j]);        if (!set.isEmpty()) {          jt.addNode(edgeCliques[j], edgeCliques[k]);        }      }    }    return jt;  }public void testJtConsistency() {  for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) {    UndirectedModel mdl = models[mdlIdx];    JunctionTreeInferencer jti = new JunctionTreeInferencer();    JunctionTree jt = jti.buildJunctionTree(mdl);    for (Iterator it = jt.getVerticesIterator(); it.hasNext();) {      Clique parent = (Clique) it.next();      for (Iterator it2 = jt.getChildren(parent).iterator(); it2.hasNext();) {        Clique child = (Clique) it2.next();        DiscretePotential ptl = jt.getSepsetPot(parent, child);        Set intersection = parent.intersection (child);        assertTrue (intersection.equals (ptl.varSet()));      }    }  }}  private void compareTrpJoint(DiscretePotential joint, TRP trp)  {    Assignment assn = null;    double prob1 = 0.0, prob2 = 0.0;    try {      Clique all = new HashClique(joint.varSet());      for (Iterator it = all.assignmentIterator(); it.hasNext();) {        assn = (Assignment) it.next();        prob1 = trp.lookupJoint(assn);        prob2 = joint.phi(assn);//				assertTrue (Maths.almostEquals (prob1, prob2));        assertTrue(Math.abs(prob1 - prob2) < 0.01);      }    } catch (AssertionFailedError e) {      System.out.println("*****************************************\nTEST FAILURE in compareTrpJoint");      System.out.println("*****************************************\nat");      System.out.println(assn);      System.out.println("Expected: " + prob2);      System.out.println("TRP: " + prob1);      System.out.println("*****************************************\nExpected joint");      System.out.println(joint);      System.out.println("*****************************************\nTRP dump");      trp.dump();      throw e;    }  }  public void testTrp()  {    final UndirectedModel model = createTriangle();    TRP trp = new TRP(new RandomTreeFactory(),                      new TRP.IterationTerminator(200));    BruteForceInferencer brute = new BruteForceInferencer();    DiscretePotential joint = brute.joint(model);    trp.computeMarginals(model);    // Check joint//		DiscretePotential joint = brute.joint (model);    compareTrpJoint(joint, trp);    // Check all marginals    try {      for (Iterator it = model.getVerticesIterator(); it.hasNext();) {        Variable var = (Variable) it.next();        DiscretePotential marg1 = trp.lookupMarginal(var);        DiscretePotential marg2 = joint.marginalize (var);        assertTrue(marg1.almostEquals(marg2, APPX_EPSILON));      }      for (Iterator it = model.getEdgeSet().iterator(); it.hasNext();) {        Edge e = (Edge) it.next();        Variable v1 = (Variable) e.getVertexA();        Variable v2 = (Variable) e.getVertexB();        DiscretePotential marg1 = trp.lookupMarginal(v1, v2);        DiscretePotential marg2 = joint.marginalize                (new Variable[]{v1, v2});        assertTrue(marg1.almostEquals(marg2, APPX_EPSILON));      }    } catch (AssertionFailedError e) {      System.out.println("\n*************************************\nTEST FAILURE in compareTrpMargs");//    System.out.println(marg1);//    System.out.println(marg2);      System.out.println("*************************************\nComplete model:\n\n");      model.dump();      System.out.println("*************************************\nTRP margs:\n\n");      trp.dump();      System.out.println("**************************************\nAll correct margs:\n");      for (Iterator it2 = model.getVerticesIterator(); it2.hasNext();) {        Variable v2 = (Variable) it2.next();				brute.computeMarginals (model);        System.out.println(brute.lookupMarginal(v2));      }      throw e;    }  }  public void testTrpJoint()  {    UndirectedModel model = createTriangle();    TRP trp = new TRP(new RandomTreeFactory(),                      new TRP.IterationTerminator(25));    trp.computeMarginals(model);    // For each assignment to the model, check that    // TRP.lookupLogJoint and TRP.lookupJoint are consistent    Clique all = new HashClique(model.getVertexSet());    for (Iterator it = all.assignmentIterator(); it.hasNext();) {      Assignment assn = (Assignment) it.next();      double log = trp.lookupLogJoint(assn);      double prob = trp.lookupJoint(assn);      assertTrue(Maths.almostEquals(Math.exp(log), prob));    }    logger.info("Test trpJoint passed.");  }  /** Tests that running TRP doesn't inadvertantly change potentials   in the original graph. */  public void testTrpNonDestructivity()  {    UndirectedModel model = createTriangle();    TRP trp = new TRP(new RandomTreeFactory(),                      new TRP.IterationTerminator(25));    BruteForceInferencer brute = new BruteForceInferencer();    DiscretePotential joint1 = brute.joint(model);    trp.computeMarginals(model);    DiscretePotential joint2 = brute.joint(model);    assertTrue(joint1.almostEquals(joint2));    logger.info("Test trpNonDestructivity passed.");  }  // Verify that potentialOfVertex and potentialOfEdge (which use  // caches) are consistent with the potentials set.  public void testUndirectedCaches()  {    for (int i = 0; i < models.length; i++) {       UndirectedModel mdl = models[i];       verifyCachesConsistent (mdl);    }    logger.info("Test undirectedCaches passed.");  }  private void verifyCachesConsistent (UndirectedModel mdl)   {	DiscretePotential pot, pot2, pot3;         for (Iterator it = mdl.potentials().iterator(); it.hasNext();) {            pot = (DiscretePotential) it.next();    //				System.out.println("Testing model "+i+" potential "+pot);                Object[] vars = pot.varSet().toArray();            switch (vars.length) {              case 1:                pot2 = mdl.potentialOfVertex((Variable) vars[0]);                assertTrue(pot == pot2);                break;                      case 2:                Variable var1 = (Variable) vars[0];                Variable var2 = (Variable) vars[1];                pot2 = mdl.potentialOfEdge(var1, var2);                pot3 = mdl.potentialOfEdge(var2, var1);                assertTrue(pot == pot2);                assertTrue(pot2 == pot3);                break;                    // Potentials of size > 2 aren't now cached.              default:                break;            }          }  }	// Verify that potentialOfVertex and potentialOfEdge (which use	 // caches) are consistent with the potentials set even if a vertex is removed.	 public void testUndirectedCachesAfterRemove ()	 {		 DiscretePotential pot, pot2, pot3;		 for (int i = 0; i < models.length; i++) {			 UndirectedModel mdl = models[i];			 mdl = mdl.duplicate ();			 mdl.remove (mdl.get (0));			 			 // Verify that indexing correct			 for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) {			 	Variable var = (Variable) it.next ();			 	int idx = mdl.getIndex (var);			 	assertTrue (idx >= 0);			 	assertTrue (idx < mdl.getVerticesCount());			 }			 			 // Verify that caches consistent			verifyCachesConsistent (mdl);		 }		 logger.info("Test undirectedCaches passed.");	 }  public void testTrpReuse()  {    TRP trp1 = new TRP(new TRP.RandomTreeFactory(),                       new TRP.IterationTerminator(25));    for (int i = 0; i < models.length; i++) {      trp1.computeMarginals(models[i]);    }    // Hard to do automatically right now...    logger.info("Please ensure that all instantiations above run for 25 iterations.");    // Ensure that all edges touched works...    UndirectedModel mdl = models[0];    final Tree tree = new TRP.RandomTreeFactory().nextTree(mdl);    TRP trp2 = new TRP(new TRP.TreeFactory() {      public Tree nextTree(Graph g)      {        return tree;      }    });    trp2.computeMarginals(mdl);    logger.info("Ensure that the above instantiation ran for 1000 iterations with a warning.");  }  // Verify that variable indices are consistent in undirectected  // models.  public void testUndirectedIndices()  {    for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) {      UndirectedModel mdl = models[mdlIdx];      for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) {        Variable var1 = (Variable) it.next();        Variable var2 = mdl.get(mdl.getIndex(var1));        assertTrue("Mismatch in Variable index for " + var1 + " vs "                   + var2 + " in model " + mdlIdx + "\n" + mdl,                   var1 == var2);      }    }    logger.info("Test undirectedIndices passed.");  }  // Tests that TRP and max-product propagation return the same  // results when TRP runs for exactly one iteration.  public void testTrpViterbiEquiv()  {    for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) {      UndirectedModel mdl = trees[mdlIdx];      ViterbiPropagation maxprod = new ViterbiPropagation();      ViterbiTRP trp = new ViterbiTRP(new TRP.RandomTreeFactory(),                                      new TRP.IterationTerminator(1));      maxprod.computeMarginals(mdl);      trp.computeMarginals(mdl);      // TRP should return same results as viterbi      for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) {        Variable var = (Variable) it.next();

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -