📄 testinference.java
字号:
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.grmm.test;import junit.framework.*;import java.util.Random;import edu.umass.cs.mallet.base.types.Matrix;import edu.umass.cs.mallet.base.types.Matrixn;import edu.umass.cs.mallet.base.util.CollectionUtils;import edu.umass.cs.mallet.base.util.MalletLogger;import java.util.logging.Logger;import salvo.jesus.graph.Edge;import edu.umass.cs.mallet.base.util.Maths;import edu.umass.cs.mallet.grmm.*;import edu.umass.cs.mallet.grmm.TRP.RandomTreeFactory;import java.util.*;import salvo.jesus.graph.Graph;/** * Torture tests of inference in GRMM. Well, actually, they're * not all that torturous, but hopefully they're at least * somewhat disconcerting. * * @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a> * @version $Id: TestInference.java,v 1.2 2004/07/22 00:37:38 casutton Exp $ */public class TestInference extends TestCase { private static Logger logger = MalletLogger.getLogger(TestInference.class.getName()); private static double APPX_EPSILON = 0.15; final public Class[] algorithms = { BruteForceInferencer.class, VariableElimination.class, JunctionTreeInferencer.class, }; final public Class[] appxAlgs = { TRP.class, LoopyBP.class, AsyncLoopyBP.class, }; final public Class[] maxMargAlgs = { LoopyMaxProduct.class, ViterbiTRP.class, JunctionTreeMaxProduct.class, }; // only used for logJoint test for now final public Class[] allAlgs = {// BruteForceInferencer.class, JunctionTreeInferencer.class, TRP.class,// VariableElimination.class, AsyncLoopyBP.class, LoopyBP.class, }; final public Class[] treeAlgs = { BeliefPropagation.class, }; List modelsList; UndirectedModel[] models; UndirectedModel[] trees; DiscretePotential[][] treeMargs; public TestInference(String name) { super(name); } private static UndirectedModel createChainGraph() { Variable[] vars = new Variable[5]; UndirectedModel model = new UndirectedModel(); try { // Add all variables to model for (int i = 0; i < 5; i++) { vars[i] = new Variable(2); model.add(vars[i]); } // Add some links double probs[] = {0.9, 0.1, 0.1, 0.9}; for (int i = 0; i < 4; i++) { Variable[] pair = {vars[i], vars[i + 1], }; MultinomialPotential pot = new MultinomialPotential(pair, probs); model.addEdgeWithPotential(vars[i], vars[i + 1], pot); } } catch (Exception e) { e.printStackTrace(); assertTrue(false); } return model; } private static UndirectedModel createTriangle() { Variable[] vars = new Variable[3]; for (int i = 0; i < 3; i++) { vars[i] = new Variable(2); } UndirectedModel model = new UndirectedModel(vars); double[][] pots = new double[][]{{0.2, 0.8, 0.1, 0.9}, {0.7, 0.3, 0.5, 0.5}, {0.6, 0.4, 0.8, 0.2}, {0.35, 0.65}}; // double[][] pots = new double[] [] { { model.addEdgeWithPotential(vars[0], vars[1], pots[0]); model.addEdgeWithPotential(vars[1], vars[2], pots[1]); model.addEdgeWithPotential(vars[2], vars[0], pots[2]); MultinomialPotential pot = new MultinomialPotential(new Variable[]{vars[0]}, pots[3]); model.addPotential(vars[0], pot); return model; } private static MultinomialPotential randomEdgePotential(Random r, Variable v1, Variable v2) { int max1 = v1.getNumOutcomes(); int max2 = v2.getNumOutcomes(); Matrix phi = new Matrixn(new int[]{max1, max2}); for (int i = 0; i < v1.getNumOutcomes(); i++) { for (int j = 0; j < v2.getNumOutcomes(); j++) { phi.setValue(new int[]{i, j}, r.nextDouble ()); // rescale(r.nextDouble())); } } return new MultinomialPotential (new Variable[]{v1, v2}, phi); } private static MultinomialPotential randomNodePotential(Random r, Variable v) { int max = v.getNumOutcomes(); Matrix phi = new Matrixn(new int[]{max}); for (int i = 0; i < v.getNumOutcomes(); i++) { phi.setSingleValue(i, rescale(r.nextDouble())); } return new MultinomialPotential (new Variable[]{v}, phi); } // scale d into range 0.2..0.8 private static double rescale(double d) { return 0.2 + 0.6 * d; } private static UndirectedModel createRandomGraph(int numV, int numOutcomes, Random r) { Variable[] vars = new Variable[numV]; for (int i = 0; i < numV; i++) { vars[i] = new Variable(numOutcomes); } UndirectedModel model = new UndirectedModel(vars); for (int i = 0; i < numV; i++) { boolean hasOne = false; for (int j = i + 1; j < numV; j++) { if (r.nextBoolean()) { hasOne = true; model.addEdgeWithPotential (vars[i], vars[j], randomEdgePotential(r, vars[i], vars[j])); } } // If vars [i] has no edge potential, add a node potential // To keep things simple, we'll require the potential to be normalized. if (!hasOne) { DiscretePotential pot = randomNodePotential(r, vars[i]); pot.normalize(); model.addPotential(vars[i], pot); } } // Ensure exactly one connected component for (int i = 0; i < numV; i++) { for (int j = i + 1; j < numV; j++) { if (!model.isConnected(vars[i], vars[j])) { DiscretePotential ptl = randomEdgePotential(r, vars[i], vars[j]); model.addEdgeWithPotential(vars[i], vars[j], ptl); } } } return model; } private static UndirectedModel createRandomGrid(int w, int h, int maxOutcomes, Random r) { Variable[][] vars = new Variable[w][h]; UndirectedModel mdl = new UndirectedModel(w * h); for (int i = 0; i < w; i++) { for (int j = 0; j < h; j++) { vars[i][j] = new Variable(r.nextInt(maxOutcomes - 1) + 2); mdl.add(vars[i][j]); } } for (int i = 0; i < w; i++) { for (int j = 0; j < h; j++) { DiscretePotential ptl; if (i < w - 1) { ptl = randomEdgePotential(r, vars[i][j], vars[i + 1][j]); mdl.addEdgeWithPotential(vars[i][j], vars[i + 1][j], ptl); } if (j < h - 1) { ptl = randomEdgePotential(r, vars[i][j], vars[i][j + 1]); mdl.addEdgeWithPotential(vars[i][j], vars[i][j + 1], ptl); } } } return mdl; } private UndirectedModel createRandomTree(int nnodes, int maxOutcomes, Random r) { Variable[] vars = new Variable[nnodes]; UndirectedModel mdl = new UndirectedModel(nnodes); for (int i = 0; i < nnodes; i++) { vars[i] = new Variable(r.nextInt(maxOutcomes - 1) + 2); mdl.add(vars[i]); } // Add some random edges for (int i = 0; i < nnodes; i++) { for (int j = i + 1; j < nnodes; j++) { if (!mdl.isConnected(vars[i], vars[j]) && r.nextBoolean()) { DiscretePotential ptl = randomEdgePotential(r, vars[i], vars[j]); mdl.addEdgeWithPotential(vars[i], vars[j], ptl); } } } // Ensure exactly one connected component for (int i = 0; i < nnodes; i++) { for (int j = i + 1; j < nnodes; j++) { if (!mdl.isConnected(vars[i], vars[j])) { System.out.println("forced edge: " + i + " " + j); DiscretePotential ptl = randomEdgePotential(r, vars[i], vars[j]); mdl.addEdgeWithPotential(vars[i], vars[j], ptl); } } } return mdl; } public static List createTestModels() { Random r = new Random(42); // These models are all small so that we can run the brute force // inferencer on them. UndirectedModel[] mdls = new UndirectedModel[]{ createTriangle(), createChainGraph(), createRandomGraph(3, 2, r), createRandomGraph(3, 3, r), createRandomGraph(6, 3, r), createRandomGraph(8, 2, r), createRandomGrid(3, 2, 4, r), createRandomGrid(4, 3, 2, r), }; return new ArrayList(Arrays.asList(mdls)); } public void testFactorizedJoint() throws Exception { Inferencer[][] infs = new Inferencer[allAlgs.length][models.length]; for (int i = 0; i < allAlgs.length; i++) { for (int mdl = 0; mdl < models.length; mdl++) { Inferencer alg = (Inferencer) allAlgs[i].newInstance(); try { alg.computeMarginals(models[mdl]); infs[i][mdl] = alg; } catch (UnsupportedOperationException e) { // LoopyBP only handles edge ptls logger.warning("Skipping (" + mdl + "," + i + ")\n" + e); throw e;// continue; } } } /* Ensure that lookupLogJoint() consistent */ int alg1 = 0; // Brute force for (int alg2 = 1; alg2 < allAlgs.length; alg2++) { for (int mdl = 0; mdl < models.length; mdl++) { Inferencer inf1 = infs[alg1][mdl]; Inferencer inf2 = infs[alg2][mdl]; if ((inf1 == null) || (inf2 == null)) { continue; } Iterator it = models[mdl].assignmentIterator(); while (it.hasNext()) { try { Assignment assn = (Assignment) it.next(); double joint1 = inf1.lookupLogJoint(assn); double joint2 = inf2.lookupLogJoint(assn); logger.finest("logJoint: " + inf1 + " " + inf2 + " Model " + mdl + " INF1: " + joint1 + "\n" + " INF2: " + joint2 + "\n"); assertTrue("logJoint not equal btwn " + inf1 + " " + inf2 + "\n" + " Model " + mdl + "\n" + " INF1: " + joint1 + "\n" + " INF2: " + joint2 + "\n", Math.abs(joint1 - joint2) < APPX_EPSILON); double joint3 = inf1.lookupJoint(assn); assertTrue("logJoint & joint not consistent\n " + "Model " + mdl + "\n" + assn, Maths.almostEquals(joint3, Math.exp(joint1))); } catch (UnsupportedOperationException e) { // VarElim doesn't compute log joints. Let it slide logger.warning("Skipping " + inf1 + " -> " + inf2 + "\n" + e); continue; } } } } } public void testMarginals() throws Exception { DiscretePotential[][][] joints = new DiscretePotential[models.length][][]; int numAlgs = algorithms.length + appxAlgs.length; for (int mdl = 0; mdl < models.length; mdl++) { joints[mdl] = new DiscretePotential[numAlgs][]; } /* Query every known graph with every known alg. */ for (int i = 0; i < algorithms.length; i++) { for (int mdl = 0; mdl < models.length; mdl++) { Inferencer alg = (Inferencer) algorithms[i].newInstance(); logger.fine("Computing marginals for model " + mdl + " alg " + alg); alg.computeMarginals(models[mdl]); int vrt = 0; int numVertices = models[mdl].getVerticesCount(); joints[mdl][i] = new DiscretePotential[numVertices]; for (Iterator it = models[mdl].getVertexSet().iterator(); it.hasNext(); vrt++) { Variable var = (Variable) it.next(); try { joints[mdl][i][vrt] = alg.lookupMarginal(var); assert joints[mdl][i][vrt] != null : "Query returned null for model " + mdl + " vertex " + var + " alg " + alg; } catch (UnsupportedOperationException e) { // Allow unsupported inference to slide with warning logger.warning("Warning: Skipping model " + mdl + " for alg " + alg + "\n Inference unsupported."); } } } } logger.fine("Checking that results are consistent..."); /* Now, make sure the exact marginals are consistent for * the same model. */ for (int mdl = 0; mdl < models.length; mdl++) { int maxV = models[mdl].getVerticesCount(); for (int vrt = 0; vrt < maxV; vrt++) { for (int alg1 = 0; alg1 < algorithms.length; alg1++) { for (int alg2 = 0; alg2 < algorithms.length; alg2++) { DiscretePotential joint1 = joints[mdl][alg1][vrt]; DiscretePotential joint2 = joints[mdl][alg2][vrt]; try { // By the time we get here, a joint is null only if // there was an UnsupportedOperationException. if ((joint1 != null) && (joint2 != null)) { assertTrue(joint1.almostEquals(joint2)); } } catch (AssertionFailedError e) { System.out.println("\n************************************\nTest FAILED\n\n"); System.out.println("Model " + mdl + " Vertex " + vrt); System.out.println("Algs " + alg1 + " and " + alg2 + " not consistent."); System.out.println("MARGINAL from " + alg1); System.out.println(joint1); System.out.println("MARGINAL from " + alg2); System.out.println(joint2); System.out.println("Marginals from " + alg1 + ":"); for (int i = 0; i < maxV; i++) { System.out.println(joints[mdl][alg1][i]); } System.out.println("Marginals from " + alg2 + ":"); for (int i = 0; i < maxV; i++) { System.out.println(joints[mdl][alg2][i]); } models[mdl].dump(); throw e; } } } } } // Compare all approximate algorithms against brute force. logger.fine("Checking the approximate algorithms..."); int alg2 = 0; // Brute force for (int alg1 = algorithms.length; alg1 < numAlgs; alg1++) { Inferencer alg = (Inferencer) appxAlgs[alg1 - algorithms.length].newInstance(); for (int mdl = 0; mdl < models.length; mdl++) { logger.finer("Running inference alg " + alg + " with model " + mdl); try { alg.computeMarginals(models[mdl]); } catch (UnsupportedOperationException e) { // LoopyBP does not support vertex potentials. // We'll let that slide. if (alg instanceof BeliefPropagation) { logger.warning("Skipping model " + mdl + " for alg " + alg + "\nInference unsupported.");
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -