⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 junctiontreeinferencer.java

📁 这是一个matlab的java实现。里面有许多内容。请大家慢慢捉摸。
💻 JAVA
字号:
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This software is provided under the terms of the Common Public License,   version 1.0, as published by http://www.opensource.org.  For further   information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.grmm;import java.util.TreeSet;import salvo.jesus.graph.Graph;import salvo.jesus.graph.Vertex;import java.util.ArrayList;import java.util.Comparator;import java.util.Iterator;import java.util.logging.Logger;import edu.umass.cs.mallet.base.util.MalletLogger;import java.util.List;import edu.umass.cs.mallet.base.util.Arrays;import edu.umass.cs.mallet.base.types.Alphabet;import java.util.LinkedList;import salvo.jesus.graph.Edge;import salvo.jesus.graph.GraphImpl;import salvo.jesus.graph.Visitor;import salvo.jesus.graph.algorithm.GraphTraversal;import salvo.jesus.graph.algorithm.BreadthFirstTraversal;import java.util.logging.Level;import java.util.Collections;/** * Does inference in general graphical models using *  the Hugin junction tree algorithm. * * Created: Mon Nov 10 23:58:44 2003 * * @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a> * @version $Id: JunctionTreeInferencer.java,v 1.1 2004/07/15 17:53:31 casutton Exp $ */public class JunctionTreeInferencer extends AbstractInferencer {  private static Logger logger = MalletLogger.getLogger(JunctionTreeInferencer.class.getName());  public JunctionTreeInferencer()  {  } // JunctionTreeInferencer constructor  private boolean contains(Graph g, Vertex v1)  {    return g.getVertexSet().contains(v1);  }  private boolean isAdjacent(Graph g, Vertex v1, Vertex v2)  {    return ((UndirectedModel) g).isAdjacent(v1, v2);  }  protected JunctionTree jtCurrent;  private ArrayList cliques;  /**   * Returns the number of edges that would be added to a graph if a   *  given vertex would be removed in the triangulation procedure.   *  The return value is the number of edges in the elimination   *  clique of V that are not already present.   */  private int newEdgesRequired(Graph mdl, Vertex v)  {    int rating = 0;    for (Iterator it1 = neighborsIterator(mdl,v); it1.hasNext();) {      Vertex neighbor1 = (Vertex) it1.next();      Iterator it2 = neighborsIterator(mdl,v);      while (it2.hasNext()) {        Vertex neighbor2 = (Vertex) it2.next();        if (neighbor1 != neighbor2) {          if (!isAdjacent(mdl, neighbor1, neighbor2)) {            rating++;          }        }      }    }//		System.out.println(v+" = "+rating);    return rating;  }  /**   * Returns the weight of the clique that would be added to a graph if a   *  given vertex would be removed in the triangulation procedure.   *  The return value is the number of edges in the elimination   *  clique of V that are not already present.   */  private int weightRequired(Graph mdl, Vertex v)  {    int rating = 1;    for (Iterator it1 = neighborsIterator(mdl,v); it1.hasNext();) {      Variable neighbor = (Variable) it1.next();      rating *= neighbor.getNumOutcomes();    }//		System.out.println(v+" = "+rating);    return rating;  }  private void connectNeighbors(Graph mdl, Vertex v)  {    for (Iterator it1 = neighborsIterator(mdl,v); it1.hasNext();) {      Vertex neighbor1 = (Vertex) it1.next();      Iterator it2 = neighborsIterator(mdl,v);      while (it2.hasNext()) {        Vertex neighbor2 = (Vertex) it2.next();        if (neighbor1 != neighbor2) {          if (!isAdjacent(mdl, neighbor1, neighbor2)) {            try {              mdl.addEdge(neighbor1, neighbor2);            } catch (Exception e) {              throw new RuntimeException(e);            }          }        }      }    }  }  // xx should refactor into Collections.any (Coll, TObjectProc)  /* Return true iff a clique in L strictly contains c. */  private boolean findSuperClique(List l, Clique c)  {    for (Iterator it = l.iterator(); it.hasNext();) {      Clique c2 = (Clique) it.next();      if (c2.containsAll(c)) {        return true;      }    }    return false;  }  // works like the obscure <=> operator in Perl.  private int cmp(int i1, int i2)  {    if (i1 < i2) {      return -1;    } else if (i1 > i2) {      return 1;    } else {      return 0;    }  }	public Variable pickVertexToRemove (UndirectedModel mdl, ArrayList lst)	{		Iterator it = lst.iterator();		Variable best = (Variable) it.next();		int bestVal1 = newEdgesRequired (mdl, best);		int bestVal2 = weightRequired (mdl, best);		while (it.hasNext()) {			Variable v = (Variable) it.next();			int val = newEdgesRequired (mdl, v);			if (val < bestVal1) {				best = v;				bestVal1 = val;				bestVal2 = weightRequired (mdl, v);			} else if (val == bestVal1) {				int val2 = weightRequired (mdl, v);				if (val2 < bestVal2) {					best = v;					bestVal1 = val;					bestVal2 = val2;				}			}		}				return best;	}  /**   * Adds edges to graph until it is triangulated.   */  private void triangulate(final UndirectedModel mdl)  {    final UndirectedModel mdl2 = mdl.duplicate();    ArrayList vars = new ArrayList(mdl.getVertexSet());    Alphabet varMap = makeVertexMap(vars);    cliques = new ArrayList();    // debug    if (logger.isLoggable (Level.FINER)) {			logger.finer ("Triangulating model: "+mdl);      String ret = "";      for (int i = 0; i < vars.size(); i++) {        Variable next = (Variable) vars.get(i);        ret += next.toString() + " (" + mdl.getIndex(next) + ")\n  ";      }      logger.finer(ret);    }    while (!vars.isEmpty()) {      Variable v = (Variable) pickVertexToRemove (mdl2, vars);      logger.finer("Triangulating vertex " + v);      Clique clique = new BitSetClique(varMap, mdl2.getAdjacentVertices(v));      clique.add(v);      if (!findSuperClique(cliques, clique)) {        cliques.add(clique);        if (logger.isLoggable (Level.FINER))          logger.finer("  Elim clique " + clique + " size " + clique.size() + " weight " + clique.weight());      }      // must remove V from graph first, because adding the edges//  will change the rating of other vertices      connectNeighbors(mdl2, v);      vars.remove(v);      mdl2.remove(v);    }    if (logger.isLoggable(Level.FINE)) {      logger.fine("Triangulation done. Cliques are: ");      int totSize = 0, totWeight = 0, maxSize = 0, maxWeight = 0;      for (Iterator it = cliques.iterator(); it.hasNext();) {        Clique c = (Clique) it.next();        logger.finer(c.toString());        totSize += c.size();        maxSize = Math.max(c.size(), maxSize);        totWeight += c.weight();        maxWeight = Math.max(c.weight(), maxWeight);      }      double sz = cliques.size();      logger.fine("Jt created " + sz + " cliques. Size: avg " + (totSize / sz)                  + " max " + (maxSize) + " Weight: avg " + (totWeight / sz)                  + " max " + (maxWeight));    }  }  private Alphabet makeVertexMap(ArrayList vars)  {    Alphabet map = new Alphabet (vars.size (), Variable.class);    map.lookupIndices(vars.toArray(), true);    return map;  }  private int sepsetSize(BitSetClique[] pair)  {    assert pair.length == 2;    return pair[0].intersectionSize(pair[1]);  }  private int sepsetCost(Clique[] pair)  {    assert pair.length == 2;    return pair[0].weight() + pair[1].weight();  }  // Given two pairs of cliques, returns -1 if the pair o1 should be  // added to the tree first.  We add pairs that have the largest  // mass (number of vertices in common) to ensure that the clique  // tree satifies the running intersection property.  private Comparator sepsetChooser = new Comparator() {    public int compare(Object o1, Object o2)    {      if (o1 == o2) return 0;      BitSetClique[] pair1 = (BitSetClique[]) o1;      BitSetClique[] pair2 = (BitSetClique[]) o2;      int size1 = sepsetSize(pair1);      int size2 = sepsetSize(pair2);      int retval = -cmp(size1, size2);      if (retval == 0) {        // Break ties by adding the sepset with the        //  smallest cost (sum of weights of connected clusters)        int cost1 = sepsetCost(pair1);        int cost2 = sepsetCost(pair2);        retval = cmp(cost1, cost2);        // Still a tie? Break arbitrarily but consistently.        if (retval == 0)          retval = cmp(o1.hashCode(), o2.hashCode());      }      return retval;    }  };  private JunctionTree graphToJt(final Graph g)  {    final JunctionTree jt = new JunctionTree(g.getVerticesCount());    Visitor visitor = new Visitor() {      public boolean visit(Vertex v1)      {        for (Iterator it = neighborsIterator(g,v1); it.hasNext();) {          Vertex v2 = (Vertex) it.next();          if (jt.getParent(v1) != v2) {            jt.addNode(v1, v2);          }        }        return true;      }    };    GraphTraversal traverser = new BreadthFirstTraversal(g);    Vertex start = (Vertex) g.getVertexSet().iterator().next();    jt.add (start);    traverser.traverse(start, visitor);    return jt;  }  private JunctionTree buildJtStructure()  {    TreeSet pq = new TreeSet(sepsetChooser);    // Initialize pq with all possible edges...    for (Iterator it = cliques.iterator(); it.hasNext();) {      BitSetClique c1 = (BitSetClique) it.next();      for (Iterator it2 = cliques.iterator(); it2.hasNext();) {        BitSetClique c2 = (BitSetClique) it2.next();        if (c1 == c2) break;        pq.add(new BitSetClique[]{c1, c2});      }    }    // ...and add the edges to jt that come to the top of the queue    //  and don't cause a cycle.    // xxx OK, this sucks.  openjgraph doesn't allow adding    //  disconnected edges to a tree, so what we'll do is create a    //  Graph frist, then convert it to a Tree.    Graph g = new GraphImpl();    // first add every clique to the graph    for (Iterator it = cliques.iterator(); it.hasNext();) {      Clique c = (Clique) it.next();      try {        g.add(c);      } catch (Exception e) {        throw new RuntimeException(e);      }    }    // then add n - 1 edges    int numCliques = cliques.size();    int edgesAdded = 0;    while (edgesAdded < numCliques - 1) {      Clique[] pair = (Clique[]) pq.first();      pq.remove(pair);      if (!g.isConnected(pair[0], pair[1])) {        try {          g.addEdge(pair[0], pair[1]);          edgesAdded++;        } catch (Exception e) {          throw new RuntimeException(e);        }        ;      }    }    JunctionTree jt = graphToJt(g);    if (logger.isLoggable(Level.FINER))      logger.finer("  jt structure was " + jt);    return jt;  }  private void initJtCpts(UndirectedModel mdl, JunctionTree jt)  {    for (Iterator it = jt.getVerticesIterator(); it.hasNext();) {      Clique c = (Clique) it.next();			DiscretePotential ptl = new MultinomialPotential (c);			ptl.logify();      jt.setCPF(c, ptl);    }		jt.logify();  // Get the sepset potentials in log space    for (Iterator it = mdl.potentials().iterator(); it.hasNext();) {      DiscretePotential ptl = (DiscretePotential) it.next();      Clique parent = jt.findParentCluster(ptl.varSet());      assert parent != null              : "Unable to find parent cluster for ptl " + ptl + "in jt " + jt;      DiscretePotential cpf = jt.getCPF(parent);      cpf.multiplyBy(ptl.log());			/* debug			if (jt.isNaN()) {				throw new RuntimeException ("Got a NaN");			}			*/    }  }  public void computeMarginals(UndirectedModel mdl)  {    buildJunctionTree(mdl);    BeliefPropagation bp = new BeliefPropagation();    bp.computeMarginals(jtCurrent);		totalMessagesSent += bp.getTotalMessagesSent();  }  public JunctionTree buildJunctionTree(UndirectedModel mdl)  {    jtCurrent = (JunctionTree) mdl.getInferenceCache(JunctionTreeInferencer.class);    if (jtCurrent != null) {      jtCurrent.clearCPFs();    } else {      triangulate(mdl);      jtCurrent = buildJtStructure();      mdl.setInferenceCache(JunctionTreeInferencer.class, jtCurrent);    }    initJtCpts(mdl, jtCurrent);    return jtCurrent;  }  public DiscretePotential lookupMarginal(Variable var)  {    Clique parent = jtCurrent.findParentCluster(var);    DiscretePotential cpf = jtCurrent.getCPF(parent);    if (logger.isLoggable(Level.FINER)) {      logger.finer("Lookup jt marginal: var " + var + " cluster " + parent);      logger.finest(" cpf " + cpf);    }    DiscretePotential marginal = cpf.marginalize(var);    marginal.normalize();		marginal.delogify();    return marginal;  }  public DiscretePotential lookupMarginal(Clique clique)  {    Clique parent = jtCurrent.findParentCluster(clique);    if (parent == null) {      throw new UnsupportedOperationException              ("No parent cluster in " + jtCurrent + " for clique " + clique);    }    DiscretePotential cpf = jtCurrent.getCPF(parent);    if (logger.isLoggable(Level.FINER)) {      logger.finer("Lookup jt marginal: clique " + clique + " cluster " + parent);      logger.finest("  cpf " + cpf);    }    DiscretePotential marginal = cpf.marginalize(clique);    marginal.normalize();		marginal.delogify();    return marginal;  }  public double lookupLogJoint(Assignment assn)  {    return jtCurrent.lookupLogJoint(assn);  }  public double dumpLogJoint(Assignment assn)  {    return jtCurrent.dumpLogJoint(assn);  }	// test	private Iterator neighborsIterator (final Graph g, final Vertex v) {		return new Iterator () {				Iterator edgeIt = g.getEdges (v).iterator();				public boolean hasNext () { return edgeIt.hasNext(); };				public Object next () { 					Edge edge = (Edge) edgeIt.next();					return edge.getOppositeVertex (v);				}				public void remove () {					throw new UnsupportedOperationException ();				}			};	}	public void dump ()	{		if (jtCurrent != null) {			System.out.println("Current junction tree");			jtCurrent.dump();		} else {			System.out.println("NO current junction tree");		}	}					private int totalMessagesSent = 0;	/**	 * Returns the total number of messages this inferencer has sent.	 */	public int getTotalMessagesSent () { return totalMessagesSent; }} // JunctionTreeInferencer

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -