⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 testinference.java

📁 这是一个matlab的java实现。里面有许多内容。请大家慢慢捉摸。
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This software is provided under the terms of the Common Public License,   version 1.0, as published by http://www.opensource.org.  For further   information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.grmm.test;import junit.framework.*;import java.util.Random;import edu.umass.cs.mallet.base.types.Matrix;import edu.umass.cs.mallet.base.types.Matrixn;import edu.umass.cs.mallet.base.util.CollectionUtils;import edu.umass.cs.mallet.base.util.MalletLogger;import java.util.logging.Logger;import salvo.jesus.graph.Edge;import edu.umass.cs.mallet.base.util.Maths;import edu.umass.cs.mallet.grmm.*;import edu.umass.cs.mallet.grmm.TRP.RandomTreeFactory;import java.util.*;import salvo.jesus.graph.Graph;/** *  Torture tests of inference in GRMM.  Well, actually, they're *   not all that torturous, but hopefully they're at least *   somewhat disconcerting. * * @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a> * @version $Id: TestInference.java,v 1.2 2004/07/22 00:37:38 casutton Exp $ */public class TestInference extends TestCase {  private static Logger logger = MalletLogger.getLogger(TestInference.class.getName());  private static double APPX_EPSILON = 0.15;  final public Class[] algorithms = {    BruteForceInferencer.class,    VariableElimination.class,    JunctionTreeInferencer.class,  };  final public Class[] appxAlgs = {    TRP.class,    LoopyBP.class,		AsyncLoopyBP.class,  };  final public Class[] maxMargAlgs = {    LoopyMaxProduct.class,    ViterbiTRP.class,    JunctionTreeMaxProduct.class,  };  // only used for logJoint test for now  final public Class[] allAlgs = {//    BruteForceInferencer.class,    JunctionTreeInferencer.class,    TRP.class,//	  VariableElimination.class,		AsyncLoopyBP.class,    LoopyBP.class,  };  final public Class[] treeAlgs = {    BeliefPropagation.class,  };  List modelsList;  UndirectedModel[] models;  UndirectedModel[] trees;  DiscretePotential[][] treeMargs;  public TestInference(String name)  {    super(name);  }  private static UndirectedModel createChainGraph()  {    Variable[] vars = new Variable[5];    UndirectedModel model = new UndirectedModel();    try {      // Add all variables to model      for (int i = 0; i < 5; i++) {        vars[i] = new Variable(2);        model.add(vars[i]);      }      // Add some links      double probs[] = {0.9, 0.1, 0.1, 0.9};      for (int i = 0; i < 4; i++) {        Variable[] pair = {vars[i], vars[i + 1], };        MultinomialPotential pot = new MultinomialPotential(pair, probs);        model.addEdgeWithPotential(vars[i], vars[i + 1], pot);      }    } catch (Exception e) {      e.printStackTrace();      assertTrue(false);    }    return model;  }  private static UndirectedModel createTriangle()  {    Variable[] vars = new Variable[3];    for (int i = 0; i < 3; i++) {      vars[i] = new Variable(2);    }    UndirectedModel model = new UndirectedModel(vars);    double[][] pots = new double[][]{{0.2, 0.8, 0.1, 0.9},                                     {0.7, 0.3, 0.5, 0.5},                                     {0.6, 0.4, 0.8, 0.2},                                     {0.35, 0.65}};    // double[][] pots = new double[] [] { {    model.addEdgeWithPotential(vars[0], vars[1], pots[0]);    model.addEdgeWithPotential(vars[1], vars[2], pots[1]);    model.addEdgeWithPotential(vars[2], vars[0], pots[2]);    MultinomialPotential pot = new MultinomialPotential(new Variable[]{vars[0]}, pots[3]);    model.addPotential(vars[0], pot);    return model;  }  private static MultinomialPotential randomEdgePotential(Random r,                                                   Variable v1, Variable v2)  {    int max1 = v1.getNumOutcomes();    int max2 = v2.getNumOutcomes();    Matrix phi = new Matrixn(new int[]{max1, max2});    for (int i = 0; i < v1.getNumOutcomes(); i++) {      for (int j = 0; j < v2.getNumOutcomes(); j++) {        phi.setValue(new int[]{i, j}, r.nextDouble ()); // rescale(r.nextDouble()));      }    }    return new MultinomialPotential            (new Variable[]{v1, v2}, phi);  }  private static MultinomialPotential randomNodePotential(Random r, Variable v)  {    int max = v.getNumOutcomes();    Matrix phi = new Matrixn(new int[]{max});    for (int i = 0; i < v.getNumOutcomes(); i++) {      phi.setSingleValue(i, rescale(r.nextDouble()));    }    return new MultinomialPotential            (new Variable[]{v}, phi);  }  // scale d into range 0.2..0.8  private static double rescale(double d)  {    return 0.2 + 0.6 * d;  }  private static UndirectedModel createRandomGraph(int numV, int numOutcomes, Random r)  {    Variable[] vars = new Variable[numV];    for (int i = 0; i < numV; i++) {      vars[i] = new Variable(numOutcomes);    }    UndirectedModel model = new UndirectedModel(vars);    for (int i = 0; i < numV; i++) {      boolean hasOne = false;      for (int j = i + 1; j < numV; j++) {        if (r.nextBoolean()) {          hasOne = true;          model.addEdgeWithPotential                  (vars[i], vars[j], randomEdgePotential(r, vars[i], vars[j]));        }      }      // If vars [i] has no edge potential, add a node potential      //  To keep things simple, we'll require the potential to be normalized.      if (!hasOne) {        DiscretePotential pot = randomNodePotential(r, vars[i]);        pot.normalize();        model.addPotential(vars[i], pot);      }    }    // Ensure exactly one connected component    for (int i = 0; i < numV; i++) {      for (int j = i + 1; j < numV; j++) {        if (!model.isConnected(vars[i], vars[j])) {          DiscretePotential ptl = randomEdgePotential(r, vars[i], vars[j]);          model.addEdgeWithPotential(vars[i], vars[j], ptl);        }      }    }    return model;  }  private static UndirectedModel createRandomGrid(int w, int h, int maxOutcomes, Random r)  {    Variable[][] vars = new Variable[w][h];    UndirectedModel mdl = new UndirectedModel(w * h);    for (int i = 0; i < w; i++) {      for (int j = 0; j < h; j++) {        vars[i][j] = new Variable(r.nextInt(maxOutcomes - 1) + 2);        mdl.add(vars[i][j]);      }    }    for (int i = 0; i < w; i++) {      for (int j = 0; j < h; j++) {        DiscretePotential ptl;        if (i < w - 1) {          ptl = randomEdgePotential(r, vars[i][j], vars[i + 1][j]);          mdl.addEdgeWithPotential(vars[i][j], vars[i + 1][j], ptl);        }        if (j < h - 1) {          ptl = randomEdgePotential(r, vars[i][j], vars[i][j + 1]);          mdl.addEdgeWithPotential(vars[i][j], vars[i][j + 1], ptl);        }      }    }    return mdl;  }  private UndirectedModel createRandomTree(int nnodes, int maxOutcomes, Random r)  {    Variable[] vars = new Variable[nnodes];    UndirectedModel mdl = new UndirectedModel(nnodes);    for (int i = 0; i < nnodes; i++) {      vars[i] = new Variable(r.nextInt(maxOutcomes - 1) + 2);      mdl.add(vars[i]);    }    //  Add some random edges    for (int i = 0; i < nnodes; i++) {      for (int j = i + 1; j < nnodes; j++) {        if (!mdl.isConnected(vars[i], vars[j]) && r.nextBoolean()) {          DiscretePotential ptl = randomEdgePotential(r, vars[i], vars[j]);          mdl.addEdgeWithPotential(vars[i], vars[j], ptl);        }      }    }    // Ensure exactly one connected component    for (int i = 0; i < nnodes; i++) {      for (int j = i + 1; j < nnodes; j++) {        if (!mdl.isConnected(vars[i], vars[j])) {          System.out.println("forced edge: " + i + " " + j);          DiscretePotential ptl = randomEdgePotential(r, vars[i], vars[j]);          mdl.addEdgeWithPotential(vars[i], vars[j], ptl);        }      }    }    return mdl;  }  public static List createTestModels()  {    Random r = new Random(42);    // These models are all small so that we can run the brute force    // inferencer on them.    UndirectedModel[] mdls = new UndirectedModel[]{      createTriangle(),      createChainGraph(),      createRandomGraph(3, 2, r),      createRandomGraph(3, 3, r),      createRandomGraph(6, 3, r),      createRandomGraph(8, 2, r),      createRandomGrid(3, 2, 4, r),      createRandomGrid(4, 3, 2, r),    };    return new ArrayList(Arrays.asList(mdls));  }  public void testFactorizedJoint() throws Exception  {    Inferencer[][] infs = new Inferencer[allAlgs.length][models.length];    for (int i = 0; i < allAlgs.length; i++) {      for (int mdl = 0; mdl < models.length; mdl++) {        Inferencer alg = (Inferencer) allAlgs[i].newInstance();        try {          alg.computeMarginals(models[mdl]);          infs[i][mdl] = alg;        } catch (UnsupportedOperationException e) {          // LoopyBP only handles edge ptls          logger.warning("Skipping (" + mdl + "," + i + ")\n" + e);					throw e;//          continue;        }      }    }    /* Ensure that lookupLogJoint() consistent */    int alg1 = 0;  // Brute force    for (int alg2 = 1; alg2 < allAlgs.length; alg2++) {      for (int mdl = 0; mdl < models.length; mdl++) {        Inferencer inf1 = infs[alg1][mdl];        Inferencer inf2 = infs[alg2][mdl];        if ((inf1 == null) || (inf2 == null)) {          continue;        }        Iterator it = models[mdl].assignmentIterator();        while (it.hasNext()) {          try {            Assignment assn = (Assignment) it.next();            double joint1 = inf1.lookupLogJoint(assn);            double joint2 = inf2.lookupLogJoint(assn);            logger.finest("logJoint: " + inf1 + " " + inf2                          + "  Model " + mdl                          + "  INF1: " + joint1 + "\n"                          + "  INF2: " + joint2 + "\n");            assertTrue("logJoint not equal btwn " + inf1 + " " + inf2 + "\n"                       + "  Model " + mdl + "\n"                       + "  INF1: " + joint1 + "\n"                       + "  INF2: " + joint2 + "\n",                       Math.abs(joint1 - joint2) < APPX_EPSILON);            double joint3 = inf1.lookupJoint(assn);            assertTrue("logJoint & joint not consistent\n  "                       + "Model " + mdl + "\n" + assn,                       Maths.almostEquals(joint3, Math.exp(joint1)));          } catch (UnsupportedOperationException e) {            // VarElim doesn't compute log joints. Let it slide            logger.warning("Skipping " + inf1 + " -> " + inf2 + "\n" + e);            continue;          }        }      }    }  }  public void testMarginals() throws Exception  {    DiscretePotential[][][] joints = new DiscretePotential[models.length][][];    int numAlgs = algorithms.length + appxAlgs.length;    for (int mdl = 0; mdl < models.length; mdl++) {      joints[mdl] = new DiscretePotential[numAlgs][];    }    /* Query every known graph with every known alg. */    for (int i = 0; i < algorithms.length; i++) {      for (int mdl = 0; mdl < models.length; mdl++) {        Inferencer alg = (Inferencer) algorithms[i].newInstance();        logger.fine("Computing marginals for model " + mdl + " alg " + alg);        alg.computeMarginals(models[mdl]);        int vrt = 0;        int numVertices = models[mdl].getVerticesCount();        joints[mdl][i] = new DiscretePotential[numVertices];        for (Iterator it = models[mdl].getVertexSet().iterator();             it.hasNext();             vrt++) {          Variable var = (Variable) it.next();          try {            joints[mdl][i][vrt] = alg.lookupMarginal(var);            assert joints[mdl][i][vrt] != null                    : "Query returned null for model " + mdl + " vertex " + var + " alg " + alg;          } catch (UnsupportedOperationException e) {            // Allow unsupported inference to slide with warning            logger.warning("Warning: Skipping model " + mdl + " for alg " + alg                           + "\n  Inference unsupported.");          }        }      }    }    logger.fine("Checking that results are consistent...");    /* Now, make sure the exact marginals are consistent for     *  the same model.                       */    for (int mdl = 0; mdl < models.length; mdl++) {      int maxV = models[mdl].getVerticesCount();      for (int vrt = 0; vrt < maxV; vrt++) {        for (int alg1 = 0; alg1 < algorithms.length; alg1++) {          for (int alg2 = 0; alg2 < algorithms.length; alg2++) {            DiscretePotential joint1 = joints[mdl][alg1][vrt];            DiscretePotential joint2 = joints[mdl][alg2][vrt];            try {              // By the time we get here, a joint is null only if              // there was an UnsupportedOperationException.              if ((joint1 != null) && (joint2 != null)) {                assertTrue(joint1.almostEquals(joint2));              }            } catch (AssertionFailedError e) {              System.out.println("\n************************************\nTest FAILED\n\n");              System.out.println("Model " + mdl + " Vertex " + vrt);              System.out.println("Algs " + alg1 + " and " + alg2 + " not consistent.");              System.out.println("MARGINAL from " + alg1);              System.out.println(joint1);              System.out.println("MARGINAL from " + alg2);              System.out.println(joint2);              System.out.println("Marginals from " + alg1 + ":");              for (int i = 0; i < maxV; i++) {                System.out.println(joints[mdl][alg1][i]);              }              System.out.println("Marginals from " + alg2 + ":");              for (int i = 0; i < maxV; i++) {                System.out.println(joints[mdl][alg2][i]);              }              models[mdl].dump();              throw e;            }          }        }      }    }    // Compare all approximate algorithms against brute force.    logger.fine("Checking the approximate algorithms...");    int alg2 = 0; // Brute force    for (int alg1 = algorithms.length; alg1 < numAlgs; alg1++) {      Inferencer alg = (Inferencer) appxAlgs[alg1 - algorithms.length].newInstance();      for (int mdl = 0; mdl < models.length; mdl++) {        logger.finer("Running inference alg " + alg + " with model " + mdl);        try {          alg.computeMarginals(models[mdl]);        } catch (UnsupportedOperationException e) {          // LoopyBP does not support vertex potentials.          //  We'll let that slide.          if (alg instanceof BeliefPropagation) {            logger.warning("Skipping model " + mdl + " for alg " + alg                           + "\nInference unsupported.");

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -