📄 trp.java
字号:
public Tree nextTree () { // If no more trees, rewind. if (!it.hasNext ()) { it = lst.iterator (); } return (Tree) it.next(); } }; // Termination conditions // will this need to be subclassed from outside? Will such // subclasses need access to the private state of TRP? static public interface TerminationCondition extends Cloneable { // This takes the instances of trp as a parameter so that if a // TRP instance is cloned, and the terminator copied over, it // will still work. public boolean shouldContinue (TRP trp); public void reset (); // boy do I hate Java cloning public Object clone () throws CloneNotSupportedException; } static public class IterationTerminator implements TerminationCondition { int current; int max; public void reset () { current = 0; } public IterationTerminator (int m) { max = m; reset (); }; public boolean shouldContinue (TRP trp) { current++; if (current >= max) { logger.finest ("***TRP quitting: Iteration "+current+" >= "+max); } return current <= max; }; public Object clone () throws CloneNotSupportedException { return super.clone (); } } public static class ConvergenceTerminator implements TerminationCondition { DiscretePotential[][] oldMessages; double delta = 0.01; public ConvergenceTerminator () {}; public ConvergenceTerminator (double delta) { this.delta = delta; } public void reset () { oldMessages = null; } private void copyMessages (TRP trp) { int n = trp.messages.length; oldMessages = new DiscretePotential [n][n]; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { if (trp.messages[i][j] != null) { oldMessages[i][j] = trp.messages[i][j].duplicate (); } } } } private boolean checkForConvergence (TRP trp) { for (int i = 0; i < oldMessages.length; i++) { for (int j = 0; j < oldMessages[i].length; j++) { DiscretePotential ptl1 = oldMessages [i][j]; DiscretePotential ptl2 = trp.messages [i][j]; if (oldMessages [i][j] != null) { assert trp.messages [i][j] != null : "Message went from nonnull to null "+i+" --> "+j; for (Iterator it = ptl1.assignmentIterator(); it.hasNext();) { Assignment assn = (Assignment) it.next(); double val1 = ptl1.phi (assn); double val2 = ptl2.phi (assn); if (Math.abs (val1 - val2) > delta) { return false; } } } } } return true; } public boolean shouldContinue (TRP trp) { boolean retval = true; if (oldMessages != null) retval = !checkForConvergence (trp); copyMessages(trp); return retval; } public Object clone () throws CloneNotSupportedException { return super.clone (); } } // Runs until convergence, but doesn't stop until all edges have // been used at least once, and always stops after 1000 iterations. public static class DefaultConvergenceTerminator implements TerminationCondition { ConvergenceTerminator cterminator; IterationTerminator iterminator; String msg; public DefaultConvergenceTerminator () { this (0.001, 1000); } public DefaultConvergenceTerminator (double delta, int maxIter) { cterminator = new ConvergenceTerminator (delta); iterminator = new IterationTerminator (maxIter); msg = "***TRP quitting: over "+maxIter+" iterations"; } public void reset () { iterminator.reset (); cterminator.reset (); } // Terminate if converged or at insanely high # of iterations public boolean shouldContinue (TRP trp) { boolean notAllTouched = !trp.allEdgesTouched (); if (!iterminator.shouldContinue (trp)) { logger.warning (msg); if (notAllTouched) { logger.warning ("***TRP warning: Not all edges used!"); } return false; } if (notAllTouched) { return true; } else { return cterminator.shouldContinue (trp); } } public Object clone () throws CloneNotSupportedException { DefaultConvergenceTerminator dup = (DefaultConvergenceTerminator) super.clone (); dup.iterminator = (IterationTerminator) iterminator.clone (); dup.cterminator = (ConvergenceTerminator) cterminator.clone (); return dup; } } // And now, the heart of TRP: public void computeMarginals (UndirectedModel m) { initForGraph (m); int iter = 0; while (terminator.shouldContinue (this)) { logger.finer ("TRP iteration "+(iter++)); Tree tree = factory.nextTree (m); propagate (tree); } iterUsed = iter; logger.info ("***TRP used "+iter+" iterations."); } private void propagate (Tree tree) { Variable root = (Variable) tree.getRoot (); lambdaPropagation (tree, null, root); piPropagation (tree, root); } private void lambdaPropagation (Tree tree, Variable parent, Variable child) { logger.finer ("TRP lambdaPropagation from "+parent); Iterator it = tree.getChildren (child).iterator(); while (it.hasNext()) { Variable gchild = (Variable) it.next(); lambdaPropagation (tree, child, gchild); } if (parent != null) { sendMessage (mdlCurrent, child, parent); // a bit sneaky to put this here... touchEdge (parent, child); } } private void piPropagation (Tree tree, Variable parent) { logger.finer ("TRP piPropagation from "+parent); Iterator it = tree.getChildren (parent).iterator(); while (it.hasNext()) { Variable child = (Variable) it.next(); sendMessage (mdlCurrent, parent, child); piPropagation (tree, child); } } private boolean allEdgesTouched () { Iterator it = mdlCurrent.getEdgeSet().iterator(); while (it.hasNext()) { Edge edge = (Edge) it.next(); Variable v1 = (Variable) edge.getVertexA(); Variable v2 = (Variable) edge.getVertexB(); int idx1 = mdlCurrent.getIndex (v1); int idx2 = mdlCurrent.getIndex (v2); if (edgeTouched [idx1][idx2] == 0) { logger.finest ("***TRP continuing: edge "+idx1+","+idx2 +" not touched."); return false; } } return true; } private void touchEdge (Variable v1, Variable v2) { int idx1 = mdlCurrent.getIndex (v1); int idx2 = mdlCurrent.getIndex (v2); edgeTouched[idx1][idx2]++; edgeTouched[idx2][idx1]++; } private boolean isEdgeTouched (Edge e) { Variable v1 = (Variable) e.getVertexA(); Variable v2 = (Variable) e.getVertexB(); int idx1 = mdlCurrent.getIndex (v1); int idx2 = mdlCurrent.getIndex (v2); return (edgeTouched[idx1][idx2] > 0); } public DiscretePotential lookupMarginal (Variable v1, Variable v2) { int idx1 = mdlCurrent.getIndex (v1); int idx2 = mdlCurrent.getIndex (v2); DiscretePotential edgePtl = mdlCurrent.potentialOfEdge (v1, v2); DiscretePotential product = edgePtl.duplicate(); msgProduct (product, idx1, idx2); msgProduct (product, idx2, idx1); assert product.varSet().size() == 2; product.delogify (); product.normalize (); return product; } // xxx this should be added to the caching inference interface... // perhaps to AbstractCachingInferencer public DiscretePotential lookupMarginal (Clique c) { switch (c.size ()) { case 1: return lookupMarginal (c.get (0)); case 2: return lookupMarginal (c.get (0), c.get (1)); default: throw new IllegalArgumentException ("TRP currently only supports node and edge cliques."); } } public DiscretePotential query (DirectedModel m, Variable var) { throw new UnsupportedOperationException ("GRMM doesn't yet do directed models."); } // copy-paste from LoopyBP public double lookupLogJoint (Assignment assn) { double accum = 0.0; // Compute using BP-factorization // prod_s (p(x_s))^-(deg(s)-1) * ... for (Iterator it = mdlCurrent.getVerticesIterator(); it.hasNext();) { Variable var = (Variable) it.next(); DiscretePotential ptl = lookupMarginal (var); int deg = mdlCurrent.getDegree(var); if (deg > 1) accum -= (deg - 1) * Math.log (ptl.phi (assn)); } // ... * prod_{st} p(x_s, x_t) for (Iterator it = mdlCurrent.getEdgeSet().iterator(); it.hasNext();) { Edge edge = (Edge) it.next(); Variable v1 = (Variable) edge.getVertexA (); Variable v2 = (Variable) edge.getVertexB (); DiscretePotential p12 = lookupMarginal (v1, v2); DiscretePotential p1 = lookupMarginal (v1); DiscretePotential p2 = lookupMarginal (v2); accum += Math.log (p12.phi (assn)); } return accum; } // Deep copy termination condition public Object clone () { try { TRP dup = (TRP) super.clone (); if (terminator != null) dup.terminator = (TerminationCondition) terminator.clone (); return dup; } catch (CloneNotSupportedException e) { // should never happen throw new RuntimeException (e); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -