⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 trp.java

📁 这是一个matlab的java实现。里面有许多内容。请大家慢慢捉摸。
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
		public Tree nextTree () {			// If no more trees, rewind.			if (!it.hasNext ()) {				it = lst.iterator ();			}			return (Tree) it.next();		}	};	// Termination conditions	// will this need to be subclassed from outside?  Will such	// subclasses need access to the private state of TRP?	static public interface TerminationCondition extends Cloneable {		// This takes the instances of trp as a parameter so that if a		//  TRP instance is cloned, and the terminator copied over, it		//  will still work.		public boolean shouldContinue (TRP trp);		public void reset ();    // boy do I hate Java cloning		public Object clone () throws CloneNotSupportedException;	}	static public class IterationTerminator implements TerminationCondition	{		int current;		int max;		public void reset () { current = 0; }		public IterationTerminator (int m) { max = m; reset (); };		public boolean shouldContinue (TRP trp) { 			current++;			if (current >= max) {				logger.finest ("***TRP quitting: Iteration "+current+" >= "+max);			}			return current <= max; 		};		public Object clone () throws CloneNotSupportedException		{			return super.clone ();		}	}	public static class ConvergenceTerminator implements TerminationCondition	{		DiscretePotential[][] oldMessages;		double delta = 0.01;		public ConvergenceTerminator () {};		public ConvergenceTerminator (double delta) { this.delta = delta; }		public void reset () 		{ 			oldMessages = null;		}		private void copyMessages (TRP trp) 		{			int n = trp.messages.length;			oldMessages = new DiscretePotential [n][n];			for (int i = 0; i < n; i++) {				for (int j = 0; j < n; j++) {					if (trp.messages[i][j] != null) {						oldMessages[i][j] = trp.messages[i][j].duplicate ();					}				}			}		}		private boolean checkForConvergence (TRP trp) 		{			for (int i = 0; i < oldMessages.length; i++) {				for (int j = 0; j < oldMessages[i].length; j++) {					DiscretePotential ptl1 = oldMessages [i][j];					DiscretePotential ptl2  = trp.messages [i][j];					if (oldMessages [i][j] != null) {						assert trp.messages [i][j] != null 							: "Message went from nonnull to null "+i+" --> "+j;												for (Iterator it = ptl1.assignmentIterator(); it.hasNext();) {							Assignment assn = (Assignment) it.next();							double val1 = ptl1.phi (assn);							double val2 = ptl2.phi (assn);							if (Math.abs (val1 - val2) > delta) {								return false;							}						}					}				}			}			return true;		}		public boolean shouldContinue (TRP trp)		{			boolean retval = true;			if (oldMessages != null) 				retval = !checkForConvergence (trp);			copyMessages(trp);						return retval;		}		public Object clone () throws CloneNotSupportedException		{			return super.clone ();		}			}	// Runs until convergence, but doesn't stop until all edges have	// been used at least once, and always stops after 1000 iterations.	public static class DefaultConvergenceTerminator implements TerminationCondition	{		ConvergenceTerminator cterminator; 		IterationTerminator iterminator; 		String msg;				public DefaultConvergenceTerminator () { this (0.001, 1000); }		public DefaultConvergenceTerminator (double delta, int maxIter) {			cterminator = new ConvergenceTerminator (delta);			iterminator = new IterationTerminator (maxIter);			msg = "***TRP quitting: over "+maxIter+" iterations";		}		public void reset () 		{			iterminator.reset ();			cterminator.reset ();		}		// Terminate if converged or at insanely high # of iterations		public boolean shouldContinue (TRP trp)		{			boolean notAllTouched = !trp.allEdgesTouched ();			if (!iterminator.shouldContinue (trp)) {				logger.warning (msg);				if (notAllTouched) {					logger.warning ("***TRP warning: Not all edges used!");				}				return false;			}						if (notAllTouched) {				return true;			} else {				return cterminator.shouldContinue (trp);			}		}		public Object clone () throws CloneNotSupportedException		{			DefaultConvergenceTerminator dup = (DefaultConvergenceTerminator)																				 super.clone ();			dup.iterminator = (IterationTerminator) iterminator.clone ();			dup.cterminator = (ConvergenceTerminator) cterminator.clone ();			return dup;		}			}	// And now, the heart of TRP:	public void computeMarginals (UndirectedModel m) {		initForGraph (m);				int iter = 0;		while (terminator.shouldContinue (this)) {			logger.finer ("TRP iteration "+(iter++));			Tree tree = factory.nextTree (m);			propagate (tree);		}		iterUsed = iter;		logger.info ("***TRP used "+iter+" iterations.");	}	private void propagate (Tree tree) {		Variable root = (Variable) tree.getRoot ();		lambdaPropagation (tree, null, root);		piPropagation (tree, root);	}	private void lambdaPropagation (Tree tree, Variable parent, Variable child)	{		logger.finer ("TRP lambdaPropagation from "+parent);		Iterator it = tree.getChildren (child).iterator();		while (it.hasNext()) {			Variable gchild = (Variable) it.next();			lambdaPropagation (tree, child, gchild);		}		if (parent != null) {			sendMessage (mdlCurrent, child, parent);			// a bit sneaky to put this here...			touchEdge (parent, child);		}	}	private void piPropagation (Tree tree, Variable parent)	{		logger.finer ("TRP piPropagation from "+parent);		Iterator it = tree.getChildren (parent).iterator();		while (it.hasNext()) {			Variable child = (Variable) it.next();			sendMessage (mdlCurrent, parent, child);			piPropagation (tree, child);		}	}	private boolean allEdgesTouched ()	{		Iterator it = mdlCurrent.getEdgeSet().iterator();		while (it.hasNext()) {			Edge edge = (Edge) it.next();			Variable v1 = (Variable) edge.getVertexA();			Variable v2 = (Variable) edge.getVertexB();			int idx1 = mdlCurrent.getIndex (v1);			int idx2 = mdlCurrent.getIndex (v2);			if (edgeTouched [idx1][idx2] == 0) {				logger.finest ("***TRP continuing: edge "+idx1+","+idx2											 +" not touched.");				return false;			}		}		return true;	}	private void touchEdge (Variable v1, Variable v2)	{		int idx1 = mdlCurrent.getIndex (v1);		int idx2 = mdlCurrent.getIndex (v2);		edgeTouched[idx1][idx2]++;		edgeTouched[idx2][idx1]++;	}	private boolean isEdgeTouched (Edge e)	{		Variable v1 = (Variable) e.getVertexA();		Variable v2 = (Variable) e.getVertexB();		int idx1 = mdlCurrent.getIndex (v1);		int idx2 = mdlCurrent.getIndex (v2);		return (edgeTouched[idx1][idx2] > 0);	}	public DiscretePotential lookupMarginal (Variable v1, Variable v2)	{		int idx1 = mdlCurrent.getIndex (v1);		int idx2 = mdlCurrent.getIndex (v2);		DiscretePotential edgePtl = mdlCurrent.potentialOfEdge (v1, v2);		DiscretePotential product = edgePtl.duplicate();		msgProduct (product, idx1, idx2);		msgProduct (product, idx2, idx1);		assert product.varSet().size() == 2;		product.delogify ();		product.normalize ();		return product;	}	// xxx this should be added to the caching inference interface...	//  perhaps to AbstractCachingInferencer	public DiscretePotential lookupMarginal (Clique c) {		switch (c.size ()) {			case 1: return lookupMarginal (c.get (0));			case 2: return lookupMarginal (c.get (0), c.get (1));			default: throw new IllegalArgumentException 				("TRP currently only supports node and edge cliques.");		}	}			public DiscretePotential query (DirectedModel m, Variable var) 	{		throw new UnsupportedOperationException 			("GRMM doesn't yet do directed models.");	}	// copy-paste from LoopyBP	public double lookupLogJoint (Assignment assn)	{		double accum = 0.0;		// Compute using BP-factorization 		// prod_s (p(x_s))^-(deg(s)-1) * ...		for (Iterator it = mdlCurrent.getVerticesIterator(); it.hasNext();) {			Variable var = (Variable) it.next();			DiscretePotential ptl = lookupMarginal (var);			int deg = mdlCurrent.getDegree(var);			if (deg > 1)				accum -= (deg - 1) * Math.log (ptl.phi (assn));		}		 		// ... * prod_{st} p(x_s, x_t)		for (Iterator it = mdlCurrent.getEdgeSet().iterator(); it.hasNext();) {			Edge edge = (Edge) it.next();			Variable v1 = (Variable) edge.getVertexA ();			Variable v2 = (Variable) edge.getVertexB ();			DiscretePotential p12 = lookupMarginal (v1, v2);			DiscretePotential p1 = lookupMarginal (v1);			DiscretePotential p2 = lookupMarginal (v2);			accum += Math.log (p12.phi (assn));		}				return accum;	}	// Deep copy termination condition	public Object clone () {		try {			TRP dup = (TRP) super.clone ();			if (terminator != null)				dup.terminator = (TerminationCondition) terminator.clone ();			return dup;		} catch (CloneNotSupportedException e) {			// should never happen			throw new RuntimeException (e);		}	}}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -