📄 trp.java
字号:
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.grmm;import java.io.*;import java.util.*;import java.util.Iterator;import salvo.jesus.graph.Graph;import salvo.jesus.graph.Edge;import edu.umass.cs.mallet.base.types.Alphabet;import salvo.jesus.graph.TreeImpl;import java.util.logging.Logger;import edu.umass.cs.mallet.base.util.MalletLogger;import gnu.trove.THashSet;import gnu.trove.TObjectProcedure;import edu.umass.cs.mallet.base.util.Arrays;import gnu.trove.TObjectFunction;import salvo.jesus.graph.GraphImpl;import salvo.jesus.graph.Visitor;import salvo.jesus.graph.algorithm.GraphTraversal;import salvo.jesus.graph.algorithm.BreadthFirstTraversal;import salvo.jesus.graph.Vertex;/** * Implementation of Wainwright's TRP algorithm for approximate inference * in general graphical models. * * @author Charles Sutton * @version $Id: TRP.java,v 1.1 2004/07/15 17:53:31 casutton Exp $ */public class TRP extends BeliefPropagation { private static Logger logger = MalletLogger.getLogger (TRP.class.getName()); /** If true, run lots of paranoid and expensive sanity checks. */ private boolean paranoid = false; private boolean cacheStale = true; private TreeFactory factory; private TerminationCondition terminator; private Inferencer inferencer; /* Make sure that we've included all edges before we terminate. */ private int[][] edgeTouched; /* Number of nodes in the graph being inferred about. (equal to vertexPots.length) */ private int numNodes; private boolean hasConverged; private int iterUsed = 0; static private Random rand = new Random (); public TRP () { this (null, null); } public TRP (TreeFactory f) { this (f, null); } public TRP (TreeFactory f, TerminationCondition cond) { factory = f; terminator = cond; inLogSpace = true; } // Accessors public void setTerminator (TerminationCondition cond) { terminator = cond; } // xxx should this be static? public static void setRandomSeed (long seed) { rand = new Random (seed); } public boolean isConverged () { return hasConverged; } public int iterationsUsed () { return iterUsed; } protected void initForGraph (UndirectedModel m) { super.initForGraph (m); int n = m.getVerticesCount (); edgeTouched = new int [n][n]; numNodes = n; hasConverged = false; if (factory == null) { factory = new AlmostRandomTreeFactory (); } if (terminator == null) { terminator = new DefaultConvergenceTerminator (); } else { terminator.reset (); } } private static Tree graphToTree (final Graph g) throws Exception { final Tree tree = new Tree (); Visitor visitor = new Visitor () { public boolean visit (Vertex v1) { for (Iterator it = g.getAdjacentVertices (v1).iterator(); it.hasNext();) { Vertex v2 = (Vertex) it.next(); if (tree.getParent (v1) != v2) { tree.addNode (v1, v2); assert tree.getParent (v2) == v1; } } return true; } }; GraphTraversal traverser = new BreadthFirstTraversal (g); Vertex start = (Vertex) g.getVertexSet().iterator().next(); tree.add (start); traverser.traverse (start, visitor); return tree; } /** * Interface for tree-generation strategies for TRP. * * TRP works by repeatedly doing exact inference over spanning tree * of the original graph. But the trees can be chosen arbitrarily. * In fact, they don't need to be spanning trees; any acyclic * substructure will do. Users of TRP can tell it which strategy * to use by passing in an implementation of TreeFactory. */ public interface TreeFactory { public Tree nextTree (Graph g); }; // This works around what appears to be a bug in OpenJGraph // connected sets. private static class SimpleUnionFind { private List unionFind = new LinkedList(); private Set findSet (Object obj) { for (Iterator it = unionFind.iterator(); it.hasNext();) { Set set = (Set) it.next(); if (set.contains (obj)) return set; } Set newSet = new THashSet(); newSet.add (obj); unionFind.add (newSet); return newSet; } private void union (Object obj1, Object obj2) { Set set1 = findSet (obj1); Set set2 = findSet (obj2); set1.addAll (set2); unionFind.remove (set2); } private boolean doesNotCauseCycle (Edge edge) { Vertex node1 = edge.getVertexA(); Vertex node2 = edge.getVertexB(); return (findSet (node1) != findSet (node2)); } } /** * Generates random spanning trees. */ // TODO This is probably plagued by OpenJGraph bug I worked // around above. Check this. static public class RandomTreeFactory implements TreeFactory { private Tree graphToTree (final Graph g) throws Exception { final Tree tree = new Tree(); Visitor visitor = new Visitor() { public boolean visit (Vertex v1) { for (Iterator it = g.getAdjacentVertices(v1).iterator(); it.hasNext();) { Vertex v2 = (Vertex) it.next(); if (tree.getParent(v1) != v2) { tree.addNode(v1, v2); } } return true; } }; GraphTraversal traverser = new BreadthFirstTraversal(g); Vertex start = (Vertex) g.getVertexSet().iterator().next(); tree.add(start); traverser.traverse(start, visitor); return tree; } public Tree nextTree (Graph fullGraph) { try { Graph g = nextGraph(fullGraph); return graphToTree(g); } catch (Exception e) { e.printStackTrace(); throw new RuntimeException(e); } } public Graph nextGraph (Graph fullGraph) { try { Graph g = new GraphImpl(); ArrayList edges = new ArrayList(fullGraph.getEdgeSet()); /* * At each iteration, select an edge * that keeps the tree connected and * does not create a loop. */ while (!edges.isEmpty()) { /* select a random unused edge */ int j = rand.nextInt(edges.size()); Edge e = (Edge) edges.get(j); Variable node1 = (Variable) e.getVertexA(); Variable node2 = (Variable) e.getVertexB(); if (!g.getVertexSet().contains(node1) || !g.getVertexSet().contains(node2) || !g.isConnected(node1, node2)) { g.addEdge(node1, node2); } edges.remove(j); } return g; } catch (Exception e) { e.printStackTrace(); throw new RuntimeException(e); } } }; // End RandomTreeFactory /** * Always adds edges that have not been touched, after that * adds random edges. */ public class AlmostRandomTreeFactory implements TreeFactory { public Tree nextTree (Graph fullGraph) { SimpleUnionFind unionFind = new SimpleUnionFind(); ArrayList edges = new ArrayList(fullGraph.getEdgeSet ()); ArrayList goodEdges = new ArrayList(fullGraph.getVerticesCount()); Collections.shuffle (edges, rand); // First add all edges that haven't been used so far try { for (Iterator it = edges.iterator(); it.hasNext();) { Edge edge = (Edge) it.next(); if (!isEdgeTouched (edge) && unionFind.doesNotCauseCycle (edge)) { goodEdges.add (edge); unionFind.union (edge.getVertexA(), edge.getVertexB()); it.remove (); } } // Now add as many other edges as possible for (int i = 0; i < edges.size(); i++) { Edge edge = (Edge)edges.get (i); if (unionFind.doesNotCauseCycle (edge)) { goodEdges.add (edge); unionFind.union (edge.getVertexA(), edge.getVertexB()); } } Graph g = new GraphImpl(); for (Iterator it = goodEdges.iterator(); it.hasNext();) { Edge edge = (Edge) it.next(); g.addEdge (edge); } Tree tree = graphToTree (g);// System.out.println(tree); return tree; } catch (Exception e) { e.printStackTrace (); throw new RuntimeException (e); } } }; /** * Generates spanning trees cyclically from a predefined collection. */ static public class TreeListFactory { private List lst; private Iterator it; public TreeListFactory (List l) { lst = l; it = lst.iterator(); } public TreeListFactory (Tree[] arr) { lst = new ArrayList (java.util.Arrays.asList (arr)); it = lst.iterator(); }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -