⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 factorgraph.java

📁 是实现关系型贝叶斯网络一中机器学习算法
💻 JAVA
字号:
package rmn;import java.util.*;/** * Factor graph implementation. Used here for representing a Markov Network. *  * @author Razvan Bunescu */public class FactorGraph {  static double TOLERANCE = 1e-3;  static int MAX_ITER = 50;  // set of variable nodes (entity attributes)  TreeSet m_setVars;  // set of potential nodes (clique potentials)  Vector m_vecPots;  public FactorGraph()  {    m_setVars = new TreeSet();    m_vecPots = new Vector();  }  public void addEdges(Potential pot, Variable[] vars)  {    // create edges between potential node and variable nodes     for (int i = 0; i < vars.length; i++)      vars[i].attachPotential(pot);    pot.attachVariables(vars);    // register potential node in factor graph    m_vecPots.add(pot);    // register variable nodes in factor graph    for (int i = 0; i < vars.length; i++)      m_setVars.add(vars[i]);  }  public void allocateMessages()  {    // allocate variable messages    Iterator it = m_setVars.iterator();    while (it.hasNext()) {      Variable var = (Variable) it.next();      var.allocateMessages();    }    // allocate potentials messages    for (int i = 0; i < m_vecPots.size(); i++) {      Potential pot = (Potential) m_vecPots.get(i);      pot.allocateMessages();    }  }  public void deallocateMessages()  {    // deallocate variable messages    Iterator it = m_setVars.iterator();    while (it.hasNext()) {      Variable var = (Variable) it.next();      var.deallocateMessages();    }    // deallocate potentials messages    for (int i = 0; i < m_vecPots.size(); i++) {      Potential pot = (Potential) m_vecPots.get(i);      pot.deallocateMessages();    }  }  public void backupMessages()  {    // backup variable messages    Iterator it = m_setVars.iterator();    while (it.hasNext()) {      Variable var = (Variable) it.next();      var.backupMessages();    }    // backup potentials messages    for (int i = 0; i < m_vecPots.size(); i++) {      Potential pot = (Potential) m_vecPots.get(i);      pot.backupMessages();    }      }  public void cleanup()  {    Vector vecPots = new Vector();    for (int i = 0; i < m_vecPots.size(); i++) {      Potential pot = (Potential) m_vecPots.get(i);      if (pot.m_pf.m_bInference)	vecPots.add(pot);    }    m_vecPots = vecPots;        Iterator it = m_setVars.iterator();    while (it.hasNext()) {      Variable var = (Variable) it.next();      var.cleanup();    }  }  // bMaximize - true means use max-product, false means use sum-product  public boolean beliefPropagation(boolean bMaximize)  {    // initialize messages    Iterator it = m_setVars.iterator();    while (it.hasNext()) {      Variable var = (Variable) it.next();      var.initMessages();    }    for (int i = 0; i < m_vecPots.size(); i++) {      Potential pot = (Potential) m_vecPots.get(i);      pot.initMessages();    }        // belief propagation    boolean bConverged = false;    int nIter = 0;    int maxIter = 2 * Math.max(m_setVars.size(), m_vecPots.size());    System.out.println("maxIter = " + maxIter);    maxIter = Math.min(maxIter, MAX_ITER);    while (!bConverged && nIter < maxIter) {      backupMessages();      bConverged = true;      // send messages to neighbor potentials      it = m_setVars.iterator();      while (it.hasNext()) {        Variable var = (Variable) it.next();	boolean bConv = var.sendMsgToPots(bMaximize);	bConverged = bConverged && bConv;      }            // send messages to neighbor variables      for (int f = 0; f < m_vecPots.size(); f++) {        Potential pot = (Potential) m_vecPots.get(f);	boolean bConv = pot.sendMsgToVars(bMaximize);	bConverged = bConverged && bConv;      }            if (nIter == 0)        bConverged = false;      nIter++;    }        System.out.println("Stopped after " + nIter + " iterations!");        // absorb potential messages & set marginals    it = m_setVars.iterator();    while (it.hasNext()) {      Variable var = (Variable) it.next();      double[] var_prod = var.newMessage();      for (int f = 0; f < var.m_vecPots.size(); f++) {	Potential pot = (Potential) var.m_vecPots.get(f);	MathUtils.dotProduct(var_prod, pot.getMessage(var));      }      MathUtils.normalize(var_prod);      var.setMarginal(var_prod);    }        return true;  }  public void setExactMPE()  {    Iterator it = m_setVars.iterator();    while (it.hasNext()) {      Variable var = (Variable) it.next();      if (var.m_bNatural) {	beliefPropagation(true);	var.setArgmax();      }    }  }  public void setMPE()  {    beliefPropagation(true);        Iterator it = m_setVars.iterator();    while (it.hasNext()) {      Variable var = (Variable) it.next();      var.setArgmax();    }  }  public void setConfMarginal()  {    beliefPropagation(false);        Iterator it = m_setVars.iterator();    while (it.hasNext()) {      Variable var = (Variable) it.next();      var.setConfMarginal();    }  }  public void setInfHidden()  {    Iterator it = m_setVars.iterator();    while (it.hasNext()) {      Variable var = (Variable) it.next();      var.setInfHidden();    }  }    public void computeTrueCounts()  {    for (int p = 0; p < m_vecPots.size(); p++) {      Potential pot = (Potential) m_vecPots.get(p);      int pos[] = new int[pot.m_vars.length];      for (int v = 0; v < pot.m_vars.length; v++) {	pos[v] = pot.m_vars[v].getTrueValue();      }      pot.m_pf.incTrueCounts(pos);    }  }  public void computeInfCounts()  {    for (int p = 0; p < m_vecPots.size(); p++) {      Potential pot = (Potential) m_vecPots.get(p);      int pos[] = new int[pot.m_vars.length];      for (int v = 0; v < pot.m_vars.length; v++) {	pos[v] = pot.m_vars[v].getInfValue();      }      pot.m_pf.incInfCounts(pos);    }  }  public Vector computeCounts()  {    Vector vecPF = new Vector();    // compute true counts    for (int p = 0; p < m_vecPots.size(); p++) {      Potential pot = (Potential) m_vecPots.get(p);      if (pot.m_pf.isLearning()) {	if (!pot.m_pf.m_bUsed) {	  vecPF.add(pot.m_pf);	  pot.m_pf.m_bUsed = true;	}	int pos[] = new int[pot.m_vars.length];	for (int v = 0; v < pot.m_vars.length; v++) {	  pos[v] = pot.m_vars[v].getTrueValue();	}	pot.m_pf.incTrueCounts(pos);      }    }    // compute inf counts    for (int p = 0; p < m_vecPots.size(); p++) {      Potential pot = (Potential) m_vecPots.get(p);      if (pot.m_pf.isLearning()) {      	if (!pot.m_pf.m_bUsed) {	  vecPF.add(pot.m_pf);	  pot.m_pf.m_bUsed = true;	}	int pos[] = new int[pot.m_vars.length];	for (int v = 0; v < pot.m_vars.length; v++) {	  pos[v] = pot.m_vars[v].getInfValue();	}	pot.m_pf.incInfCounts(pos);      }    }        return vecPF;  }   public TreeSet getSetVars() {    return m_setVars;  }  /**   * Main method for testing this class.   */  static public void main(String[] args)  {    // create potentials    double[] wt1 = {0.5, 0.5};    PotentialFactory1 pf1 = new PotentialFactory1(wt1);    Potential pot1 = pf1.newInstance();    double[][] wt2 = {{0.5, 0.5}, {0.5, 0.5}};    PotentialFactory2 pf2 = new PotentialFactory2(wt2);    Potential pot2 = pf2.newInstance();    double[][] wt3 = {{0.25, 0.35}, {0.45, 0.55}};    PotentialFactory2 pf3 = new PotentialFactory2(wt3);    Potential pot3 = pf3.newInstance();    double[][][] wt4 = {{{0.5, 0.35}, {0.35, 0.5}}, {{0.35, 0.5}, {0.5, 0.35}}};    PotentialFactory3 pf4 = new PotentialFactory3(wt4);    Potential pot4 = pf4.newInstance();    // create variables    Variable var1 = new Variable("var1", 2);    Variable var2 = new Variable("var2", 2);    var2.setObserved(0);    Variable var3 = new Variable("var3", 2);    Variable var4 = new Variable("var4", 2);    // create factor graph    FactorGraph fg = new FactorGraph();    Variable[] dom1 = new Variable[1];    dom1[0] = var1;    fg.addEdges(pot1, dom1);    Variable[] dom2 = new Variable[2];    dom2[0] = var1;    dom2[1] = var2;    fg.addEdges(pot2, dom2);    Variable[] dom3 = new Variable[2];    dom3[0] = var1;    dom3[1] = var3;    fg.addEdges(pot3, dom3);    Variable[] dom4 = new Variable[3];    dom4[0] = var2;    dom4[1] = var3;    dom4[2] = var4;    fg.addEdges(pot4, dom4);        fg.allocateMessages();    // inference - marginals    fg.beliefPropagation(false);        // display marginals    Iterator it = fg.m_setVars.iterator();    int x = 0;    while (it.hasNext()) {      Variable var = (Variable) it.next();      System.out.print("Variable " + (x + 1) + " has marginal ");      for (int i = 0; i < var.m_nCard; i++)        System.out.print(var.m_marginal[i] + " ");      System.out.println();      x++;    }    // inference - MPE    fg.setMPE(); // if fast approximate computation is desired    // fg.setExactMPE(); // if exact computation desired        // display MPE    it = fg.m_setVars.iterator();    x = 0;    while (it.hasNext()) {      Variable var = (Variable) it.next();      System.out.println("Variable " + (x + 1) + " has MPE " + var.getInfValue());      x++;    }  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -