📄 junctiontree.java
字号:
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.grmm;import java.util.Set;import salvo.jesus.graph.Vertex;import java.util.HashSet;import java.util.Iterator;import java.util.Collection;import java.util.ArrayList;import java.util.List;import java.util.Arrays;import salvo.jesus.graph.listener.NullGraphListener;import salvo.jesus.graph.GraphAddEdgeEvent;import salvo.jesus.graph.Edge;import java.io.File;import java.util.Date;import gnu.trove.TIntObjectHashMap;import gnu.trove.THashSet;import gnu.trove.TIntObjectIterator;/** * Datastructure for a junction tree. * * Created: Tue Sep 30 10:30:25 2003 * * @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a> * @version $Id: JunctionTree.java,v 1.2 2004/07/22 00:37:38 casutton Exp $ */public class JunctionTree extends Tree implements GraphicalModel { private int numNodes; private static class Sepset { Sepset(Set s, DiscretePotential p) { set = s; ptl = p; } Set set; DiscretePotential ptl; } private TIntObjectHashMap sepsets; private DiscretePotential[] cpfs; public JunctionTree(int size) { super(); numNodes = size; sepsets = new TIntObjectHashMap(); cpfs = new DiscretePotential[size]; } // JunctionTree constructor public void addNode(Vertex parent1, Vertex child1) { super.addNode(parent1, child1); Clique parent = (Clique) parent1; Clique child = (Clique) child1; Set sepset = parent.intersection(child); int id1 = lookupIndex(parent); int id2 = lookupIndex(child); putSepset(id1, id2, new Sepset(sepset, new MultinomialPotential(sepset))); } private int hashIdxIdx(int id1, int id2) { assert (id1 < 65536) && (id2 < 65536); int id; if (id1 < id2) { id = (id1 << 16) | id2; } else { id = (id2 << 16) | id1; } return id; } private void putSepset(int id1, int id2, Sepset sepset) { int id = hashIdxIdx(id1, id2); sepsets.put(id, sepset); } private Sepset getSepset(int id1, int id2) { int id = hashIdxIdx(id1, id2); return (Sepset) sepsets.get(id); } // CPF accessors public DiscretePotential getCPF(Clique c) { return cpfs[lookupIndex(c)]; } public void setCPF(Clique c, DiscretePotential pot) { cpfs[lookupIndex(c)] = pot; } void clearCPFs() { for (int i = 0; i < cpfs.length; i++) { cpfs[i] = new MultinomialPotential((Clique) lookupVertex(i)); } TIntObjectIterator it = sepsets.iterator(); while (it.hasNext()) { it.advance(); Sepset sepset = (Sepset) it.value(); sepset.ptl = new MultinomialPotential(sepset.set); } } Set sepsetPotentials() { THashSet set = new THashSet(); TIntObjectIterator it = sepsets.iterator(); while (it.hasNext()) { it.advance(); DiscretePotential ptl = ((Sepset) it.value()).ptl; set.add(ptl); } return set; } void setSepsetPot(DiscretePotential pot, Clique v1, Clique v2) { int id1 = lookupIndex(v1); int id2 = lookupIndex(v2); getSepset(id1, id2).ptl = pot; } public DiscretePotential getSepsetPot(Clique v1, Clique v2) { int id1 = lookupIndex(v1); int id2 = lookupIndex(v2); return getSepset(id1, id2).ptl; } public Collection potentials() { HashSet h = new HashSet(); for (int i = 0; i < cpfs.length; i++) { if (cpfs[i] != null) { h.add(cpfs[i]); } } return h; } public Set getSepset(Clique v1, Clique v2) { int id1 = lookupIndex(v1); int id2 = lookupIndex(v2); return getSepset(id1, id2).set; } public DiscretePotential lookupMarginal(Variable var) { Clique c = findParentCluster(var); DiscretePotential pot = getCPF(c); return pot.marginalize(var); } public double lookupLogJoint(Assignment assn) { double accum = 0; for (int i = 0; i < cpfs.length; i++) { if (cpfs[i] != null) { double phi = cpfs[i].phi(assn); if (cpfs[i].isInLogSpace()) { accum += phi; } else { accum += Math.log (phi); } } } TIntObjectIterator it = sepsets.iterator(); while (it.hasNext()) { it.advance(); DiscretePotential ptl = ((Sepset) it.value()).ptl; double phi = ptl.phi (assn); if (ptl.isInLogSpace()) { accum -= phi; } else { accum -= Math.log (phi); } } return accum; } /** Returns a cluster in the tree that contains var. */ public Clique findParentCluster(Variable var) { int best = Integer.MAX_VALUE; Clique retval = null; // xxx Inefficient for (Iterator it = getVerticesIterator(); it.hasNext();) { Clique c = (Clique) it.next(); if (c.contains(var) && c.weight() < best) { retval = c; best = c.weight(); } } return retval; } /** * Returns a cluster in the tree that contains all the vars in a * collection. */ public Clique findParentCluster(Collection vars) { int best = Integer.MAX_VALUE; Clique retval = null; // xxx Inefficient for (Iterator it = getVerticesIterator(); it.hasNext();) { Clique c = (Clique) it.next(); if (c.containsAll(vars) && c.weight() < best) { retval = c; best = c.weight(); } } return retval; } /** Returns a cluster in the tree that contains exactly the given * variables, or null if no such cluster exists. */ public Clique findCluster(Variable[] vars) { List l = Arrays.asList(vars); for (Iterator it = getVerticesIterator(); it.hasNext();) { Clique c2 = (Clique) it.next(); if (c2.containsAll(l) && l.containsAll(c2)) return c2; } return null; } /** Normalizes all potentials in the tree, both node and sepset. */ public void normalizeAll() { int n = cpfs.length; for (int i = 0; i < n; i++) { if (cpfs[i] != null) { cpfs[i].normalize(); } } TIntObjectIterator it = sepsets.iterator(); while (it.hasNext()) { it.advance(); DiscretePotential ptl = ((Sepset) it.value()).ptl; ptl.normalize(); } } // I have a feeling I'll regret this in the morning.... void logify() { int n = cpfs.length; for (int i = 0; i < n; i++) { if (cpfs[i] != null) { cpfs[i].logify(); } } TIntObjectIterator it = sepsets.iterator(); while (it.hasNext()) { it.advance(); DiscretePotential ptl = ((Sepset) it.value()).ptl; ptl.logify(); } } // even more, I have a feeling I'll regret this in the morning.... public void delogify() { int n = cpfs.length; for (int i = 0; i < n; i++) { if (cpfs[i] != null) { cpfs[i].delogify(); } } TIntObjectIterator it = sepsets.iterator(); while (it.hasNext()) { it.advance(); DiscretePotential ptl = ((Sepset) it.value()).ptl; ptl.delogify(); } } int getId(Clique c) { return lookupIndex(c); }// Implementation of GraphicalModel public double query(Inferencer inferencer, Assignment assn) { return inferencer.query(this, assn); } public void computeMarginals(Inferencer inferencer) { inferencer.computeMarginals(this); }// Debugging functions public void dump() { int n = cpfs.length; // This will cause OpenJGraph to print all our nodes and edges System.out.println(this); // Now lets print all the cpfs System.out.println("Vertex CPFs"); for (int i = 0; i < n; i++) { if (cpfs[i] != null) { System.out.println("CPF "+i+" "+cpfs[i]); } } // And the sepset potentials System.out.println("sepset CPFs"); TIntObjectIterator it = sepsets.iterator(); while (it.hasNext()) { it.advance(); DiscretePotential ptl = ((Sepset) it.value()).ptl; System.out.println(ptl); } System.out.println ("/End JT"); } public double dumpLogJoint (Assignment assn) { double accum = 0; for (int i = 0; i < cpfs.length; i++) { if (cpfs[i] != null) { double phi = cpfs[i].phi(assn); if (cpfs[i].isInLogSpace()) { accum += phi; } else { accum += Math.log (phi); } System.out.println ("CPF "+i+" accum = "+accum); } } TIntObjectIterator it = sepsets.iterator(); while (it.hasNext()) { it.advance(); DiscretePotential ptl = ((Sepset) it.value()).ptl; double phi = ptl.phi(assn); if (ptl.isInLogSpace()) { accum -= phi; } else { accum -= Math.log (phi); } System.out.println("Sepset "+ptl.varSet()+" accum "+accum); } return accum; } public boolean isNaN() { int n = cpfs.length; for (int i = 0; i < n; i++) if (cpfs[i].isNaN()) return true; // And the sepset potentials TIntObjectIterator it = sepsets.iterator(); while (it.hasNext()) { it.advance(); DiscretePotential ptl = ((Sepset) it.value()).ptl; if (ptl.isNaN()) return true; } return false; }// Implementation of edu.umass.cs.mallet.users.casutton.graphical.Compactible public void decompact() { cpfs = new DiscretePotential[numNodes]; clearCPFs(); } public void compact() { cpfs = null; }} // JunctionTree
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -