📄 junctiontreeinferencer.java
字号:
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.grmm;import java.util.TreeSet;import salvo.jesus.graph.Graph;import salvo.jesus.graph.Vertex;import java.util.ArrayList;import java.util.Comparator;import java.util.Iterator;import java.util.logging.Logger;import edu.umass.cs.mallet.base.util.MalletLogger;import java.util.List;import edu.umass.cs.mallet.base.util.Arrays;import edu.umass.cs.mallet.base.types.Alphabet;import java.util.LinkedList;import salvo.jesus.graph.Edge;import salvo.jesus.graph.GraphImpl;import salvo.jesus.graph.Visitor;import salvo.jesus.graph.algorithm.GraphTraversal;import salvo.jesus.graph.algorithm.BreadthFirstTraversal;import java.util.logging.Level;import java.util.Collections;/** * Does inference in general graphical models using * the Hugin junction tree algorithm. * * Created: Mon Nov 10 23:58:44 2003 * * @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a> * @version $Id: JunctionTreeInferencer.java,v 1.1 2004/07/15 17:53:31 casutton Exp $ */public class JunctionTreeInferencer extends AbstractInferencer { private static Logger logger = MalletLogger.getLogger(JunctionTreeInferencer.class.getName()); public JunctionTreeInferencer() { } // JunctionTreeInferencer constructor private boolean contains(Graph g, Vertex v1) { return g.getVertexSet().contains(v1); } private boolean isAdjacent(Graph g, Vertex v1, Vertex v2) { return ((UndirectedModel) g).isAdjacent(v1, v2); } protected JunctionTree jtCurrent; private ArrayList cliques; /** * Returns the number of edges that would be added to a graph if a * given vertex would be removed in the triangulation procedure. * The return value is the number of edges in the elimination * clique of V that are not already present. */ private int newEdgesRequired(Graph mdl, Vertex v) { int rating = 0; for (Iterator it1 = neighborsIterator(mdl,v); it1.hasNext();) { Vertex neighbor1 = (Vertex) it1.next(); Iterator it2 = neighborsIterator(mdl,v); while (it2.hasNext()) { Vertex neighbor2 = (Vertex) it2.next(); if (neighbor1 != neighbor2) { if (!isAdjacent(mdl, neighbor1, neighbor2)) { rating++; } } } }// System.out.println(v+" = "+rating); return rating; } /** * Returns the weight of the clique that would be added to a graph if a * given vertex would be removed in the triangulation procedure. * The return value is the number of edges in the elimination * clique of V that are not already present. */ private int weightRequired(Graph mdl, Vertex v) { int rating = 1; for (Iterator it1 = neighborsIterator(mdl,v); it1.hasNext();) { Variable neighbor = (Variable) it1.next(); rating *= neighbor.getNumOutcomes(); }// System.out.println(v+" = "+rating); return rating; } private void connectNeighbors(Graph mdl, Vertex v) { for (Iterator it1 = neighborsIterator(mdl,v); it1.hasNext();) { Vertex neighbor1 = (Vertex) it1.next(); Iterator it2 = neighborsIterator(mdl,v); while (it2.hasNext()) { Vertex neighbor2 = (Vertex) it2.next(); if (neighbor1 != neighbor2) { if (!isAdjacent(mdl, neighbor1, neighbor2)) { try { mdl.addEdge(neighbor1, neighbor2); } catch (Exception e) { throw new RuntimeException(e); } } } } } } // xx should refactor into Collections.any (Coll, TObjectProc) /* Return true iff a clique in L strictly contains c. */ private boolean findSuperClique(List l, Clique c) { for (Iterator it = l.iterator(); it.hasNext();) { Clique c2 = (Clique) it.next(); if (c2.containsAll(c)) { return true; } } return false; } // works like the obscure <=> operator in Perl. private int cmp(int i1, int i2) { if (i1 < i2) { return -1; } else if (i1 > i2) { return 1; } else { return 0; } } public Variable pickVertexToRemove (UndirectedModel mdl, ArrayList lst) { Iterator it = lst.iterator(); Variable best = (Variable) it.next(); int bestVal1 = newEdgesRequired (mdl, best); int bestVal2 = weightRequired (mdl, best); while (it.hasNext()) { Variable v = (Variable) it.next(); int val = newEdgesRequired (mdl, v); if (val < bestVal1) { best = v; bestVal1 = val; bestVal2 = weightRequired (mdl, v); } else if (val == bestVal1) { int val2 = weightRequired (mdl, v); if (val2 < bestVal2) { best = v; bestVal1 = val; bestVal2 = val2; } } } return best; } /** * Adds edges to graph until it is triangulated. */ private void triangulate(final UndirectedModel mdl) { final UndirectedModel mdl2 = mdl.duplicate(); ArrayList vars = new ArrayList(mdl.getVertexSet()); Alphabet varMap = makeVertexMap(vars); cliques = new ArrayList(); // debug if (logger.isLoggable (Level.FINER)) { logger.finer ("Triangulating model: "+mdl); String ret = ""; for (int i = 0; i < vars.size(); i++) { Variable next = (Variable) vars.get(i); ret += next.toString() + " (" + mdl.getIndex(next) + ")\n "; } logger.finer(ret); } while (!vars.isEmpty()) { Variable v = (Variable) pickVertexToRemove (mdl2, vars); logger.finer("Triangulating vertex " + v); Clique clique = new BitSetClique(varMap, mdl2.getAdjacentVertices(v)); clique.add(v); if (!findSuperClique(cliques, clique)) { cliques.add(clique); if (logger.isLoggable (Level.FINER)) logger.finer(" Elim clique " + clique + " size " + clique.size() + " weight " + clique.weight()); } // must remove V from graph first, because adding the edges// will change the rating of other vertices connectNeighbors(mdl2, v); vars.remove(v); mdl2.remove(v); } if (logger.isLoggable(Level.FINE)) { logger.fine("Triangulation done. Cliques are: "); int totSize = 0, totWeight = 0, maxSize = 0, maxWeight = 0; for (Iterator it = cliques.iterator(); it.hasNext();) { Clique c = (Clique) it.next(); logger.finer(c.toString()); totSize += c.size(); maxSize = Math.max(c.size(), maxSize); totWeight += c.weight(); maxWeight = Math.max(c.weight(), maxWeight); } double sz = cliques.size(); logger.fine("Jt created " + sz + " cliques. Size: avg " + (totSize / sz) + " max " + (maxSize) + " Weight: avg " + (totWeight / sz) + " max " + (maxWeight)); } } private Alphabet makeVertexMap(ArrayList vars) { Alphabet map = new Alphabet (vars.size (), Variable.class); map.lookupIndices(vars.toArray(), true); return map; } private int sepsetSize(BitSetClique[] pair) { assert pair.length == 2; return pair[0].intersectionSize(pair[1]); } private int sepsetCost(Clique[] pair) { assert pair.length == 2; return pair[0].weight() + pair[1].weight(); } // Given two pairs of cliques, returns -1 if the pair o1 should be // added to the tree first. We add pairs that have the largest // mass (number of vertices in common) to ensure that the clique // tree satifies the running intersection property. private Comparator sepsetChooser = new Comparator() { public int compare(Object o1, Object o2) { if (o1 == o2) return 0; BitSetClique[] pair1 = (BitSetClique[]) o1; BitSetClique[] pair2 = (BitSetClique[]) o2; int size1 = sepsetSize(pair1); int size2 = sepsetSize(pair2); int retval = -cmp(size1, size2); if (retval == 0) { // Break ties by adding the sepset with the // smallest cost (sum of weights of connected clusters) int cost1 = sepsetCost(pair1); int cost2 = sepsetCost(pair2); retval = cmp(cost1, cost2); // Still a tie? Break arbitrarily but consistently. if (retval == 0) retval = cmp(o1.hashCode(), o2.hashCode()); } return retval; } }; private JunctionTree graphToJt(final Graph g) { final JunctionTree jt = new JunctionTree(g.getVerticesCount()); Visitor visitor = new Visitor() { public boolean visit(Vertex v1) { for (Iterator it = neighborsIterator(g,v1); it.hasNext();) { Vertex v2 = (Vertex) it.next(); if (jt.getParent(v1) != v2) { jt.addNode(v1, v2); } } return true; } }; GraphTraversal traverser = new BreadthFirstTraversal(g); Vertex start = (Vertex) g.getVertexSet().iterator().next(); jt.add (start); traverser.traverse(start, visitor); return jt; } private JunctionTree buildJtStructure() { TreeSet pq = new TreeSet(sepsetChooser); // Initialize pq with all possible edges... for (Iterator it = cliques.iterator(); it.hasNext();) { BitSetClique c1 = (BitSetClique) it.next(); for (Iterator it2 = cliques.iterator(); it2.hasNext();) { BitSetClique c2 = (BitSetClique) it2.next(); if (c1 == c2) break; pq.add(new BitSetClique[]{c1, c2}); } } // ...and add the edges to jt that come to the top of the queue // and don't cause a cycle. // xxx OK, this sucks. openjgraph doesn't allow adding // disconnected edges to a tree, so what we'll do is create a // Graph frist, then convert it to a Tree. Graph g = new GraphImpl(); // first add every clique to the graph for (Iterator it = cliques.iterator(); it.hasNext();) { Clique c = (Clique) it.next(); try { g.add(c); } catch (Exception e) { throw new RuntimeException(e); } } // then add n - 1 edges int numCliques = cliques.size(); int edgesAdded = 0; while (edgesAdded < numCliques - 1) { Clique[] pair = (Clique[]) pq.first(); pq.remove(pair); if (!g.isConnected(pair[0], pair[1])) { try { g.addEdge(pair[0], pair[1]); edgesAdded++; } catch (Exception e) { throw new RuntimeException(e); } ; } } JunctionTree jt = graphToJt(g); if (logger.isLoggable(Level.FINER)) logger.finer(" jt structure was " + jt); return jt; } private void initJtCpts(UndirectedModel mdl, JunctionTree jt) { for (Iterator it = jt.getVerticesIterator(); it.hasNext();) { Clique c = (Clique) it.next(); DiscretePotential ptl = new MultinomialPotential (c); ptl.logify(); jt.setCPF(c, ptl); } jt.logify(); // Get the sepset potentials in log space for (Iterator it = mdl.potentials().iterator(); it.hasNext();) { DiscretePotential ptl = (DiscretePotential) it.next(); Clique parent = jt.findParentCluster(ptl.varSet()); assert parent != null : "Unable to find parent cluster for ptl " + ptl + "in jt " + jt; DiscretePotential cpf = jt.getCPF(parent); cpf.multiplyBy(ptl.log()); /* debug if (jt.isNaN()) { throw new RuntimeException ("Got a NaN"); } */ } } public void computeMarginals(UndirectedModel mdl) { buildJunctionTree(mdl); BeliefPropagation bp = new BeliefPropagation(); bp.computeMarginals(jtCurrent); totalMessagesSent += bp.getTotalMessagesSent(); } public JunctionTree buildJunctionTree(UndirectedModel mdl) { jtCurrent = (JunctionTree) mdl.getInferenceCache(JunctionTreeInferencer.class); if (jtCurrent != null) { jtCurrent.clearCPFs(); } else { triangulate(mdl); jtCurrent = buildJtStructure(); mdl.setInferenceCache(JunctionTreeInferencer.class, jtCurrent); } initJtCpts(mdl, jtCurrent); return jtCurrent; } public DiscretePotential lookupMarginal(Variable var) { Clique parent = jtCurrent.findParentCluster(var); DiscretePotential cpf = jtCurrent.getCPF(parent); if (logger.isLoggable(Level.FINER)) { logger.finer("Lookup jt marginal: var " + var + " cluster " + parent); logger.finest(" cpf " + cpf); } DiscretePotential marginal = cpf.marginalize(var); marginal.normalize(); marginal.delogify(); return marginal; } public DiscretePotential lookupMarginal(Clique clique) { Clique parent = jtCurrent.findParentCluster(clique); if (parent == null) { throw new UnsupportedOperationException ("No parent cluster in " + jtCurrent + " for clique " + clique); } DiscretePotential cpf = jtCurrent.getCPF(parent); if (logger.isLoggable(Level.FINER)) { logger.finer("Lookup jt marginal: clique " + clique + " cluster " + parent); logger.finest(" cpf " + cpf); } DiscretePotential marginal = cpf.marginalize(clique); marginal.normalize(); marginal.delogify(); return marginal; } public double lookupLogJoint(Assignment assn) { return jtCurrent.lookupLogJoint(assn); } public double dumpLogJoint(Assignment assn) { return jtCurrent.dumpLogJoint(assn); } // test private Iterator neighborsIterator (final Graph g, final Vertex v) { return new Iterator () { Iterator edgeIt = g.getEdges (v).iterator(); public boolean hasNext () { return edgeIt.hasNext(); }; public Object next () { Edge edge = (Edge) edgeIt.next(); return edge.getOppositeVertex (v); } public void remove () { throw new UnsupportedOperationException (); } }; } public void dump () { if (jtCurrent != null) { System.out.println("Current junction tree"); jtCurrent.dump(); } else { System.out.println("NO current junction tree"); } } private int totalMessagesSent = 0; /** * Returns the total number of messages this inferencer has sent. */ public int getTotalMessagesSent () { return totalMessagesSent; }} // JunctionTreeInferencer
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -