📄 asyncloopybp.java
字号:
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.grmm;import salvo.jesus.graph.Edge;import java.util.Iterator;import java.util.Collections;import java.util.ArrayList;/** * The loopy belief propagation algorithm for approximate inference in * general graphical models. * * Created: Wed Nov 5 19:30:15 2003 * * @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a> * @version $Id: AsyncLoopyBP.java,v 1.1 2004/07/15 17:53:31 casutton Exp $ */public class AsyncLoopyBP extends BeliefPropagation { public static final int DEFAULT_MAX_ITER = 1000; private int iterUsed; private int maxIter; private double threshold = 0.0001; DiscretePotential[][] oldMessages; public int iterationsUsed () { return iterUsed; } // Note this does not have the sophisticated Terminator interface // that we've got in TRP. public AsyncLoopyBP() { this (DEFAULT_MAX_ITER); } public AsyncLoopyBP(int maxIter) { this.maxIter = maxIter; } private void initOldMessages (UndirectedModel mdl) { int n = mdl.getVerticesCount (); oldMessages = new DiscretePotential [n][n]; for (Iterator it = mdl.getEdgeSet().iterator(); it.hasNext();) { Edge edge = (Edge) it.next(); Variable v1 = (Variable) edge.getVertexA (); Variable v2 = (Variable) edge.getVertexB (); int i = mdl.getIndex (v1); int j = mdl.getIndex (v2); oldMessages [i][j] = new MultinomialPotential (mdl.get (i)); } } private void copyOldMessages () { for (int i = 0; i < messages.length; i++) { for (int j = 0; j < messages[i].length; j++) { if (messages[i][j] != null) { oldMessages[i][j] = messages[i][j].duplicate(); } } } } private boolean hasConverged () { for (int i = 0; i < oldMessages.length; i++) { for (int j = 0; j < oldMessages[i].length; j++) { DiscretePotential ptl1 = oldMessages [i][j]; DiscretePotential ptl2 = messages [i][j]; if (oldMessages [i][j] != null) { assert messages [i][j] != null : "Message went from nonnull to null "+i+" --> "+j; for (Iterator it = ptl1.assignmentIterator(); it.hasNext();) { Assignment assn = (Assignment) it.next(); double val1 = ptl1.phi (assn); double val2 = ptl2.phi (assn); if (Math.abs (val1 - val2) > threshold) { return false; } } } } } return true; } public void computeMarginals (UndirectedModel mdl) { initOldMessages (mdl); super.initForGraph (mdl); int iter; for (iter = 0; iter < maxIter; iter++) { logger.finer ("***AsyncLoopyBP iteration "+iter); propagate (mdl); if (hasConverged ()) break; copyOldMessages (); } iterUsed = iter; if (iter >= maxIter) { logger.info ("***Loopy BP quitting: not converged after "+maxIter+" iterations."); } else { iterUsed++; // there's an off-by-one b/c of location of above break logger.info ("***AsyncLoopyBP converged: "+iterUsed+" iterations"); } } private void propagate (UndirectedModel mdl) { // Send all messages in random order. ArrayList edges = new ArrayList (mdl.getEdgeSet()); ArrayList edgesRev = new ArrayList (edges.size()); Collections.shuffle (edges); for (Iterator it = edges.iterator(); it.hasNext();) { Edge edge = (Edge) it.next(); Variable v1 = (Variable) edge.getVertexA (); Variable v2 = (Variable) edge.getVertexB (); sendMessage (mdl, v1, v2); edgesRev.add (edge); } Collections.shuffle (edgesRev); for (Iterator it = edgesRev.iterator(); it.hasNext();) { Edge edge = (Edge) it.next(); Variable v1 = (Variable) edge.getVertexA (); Variable v2 = (Variable) edge.getVertexB (); sendMessage (mdl, v2, v1); } } public DiscretePotential lookupMarginal (Variable v1, Variable v2) { int idx1 = mdlCurrent.getIndex (v1); int idx2 = mdlCurrent.getIndex (v2); DiscretePotential edgePtl = mdlCurrent.potentialOfEdge (v1, v2); DiscretePotential product = edgePtl.duplicate(); msgProduct (product, idx1, idx2); msgProduct (product, idx2, idx1); assert product.varSet().size() == 2; product.normalize (); return product; } public DiscretePotential lookupMarginal (Clique c) { switch (c.size ()) { case 1: return lookupMarginal (c.get (0)); case 2: return lookupMarginal (c.get (0), c.get (1)); default: throw new IllegalArgumentException ("AsyncLoopyBP currently only supports node and edge cliques."); } } // xxx Assumes UndirectedModel public double lookupLogJoint (Assignment assn) { double accum = 0.0; // Compute using BP-factorization // prod_s (p(x_s))^-(deg(s)-1) * ... for (Iterator it = mdlCurrent.getVerticesIterator(); it.hasNext();) { Variable var = (Variable) it.next(); DiscretePotential ptl = lookupMarginal (var); int deg = mdlCurrent.getDegree(var); if (deg > 1) accum -= (deg - 1) * Math.log (ptl.phi (assn)); } // ... * prod_{st} p(x_s, x_t) for (Iterator it = mdlCurrent.getEdgeSet().iterator(); it.hasNext();) { Edge edge = (Edge) it.next(); Variable v1 = (Variable) edge.getVertexA (); Variable v2 = (Variable) edge.getVertexB (); DiscretePotential p12 = lookupMarginal (v1, v2); DiscretePotential p1 = lookupMarginal (v1); DiscretePotential p2 = lookupMarginal (v2); accum += Math.log (p12.phi (assn)); } return accum; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -