⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 beliefpropagation.java

📁 这是一个matlab的java实现。里面有许多内容。请大家慢慢捉摸。
💻 JAVA
字号:
package edu.umass.cs.mallet.grmm;import java.util.*;import salvo.jesus.graph.*;import salvo.jesus.graph.algorithm.DepthFirstGraphTraversal;import gnu.trove.THashSet;import edu.umass.cs.mallet.base.util.MalletLogger;import java.util.logging.Logger;/** *  Pearl's belief propagation algorithm, in the version *  for junction trees by Peot & Schachter. *  * @author Charles Sutton * @version $Id: BeliefPropagation.java,v 1.1 2004/07/15 17:53:31 casutton Exp $ */public class BeliefPropagation extends AbstractInferencer{	protected static Logger logger = MalletLogger.getLogger (BeliefPropagation.class.getName ());	protected boolean normalizeBeliefs = true;	protected boolean inLogSpace = false;	// Remembers whether the last computeMarginals was junction tree,	// undirected, or what.	int mode;	private static final int JUNCTION_TREE = 1;	private static final int UNDIRECTED = 2;	private int totalMessagesSent = 0;	/**	 * Returns the total number of messages this inferencer has sent.	 */	public int getTotalMessagesSent () { return totalMessagesSent; }// {{{ PROPAGATION IN JUNCTION TREES	private JunctionTree jtCurrent;	public void computeMarginals (JunctionTree jt) 	{		jtCurrent = jt;		mode = JUNCTION_TREE;		if (inLogSpace) {			logger.fine ("Running inference in log space...");			jt.logify (); 		}		propagate (jt);		if (normalizeBeliefs) {			jt.normalizeAll ();  		// Necessary if jt originally unnormalized			 		}	}	public DiscretePotential lookupMarginal (Variable var) 	{		switch (mode) {			case JUNCTION_TREE: return lookupMarginal (jtCurrent, var);			case UNDIRECTED: return lookupMarginal (mdlCurrent, var);			default:				throw new IllegalStateException 					("Attempt to call lookupMarginal() before computeMarginals().");		}	}	protected DiscretePotential lookupMarginal (JunctionTree jt, Variable var)	{		return jt.lookupMarginal (var);	}	/* Hugin-style propagation for junction trees */	// bottom-up pass	private void collectEvidence (JunctionTree jt, Clique parent, Clique child)	{		logger.finer ("collectEvidence "+parent+" --> "+child);		for (Iterator it = jt.getChildren (child).iterator(); it.hasNext();)		{			Clique gchild = (Clique) it.next();			collectEvidence (jt, child, gchild);		}		if (parent != null) {			sendMessage (jt, child, parent);		}	}	// top-down pass	private void distributeEvidence (JunctionTree jt, Clique parent)	{		for (Iterator it = jt.getChildren (parent).iterator(); it.hasNext();)		{			Clique child = (Clique) it.next();			sendMessage (jt, parent, child);			distributeEvidence (jt, child);		}	}		/**	 *  Sends a message from the clique FROM to TO in a junction tree.	 *   This sends a sum-product message, normalized to avoid	 *   underflow.	 *  <P>	 *  Subclasses may override this to send a different kind of	 *   message, for example, max-product.	 */	protected void sendMessage (JunctionTree jt, Clique from, Clique to)	{		totalMessagesSent++;		Collection sepset = jt.getSepset (from, to);		DiscretePotential fromCpf = jt.getCPF (from);		DiscretePotential toCpf = jt.getCPF (to);		DiscretePotential oldSepsetPot = jt.getSepsetPot (from, to);		DiscretePotential lambda = fromCpf.marginalize (sepset);//		System.out.println(lambda);				lambda.normalize();//		System.out.println(lambda);				jt.setSepsetPot (lambda, from, to);//		System.out.println ("MESSAGE "+from+" --> "+to);//		System.out.println (fromCpf);//		System.out.println (lambda);//		System.out.println (toCpf);//		jt.dump();		toCpf.multiplyBy (lambda);		toCpf.divideBy (oldSepsetPot);		toCpf.normalize ();	}	/** Peot & Schachter-style propagation for junction trees */	private void propagate (JunctionTree jt)	{		jtCurrent = jt;		Clique root = (Clique) jt.getRoot ();		collectEvidence (jt, null, root);		distributeEvidence (jt, root);	}// }}}// {{{ PROPAGATION IN UNDIRECTED MODELS  		/**	 *  Array that maps (to, from) to the lambda message sent from node	 * from to node to.  In the future, this could be made sparse (changed	 * to a new class like SparseArrayList.	 */ 	protected DiscretePotential[][] messages;	protected DiscretePotential[] bel;	private THashSet marked;	protected UndirectedModel mdlCurrent;	private Variable root;	public DiscretePotential query (UndirectedModel m, Variable var) 	{			throw new UnsupportedOperationException 				("Belief propagation currently only works on junction trees.");	}	public void computeMarginals (UndirectedModel mdl) 	{		initForGraph (mdl);		marked = new THashSet (); lambdaPropagation (mdl, null, root);		marked = new THashSet (); piPropagation (mdl, root);	}	protected int assignedVertexPtls[];	protected void initForGraph (UndirectedModel mdl) 	{		mode = UNDIRECTED;		mdlCurrent = mdl;		int numNodes = mdl.getVerticesCount ();		bel = new DiscretePotential [numNodes];		messages = new DiscretePotential [numNodes][numNodes];		// setup self-messages for vertex potentials		for (Iterator it = mdl.getVerticesIterator(); it.hasNext();) {			Variable var = (Variable) it.next();			DiscretePotential ptl = mdl.potentialOfVertex (var);			if (ptl != null) {				int i = mdl.getIndex (var);				if (inLogSpace) {					logger.fine ("BeliefPropagation: Using log space.");					messages[i][i] = ptl.log();				} else {					messages[i][i] = ptl;				}			}		}		 		// Pick a root arbitrarily		root = (Variable) mdl.getVerticesIterator().next();	}		private void lambdaPropagation (UndirectedModel mdl, Variable parent, Variable child)	{		logger.fine ("lambda propagation "+parent+" , "+child);		marked.add (child);		for (Iterator it = mdl.getAdjacentVertices (child).iterator(); it.hasNext();) {			Variable gchild = (Variable) it.next();			if (!marked.contains (gchild)) {				lambdaPropagation (mdl, child, gchild);			}		}		if (parent != null) {//			sendLambdaMessage (mdl, child, parent);			sendMessage (mdl, child, parent);		}	}		private void piPropagation (UndirectedModel mdl, Variable var)	{		logger.fine ("Pi propagation from "+var);		marked.add (var);		for (Iterator it = mdl.getAdjacentVertices (var).iterator(); it.hasNext();) {			Variable child = (Variable) it.next();			if (!marked.contains (child)) {//				sendPiMessage (mdl, var, child);				sendMessage (mdl, var, child);				piPropagation (mdl, child);			}		}	}	protected void sendMessage (UndirectedModel mdl, Variable from, Variable to)	{		totalMessagesSent++;		int fromIdx = mdl.getIndex (from);		int toIdx = mdl.getIndex (to);		DiscretePotential product = mdl.potentialOfEdge (from, to).duplicate();		msgProduct (product, fromIdx, toIdx);				DiscretePotential msg = product.marginalizeOut (from);		msg.normalize ();		messages[toIdx][fromIdx] = msg;		}	protected DiscretePotential lookupMarginal (UndirectedModel mdl, Variable var)	{		int idx = mdl.getIndex (var);		if (bel[idx] == null) {			DiscretePotential marg = msgProduct (null, idx, -1);			marg.delogify (); // can't hurt			if (normalizeBeliefs) {				marg.normalize ();			}			assert marg.varSet().size() == 1 				:"Invalid marginal for var "+var+": "+marg;			assert marg.varSet().contains (var)				:"Invalid marginal for var "+var+": "+marg;			bel[idx] = marg;		}		return bel [idx];	}	protected DiscretePotential msgProduct (DiscretePotential product, int idx, int excludeMsgFrom)	{		if (product == null) {			product = new MultinomialPotential (mdlCurrent.get (idx));		}		if (inLogSpace) {			product.logify ();		}				for (int j = 0; j < messages[idx].length; j++) {			if ((messages[idx][j] != null) && (j != excludeMsgFrom)) {				 product.multiplyBy (messages [idx][j]);			}		}		return product;	}	public void dump()	{		for (int i = 0; i < messages.length; i++) {			for (int j = 0; j < messages[i].length; j++) {				if (messages[i][j] != null) {					System.out.println("Message from "+j+" to "+i);					System.out.println(messages[i][j]);				}							}		}	}	// }}}	}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -