⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 loopymaxproduct.java

📁 这是一个matlab的java实现。里面有许多内容。请大家慢慢捉摸。
💻 JAVA
字号:
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This software is provided under the terms of the Common Public License,   version 1.0, as published by http://www.opensource.org.  For further   information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.grmm;import salvo.jesus.graph.Edge;import java.util.Iterator;import java.util.logging.Level;/** * The loopy belief propagation algorithm for approximate inference in *  general graphical models. * * Created: Wed Nov  5 19:30:15 2003 * * @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a> * @version $Id: LoopyMaxProduct.java,v 1.1 2004/07/15 17:53:31 casutton Exp $ */public class LoopyMaxProduct extends ViterbiPropagation {	public static final int DEFAULT_MAX_ITER = 1000;	private int maxIter;	private double threshold = 0.0001; 	public LoopyMaxProduct() { this (DEFAULT_MAX_ITER); }	public LoopyMaxProduct(int maxIter) { this.maxIter = maxIter; }	DiscretePotential[][] oldMessages;	private boolean hasConverged = false;	public boolean hasConverged() { return hasConverged; }	private void initOldMessages (UndirectedModel mdl)	{		int n = mdl.getVerticesCount ();		oldMessages = new DiscretePotential [n][n];		for (Iterator it = mdl.getEdgeSet().iterator(); it.hasNext();) {			Edge edge = (Edge) it.next();			Variable v1 = (Variable) edge.getVertexA ();			Variable v2 = (Variable) edge.getVertexB ();			int i = mdl.getIndex (v1);			int j = mdl.getIndex (v2);			oldMessages [i][j] = new MultinomialPotential (mdl.get (i));//			oldMessages [i][j].logify();		}	}	private boolean checkConvergence () 	{		for (int i = 0; i < oldMessages.length; i++) {			for (int j = 0; j < oldMessages[i].length; j++) {				DiscretePotential ptl1 = oldMessages [i][j];				DiscretePotential ptl2  = messages [i][j];				if (oldMessages [i][j] != null) {					assert messages [i][j] != null 						: "Message went from nonnull to null "+i+" --> "+j; 					if (logger.isLoggable (Level.FINER)) 						logger.finer ("["+i+"]["+j+"]\n"+oldMessages[i][j]+messages[i][j]);					for (Iterator it = ptl1.assignmentIterator(); it.hasNext();) {						Assignment assn = (Assignment) it.next();						double val1 = ptl1.phi (assn);						double val2 = ptl2.phi (assn);						if (Math.abs (Math.exp(val1) - Math.exp(val2)) > threshold) {							return false;						}					}				}			}		}		return true;	}	public void computeMarginals (UndirectedModel mdl) 	{		hasConverged = false;		initOldMessages (mdl);		int iter;		for (iter = 0; iter < maxIter; iter++) {			logger.fine ("***LoopyMaxProduct iteration "+iter);			super.initForGraph (mdl);			propagate (mdl);			if (checkConvergence ()) break;			oldMessages = messages;		}		if (iter >= maxIter) {			hasConverged = false;			logger.info ("***Loopy BP quitting: not converged after "+maxIter									 +" iterations.");		} else {			hasConverged = true;			logger.info ("***LoopyMaxProduct finished: "+iter+" iterations");		}	}			private void propagate (UndirectedModel mdl)	{		// Send all messages. In this implementation, we send messages		//  synchronously, so it doesn't matter what order we send		//  messages in.		for (Iterator it = mdl.getEdgeSet().iterator(); it.hasNext();) {			Edge edge = (Edge) it.next();			Variable v1 = (Variable) edge.getVertexA ();			Variable v2 = (Variable) edge.getVertexB ();			sendMessage (mdl, v1, v2);			sendMessage (mdl, v2, v1);		}	}	protected DiscretePotential msgProduct (int idx, int excludeMsgFrom)	{//		logger.finest ("Multipling messages to "+idx+" (except for "//									  + excludeMsgFrom + ")");		DiscretePotential product = new MultinomialPotential ();//		product.logify();		for (int j = 0; j < oldMessages[idx].length; j++) {			if ((oldMessages[idx][j] != null) && (j != excludeMsgFrom)) {//				logger.finest ("Multiplying in oldMessages["+idx+"]["+j+"] ="+//											 oldMessages[idx][j]);				product.multiplyBy (oldMessages [idx][j]);			}		}		return product;	}	public DiscretePotential lookupMarginal (Variable v1, Variable v2)	{		int idx1 = mdlCurrent.getIndex (v1);		int idx2 = mdlCurrent.getIndex (v2);		DiscretePotential product1 = msgProduct (idx1, idx2);		DiscretePotential product2 = msgProduct (idx2, idx1);		product1.multiplyBy (product2);		DiscretePotential edgePtl = mdlCurrent.potentialOfEdge (v1, v2);		product1.multiplyBy (edgePtl);		DiscretePotential vertexPtl1 = mdlCurrent.potentialOfVertex (v1);		if (vertexPtl1 != null) {			product1.multiplyBy (vertexPtl1);		}		DiscretePotential vertexPtl2 = mdlCurrent.potentialOfVertex (v2);		if (vertexPtl2 != null) {			product1.multiplyBy (vertexPtl2);		}		DiscretePotential marg = product1.extractMax (new Variable[] { v1, v2 });		return marg;	}	public DiscretePotential lookupMarginal (Clique c)	{		switch (c.size ()) {			case 1: return lookupMarginal (c.get (0));			case 2: return lookupMarginal (c.get (0), c.get (1));			default: throw new IllegalArgumentException 				("LoopyBP currently only supports node and edge cliques.");		}	}}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -