📄 loopymaxproduct.java
字号:
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.grmm;import salvo.jesus.graph.Edge;import java.util.Iterator;import java.util.logging.Level;/** * The loopy belief propagation algorithm for approximate inference in * general graphical models. * * Created: Wed Nov 5 19:30:15 2003 * * @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a> * @version $Id: LoopyMaxProduct.java,v 1.1 2004/07/15 17:53:31 casutton Exp $ */public class LoopyMaxProduct extends ViterbiPropagation { public static final int DEFAULT_MAX_ITER = 1000; private int maxIter; private double threshold = 0.0001; public LoopyMaxProduct() { this (DEFAULT_MAX_ITER); } public LoopyMaxProduct(int maxIter) { this.maxIter = maxIter; } DiscretePotential[][] oldMessages; private boolean hasConverged = false; public boolean hasConverged() { return hasConverged; } private void initOldMessages (UndirectedModel mdl) { int n = mdl.getVerticesCount (); oldMessages = new DiscretePotential [n][n]; for (Iterator it = mdl.getEdgeSet().iterator(); it.hasNext();) { Edge edge = (Edge) it.next(); Variable v1 = (Variable) edge.getVertexA (); Variable v2 = (Variable) edge.getVertexB (); int i = mdl.getIndex (v1); int j = mdl.getIndex (v2); oldMessages [i][j] = new MultinomialPotential (mdl.get (i));// oldMessages [i][j].logify(); } } private boolean checkConvergence () { for (int i = 0; i < oldMessages.length; i++) { for (int j = 0; j < oldMessages[i].length; j++) { DiscretePotential ptl1 = oldMessages [i][j]; DiscretePotential ptl2 = messages [i][j]; if (oldMessages [i][j] != null) { assert messages [i][j] != null : "Message went from nonnull to null "+i+" --> "+j; if (logger.isLoggable (Level.FINER)) logger.finer ("["+i+"]["+j+"]\n"+oldMessages[i][j]+messages[i][j]); for (Iterator it = ptl1.assignmentIterator(); it.hasNext();) { Assignment assn = (Assignment) it.next(); double val1 = ptl1.phi (assn); double val2 = ptl2.phi (assn); if (Math.abs (Math.exp(val1) - Math.exp(val2)) > threshold) { return false; } } } } } return true; } public void computeMarginals (UndirectedModel mdl) { hasConverged = false; initOldMessages (mdl); int iter; for (iter = 0; iter < maxIter; iter++) { logger.fine ("***LoopyMaxProduct iteration "+iter); super.initForGraph (mdl); propagate (mdl); if (checkConvergence ()) break; oldMessages = messages; } if (iter >= maxIter) { hasConverged = false; logger.info ("***Loopy BP quitting: not converged after "+maxIter +" iterations."); } else { hasConverged = true; logger.info ("***LoopyMaxProduct finished: "+iter+" iterations"); } } private void propagate (UndirectedModel mdl) { // Send all messages. In this implementation, we send messages // synchronously, so it doesn't matter what order we send // messages in. for (Iterator it = mdl.getEdgeSet().iterator(); it.hasNext();) { Edge edge = (Edge) it.next(); Variable v1 = (Variable) edge.getVertexA (); Variable v2 = (Variable) edge.getVertexB (); sendMessage (mdl, v1, v2); sendMessage (mdl, v2, v1); } } protected DiscretePotential msgProduct (int idx, int excludeMsgFrom) {// logger.finest ("Multipling messages to "+idx+" (except for "// + excludeMsgFrom + ")"); DiscretePotential product = new MultinomialPotential ();// product.logify(); for (int j = 0; j < oldMessages[idx].length; j++) { if ((oldMessages[idx][j] != null) && (j != excludeMsgFrom)) {// logger.finest ("Multiplying in oldMessages["+idx+"]["+j+"] ="+// oldMessages[idx][j]); product.multiplyBy (oldMessages [idx][j]); } } return product; } public DiscretePotential lookupMarginal (Variable v1, Variable v2) { int idx1 = mdlCurrent.getIndex (v1); int idx2 = mdlCurrent.getIndex (v2); DiscretePotential product1 = msgProduct (idx1, idx2); DiscretePotential product2 = msgProduct (idx2, idx1); product1.multiplyBy (product2); DiscretePotential edgePtl = mdlCurrent.potentialOfEdge (v1, v2); product1.multiplyBy (edgePtl); DiscretePotential vertexPtl1 = mdlCurrent.potentialOfVertex (v1); if (vertexPtl1 != null) { product1.multiplyBy (vertexPtl1); } DiscretePotential vertexPtl2 = mdlCurrent.potentialOfVertex (v2); if (vertexPtl2 != null) { product1.multiplyBy (vertexPtl2); } DiscretePotential marg = product1.extractMax (new Variable[] { v1, v2 }); return marg; } public DiscretePotential lookupMarginal (Clique c) { switch (c.size ()) { case 1: return lookupMarginal (c.get (0)); case 2: return lookupMarginal (c.get (0), c.get (1)); default: throw new IllegalArgumentException ("LoopyBP currently only supports node and edge cliques."); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -