📄 beliefpropagation.java
字号:
package edu.umass.cs.mallet.grmm;import java.util.*;import salvo.jesus.graph.*;import salvo.jesus.graph.algorithm.DepthFirstGraphTraversal;import gnu.trove.THashSet;import edu.umass.cs.mallet.base.util.MalletLogger;import java.util.logging.Logger;/** * Pearl's belief propagation algorithm, in the version * for junction trees by Peot & Schachter. * * @author Charles Sutton * @version $Id: BeliefPropagation.java,v 1.1 2004/07/15 17:53:31 casutton Exp $ */public class BeliefPropagation extends AbstractInferencer{ protected static Logger logger = MalletLogger.getLogger (BeliefPropagation.class.getName ()); protected boolean normalizeBeliefs = true; protected boolean inLogSpace = false; // Remembers whether the last computeMarginals was junction tree, // undirected, or what. int mode; private static final int JUNCTION_TREE = 1; private static final int UNDIRECTED = 2; private int totalMessagesSent = 0; /** * Returns the total number of messages this inferencer has sent. */ public int getTotalMessagesSent () { return totalMessagesSent; }// {{{ PROPAGATION IN JUNCTION TREES private JunctionTree jtCurrent; public void computeMarginals (JunctionTree jt) { jtCurrent = jt; mode = JUNCTION_TREE; if (inLogSpace) { logger.fine ("Running inference in log space..."); jt.logify (); } propagate (jt); if (normalizeBeliefs) { jt.normalizeAll (); // Necessary if jt originally unnormalized } } public DiscretePotential lookupMarginal (Variable var) { switch (mode) { case JUNCTION_TREE: return lookupMarginal (jtCurrent, var); case UNDIRECTED: return lookupMarginal (mdlCurrent, var); default: throw new IllegalStateException ("Attempt to call lookupMarginal() before computeMarginals()."); } } protected DiscretePotential lookupMarginal (JunctionTree jt, Variable var) { return jt.lookupMarginal (var); } /* Hugin-style propagation for junction trees */ // bottom-up pass private void collectEvidence (JunctionTree jt, Clique parent, Clique child) { logger.finer ("collectEvidence "+parent+" --> "+child); for (Iterator it = jt.getChildren (child).iterator(); it.hasNext();) { Clique gchild = (Clique) it.next(); collectEvidence (jt, child, gchild); } if (parent != null) { sendMessage (jt, child, parent); } } // top-down pass private void distributeEvidence (JunctionTree jt, Clique parent) { for (Iterator it = jt.getChildren (parent).iterator(); it.hasNext();) { Clique child = (Clique) it.next(); sendMessage (jt, parent, child); distributeEvidence (jt, child); } } /** * Sends a message from the clique FROM to TO in a junction tree. * This sends a sum-product message, normalized to avoid * underflow. * <P> * Subclasses may override this to send a different kind of * message, for example, max-product. */ protected void sendMessage (JunctionTree jt, Clique from, Clique to) { totalMessagesSent++; Collection sepset = jt.getSepset (from, to); DiscretePotential fromCpf = jt.getCPF (from); DiscretePotential toCpf = jt.getCPF (to); DiscretePotential oldSepsetPot = jt.getSepsetPot (from, to); DiscretePotential lambda = fromCpf.marginalize (sepset);// System.out.println(lambda); lambda.normalize();// System.out.println(lambda); jt.setSepsetPot (lambda, from, to);// System.out.println ("MESSAGE "+from+" --> "+to);// System.out.println (fromCpf);// System.out.println (lambda);// System.out.println (toCpf);// jt.dump(); toCpf.multiplyBy (lambda); toCpf.divideBy (oldSepsetPot); toCpf.normalize (); } /** Peot & Schachter-style propagation for junction trees */ private void propagate (JunctionTree jt) { jtCurrent = jt; Clique root = (Clique) jt.getRoot (); collectEvidence (jt, null, root); distributeEvidence (jt, root); }// }}}// {{{ PROPAGATION IN UNDIRECTED MODELS /** * Array that maps (to, from) to the lambda message sent from node * from to node to. In the future, this could be made sparse (changed * to a new class like SparseArrayList. */ protected DiscretePotential[][] messages; protected DiscretePotential[] bel; private THashSet marked; protected UndirectedModel mdlCurrent; private Variable root; public DiscretePotential query (UndirectedModel m, Variable var) { throw new UnsupportedOperationException ("Belief propagation currently only works on junction trees."); } public void computeMarginals (UndirectedModel mdl) { initForGraph (mdl); marked = new THashSet (); lambdaPropagation (mdl, null, root); marked = new THashSet (); piPropagation (mdl, root); } protected int assignedVertexPtls[]; protected void initForGraph (UndirectedModel mdl) { mode = UNDIRECTED; mdlCurrent = mdl; int numNodes = mdl.getVerticesCount (); bel = new DiscretePotential [numNodes]; messages = new DiscretePotential [numNodes][numNodes]; // setup self-messages for vertex potentials for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) { Variable var = (Variable) it.next(); DiscretePotential ptl = mdl.potentialOfVertex (var); if (ptl != null) { int i = mdl.getIndex (var); if (inLogSpace) { logger.fine ("BeliefPropagation: Using log space."); messages[i][i] = ptl.log(); } else { messages[i][i] = ptl; } } } // Pick a root arbitrarily root = (Variable) mdl.getVerticesIterator().next(); } private void lambdaPropagation (UndirectedModel mdl, Variable parent, Variable child) { logger.fine ("lambda propagation "+parent+" , "+child); marked.add (child); for (Iterator it = mdl.getAdjacentVertices (child).iterator(); it.hasNext();) { Variable gchild = (Variable) it.next(); if (!marked.contains (gchild)) { lambdaPropagation (mdl, child, gchild); } } if (parent != null) {// sendLambdaMessage (mdl, child, parent); sendMessage (mdl, child, parent); } } private void piPropagation (UndirectedModel mdl, Variable var) { logger.fine ("Pi propagation from "+var); marked.add (var); for (Iterator it = mdl.getAdjacentVertices (var).iterator(); it.hasNext();) { Variable child = (Variable) it.next(); if (!marked.contains (child)) {// sendPiMessage (mdl, var, child); sendMessage (mdl, var, child); piPropagation (mdl, child); } } } protected void sendMessage (UndirectedModel mdl, Variable from, Variable to) { totalMessagesSent++; int fromIdx = mdl.getIndex (from); int toIdx = mdl.getIndex (to); DiscretePotential product = mdl.potentialOfEdge (from, to).duplicate(); msgProduct (product, fromIdx, toIdx); DiscretePotential msg = product.marginalizeOut (from); msg.normalize (); messages[toIdx][fromIdx] = msg; } protected DiscretePotential lookupMarginal (UndirectedModel mdl, Variable var) { int idx = mdl.getIndex (var); if (bel[idx] == null) { DiscretePotential marg = msgProduct (null, idx, -1); marg.delogify (); // can't hurt if (normalizeBeliefs) { marg.normalize (); } assert marg.varSet().size() == 1 :"Invalid marginal for var "+var+": "+marg; assert marg.varSet().contains (var) :"Invalid marginal for var "+var+": "+marg; bel[idx] = marg; } return bel [idx]; } protected DiscretePotential msgProduct (DiscretePotential product, int idx, int excludeMsgFrom) { if (product == null) { product = new MultinomialPotential (mdlCurrent.get (idx)); } if (inLogSpace) { product.logify (); } for (int j = 0; j < messages[idx].length; j++) { if ((messages[idx][j] != null) && (j != excludeMsgFrom)) { product.multiplyBy (messages [idx][j]); } } return product; } public void dump() { for (int i = 0; i < messages.length; i++) { for (int j = 0; j < messages[i].length; j++) { if (messages[i][j] != null) { System.out.println("Message from "+j+" to "+i); System.out.println(messages[i][j]); } } } } // }}} }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -