⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 junctiontree.java

📁 这是一个matlab的java实现。里面有许多内容。请大家慢慢捉摸。
💻 JAVA
字号:
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This software is provided under the terms of the Common Public License,   version 1.0, as published by http://www.opensource.org.  For further   information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.grmm;import java.util.Set;import salvo.jesus.graph.Vertex;import java.util.HashSet;import java.util.Iterator;import java.util.Collection;import java.util.ArrayList;import java.util.List;import java.util.Arrays;import salvo.jesus.graph.listener.NullGraphListener;import salvo.jesus.graph.GraphAddEdgeEvent;import salvo.jesus.graph.Edge;import java.io.File;import java.util.Date;import gnu.trove.TIntObjectHashMap;import gnu.trove.THashSet;import gnu.trove.TIntObjectIterator;/** * Datastructure for a junction tree. * * Created: Tue Sep 30 10:30:25 2003 * * @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a> * @version $Id: JunctionTree.java,v 1.2 2004/07/22 00:37:38 casutton Exp $ */public class JunctionTree extends Tree implements GraphicalModel {  private int numNodes;  private static class Sepset {    Sepset(Set s, DiscretePotential p)    {      set = s;      ptl = p;    }    Set set;    DiscretePotential ptl;  }  private TIntObjectHashMap sepsets;  private DiscretePotential[] cpfs;  public JunctionTree(int size)  {    super();    numNodes = size;    sepsets = new TIntObjectHashMap();    cpfs = new DiscretePotential[size];  } // JunctionTree constructor  public void addNode(Vertex parent1, Vertex child1)  {    super.addNode(parent1, child1);    Clique parent = (Clique) parent1;    Clique child = (Clique) child1;    Set sepset = parent.intersection(child);    int id1 = lookupIndex(parent);    int id2 = lookupIndex(child);    putSepset(id1, id2, new Sepset(sepset, new MultinomialPotential(sepset)));  }  private int hashIdxIdx(int id1, int id2)  {    assert (id1 < 65536) && (id2 < 65536);    int id;    if (id1 < id2) {      id = (id1 << 16) | id2;    } else {      id = (id2 << 16) | id1;    }    return id;  }  private void putSepset(int id1, int id2, Sepset sepset)  {    int id = hashIdxIdx(id1, id2);    sepsets.put(id, sepset);  }  private Sepset getSepset(int id1, int id2)  {    int id = hashIdxIdx(id1, id2);    return (Sepset) sepsets.get(id);  }  //  CPF accessors  public DiscretePotential getCPF(Clique c)  {    return cpfs[lookupIndex(c)];  }  public void setCPF(Clique c, DiscretePotential pot)  {    cpfs[lookupIndex(c)] = pot;  }  void clearCPFs()  {    for (int i = 0; i < cpfs.length; i++) {      cpfs[i] = new MultinomialPotential((Clique) lookupVertex(i));    }    TIntObjectIterator it = sepsets.iterator();    while (it.hasNext()) {      it.advance();      Sepset sepset = (Sepset) it.value();      sepset.ptl = new MultinomialPotential(sepset.set);    }  }  Set sepsetPotentials()  {    THashSet set = new THashSet();    TIntObjectIterator it = sepsets.iterator();    while (it.hasNext()) {      it.advance();      DiscretePotential ptl = ((Sepset) it.value()).ptl;      set.add(ptl);    }    return set;  }  void setSepsetPot(DiscretePotential pot, Clique v1, Clique v2)  {    int id1 = lookupIndex(v1);    int id2 = lookupIndex(v2);    getSepset(id1, id2).ptl = pot;  }  public DiscretePotential getSepsetPot(Clique v1, Clique v2)  {    int id1 = lookupIndex(v1);    int id2 = lookupIndex(v2);    return getSepset(id1, id2).ptl;  }  public Collection potentials()  {    HashSet h = new HashSet();    for (int i = 0; i < cpfs.length; i++) {      if (cpfs[i] != null) {        h.add(cpfs[i]);      }    }    return h;  }  public Set getSepset(Clique v1, Clique v2)  {    int id1 = lookupIndex(v1);    int id2 = lookupIndex(v2);    return getSepset(id1, id2).set;  }  public DiscretePotential lookupMarginal(Variable var)  {    Clique c = findParentCluster(var);    DiscretePotential pot = getCPF(c);    return pot.marginalize(var);  }  public double lookupLogJoint(Assignment assn)  {    double accum = 0;    for (int i = 0; i < cpfs.length; i++) {      if (cpfs[i] != null) {				double phi = cpfs[i].phi(assn);				if (cpfs[i].isInLogSpace()) {					accum += phi;				} else {					accum += Math.log (phi);				}      }    }    TIntObjectIterator it = sepsets.iterator();    while (it.hasNext()) {      it.advance();      DiscretePotential ptl = ((Sepset) it.value()).ptl;			double phi = ptl.phi (assn);			if (ptl.isInLogSpace()) {				accum -= phi;			} else {				accum -= Math.log (phi);			}    }    return accum;  }  /** Returns a cluster in the tree that contains var. */  public Clique findParentCluster(Variable var)  {    int best = Integer.MAX_VALUE;    Clique retval = null;    // xxx Inefficient    for (Iterator it = getVerticesIterator(); it.hasNext();) {      Clique c = (Clique) it.next();      if (c.contains(var) && c.weight() < best) {        retval = c;        best = c.weight();      }    }    return retval;  }  /**   * Returns a cluster in the tree that contains all the vars in a   *   collection.   */  public Clique findParentCluster(Collection vars)  {    int best = Integer.MAX_VALUE;    Clique retval = null;    // xxx Inefficient    for (Iterator it = getVerticesIterator(); it.hasNext();) {      Clique c = (Clique) it.next();      if (c.containsAll(vars) && c.weight() < best) {        retval = c;        best = c.weight();      }    }    return retval;  }  /** Returns a cluster in the tree that contains exactly the given   * 	variables, or null if no such cluster exists. */  public Clique findCluster(Variable[] vars)  {    List l = Arrays.asList(vars);    for (Iterator it = getVerticesIterator(); it.hasNext();) {      Clique c2 = (Clique) it.next();      if (c2.containsAll(l) && l.containsAll(c2))        return c2;    }    return null;  }  /** Normalizes all potentials in the tree, both node and sepset. */  public void normalizeAll()  {    int n = cpfs.length;    for (int i = 0; i < n; i++) {      if (cpfs[i] != null) {        cpfs[i].normalize();      }    }    TIntObjectIterator it = sepsets.iterator();    while (it.hasNext()) {      it.advance();      DiscretePotential ptl = ((Sepset) it.value()).ptl;      ptl.normalize();    }  }  // I have a feeling I'll regret this in the morning....  void logify()  {    int n = cpfs.length;    for (int i = 0; i < n; i++) {      if (cpfs[i] != null) {        cpfs[i].logify();      }    }    TIntObjectIterator it = sepsets.iterator();    while (it.hasNext()) {      it.advance();      DiscretePotential ptl = ((Sepset) it.value()).ptl;      ptl.logify();    }  }  // even more, I have a feeling I'll regret this in the morning....  public void delogify()  {    int n = cpfs.length;    for (int i = 0; i < n; i++) {      if (cpfs[i] != null) {        cpfs[i].delogify();      }    }    TIntObjectIterator it = sepsets.iterator();    while (it.hasNext()) {      it.advance();      DiscretePotential ptl = ((Sepset) it.value()).ptl;      ptl.delogify();    }  }  int getId(Clique c)  {    return lookupIndex(c);  }// Implementation of GraphicalModel  public double query(Inferencer inferencer, Assignment assn)  {    return inferencer.query(this, assn);  }  public void computeMarginals(Inferencer inferencer)  {    inferencer.computeMarginals(this);  }// Debugging functions  public void dump()  {    int n = cpfs.length;    // This will cause OpenJGraph to print all our nodes and edges    System.out.println(this);    // Now lets print all the cpfs    System.out.println("Vertex CPFs");    for (int i = 0; i < n; i++) {      if (cpfs[i] != null) {        System.out.println("CPF "+i+" "+cpfs[i]);      }    }    // And the sepset potentials    System.out.println("sepset CPFs");    TIntObjectIterator it = sepsets.iterator();    while (it.hasNext()) {      it.advance();      DiscretePotential ptl = ((Sepset) it.value()).ptl;      System.out.println(ptl);    }		System.out.println ("/End JT");  }  public double dumpLogJoint (Assignment assn)  {    double accum = 0;    for (int i = 0; i < cpfs.length; i++) {      if (cpfs[i] != null) {        double phi = cpfs[i].phi(assn);				if (cpfs[i].isInLogSpace()) {					accum += phi;				} else {					accum += Math.log (phi);				}				System.out.println ("CPF "+i+" accum = "+accum);      }    }    TIntObjectIterator it = sepsets.iterator();    while (it.hasNext()) {      it.advance();      DiscretePotential ptl = ((Sepset) it.value()).ptl;      double phi = ptl.phi(assn);			if (ptl.isInLogSpace()) {				accum -= phi;			} else {				accum -= Math.log (phi);			}			System.out.println("Sepset "+ptl.varSet()+" accum "+accum);    }    return accum;  }  public boolean isNaN()  {    int n = cpfs.length;    for (int i = 0; i < n; i++)      if (cpfs[i].isNaN()) return true;    // And the sepset potentials    TIntObjectIterator it = sepsets.iterator();    while (it.hasNext()) {      it.advance();      DiscretePotential ptl = ((Sepset) it.value()).ptl;			if (ptl.isNaN()) return true;    }		return false;  }// Implementation of edu.umass.cs.mallet.users.casutton.graphical.Compactible  public void decompact()  {    cpfs = new DiscretePotential[numNodes];    clearCPFs();  }  public void compact()  {    cpfs = null;  }} // JunctionTree

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -