📄 factorgraph.java
字号:
package rmn;import java.util.*;/** * Factor graph implementation. Used here for representing a Markov Network. * * @author Razvan Bunescu */public class FactorGraph { static double TOLERANCE = 1e-3; static int MAX_ITER = 50; // set of variable nodes (entity attributes) TreeSet m_setVars; // set of potential nodes (clique potentials) Vector m_vecPots; public FactorGraph() { m_setVars = new TreeSet(); m_vecPots = new Vector(); } public void addEdges(Potential pot, Variable[] vars) { // create edges between potential node and variable nodes for (int i = 0; i < vars.length; i++) vars[i].attachPotential(pot); pot.attachVariables(vars); // register potential node in factor graph m_vecPots.add(pot); // register variable nodes in factor graph for (int i = 0; i < vars.length; i++) m_setVars.add(vars[i]); } public void allocateMessages() { // allocate variable messages Iterator it = m_setVars.iterator(); while (it.hasNext()) { Variable var = (Variable) it.next(); var.allocateMessages(); } // allocate potentials messages for (int i = 0; i < m_vecPots.size(); i++) { Potential pot = (Potential) m_vecPots.get(i); pot.allocateMessages(); } } public void deallocateMessages() { // deallocate variable messages Iterator it = m_setVars.iterator(); while (it.hasNext()) { Variable var = (Variable) it.next(); var.deallocateMessages(); } // deallocate potentials messages for (int i = 0; i < m_vecPots.size(); i++) { Potential pot = (Potential) m_vecPots.get(i); pot.deallocateMessages(); } } public void backupMessages() { // backup variable messages Iterator it = m_setVars.iterator(); while (it.hasNext()) { Variable var = (Variable) it.next(); var.backupMessages(); } // backup potentials messages for (int i = 0; i < m_vecPots.size(); i++) { Potential pot = (Potential) m_vecPots.get(i); pot.backupMessages(); } } public void cleanup() { Vector vecPots = new Vector(); for (int i = 0; i < m_vecPots.size(); i++) { Potential pot = (Potential) m_vecPots.get(i); if (pot.m_pf.m_bInference) vecPots.add(pot); } m_vecPots = vecPots; Iterator it = m_setVars.iterator(); while (it.hasNext()) { Variable var = (Variable) it.next(); var.cleanup(); } } // bMaximize - true means use max-product, false means use sum-product public boolean beliefPropagation(boolean bMaximize) { // initialize messages Iterator it = m_setVars.iterator(); while (it.hasNext()) { Variable var = (Variable) it.next(); var.initMessages(); } for (int i = 0; i < m_vecPots.size(); i++) { Potential pot = (Potential) m_vecPots.get(i); pot.initMessages(); } // belief propagation boolean bConverged = false; int nIter = 0; int maxIter = 2 * Math.max(m_setVars.size(), m_vecPots.size()); System.out.println("maxIter = " + maxIter); maxIter = Math.min(maxIter, MAX_ITER); while (!bConverged && nIter < maxIter) { backupMessages(); bConverged = true; // send messages to neighbor potentials it = m_setVars.iterator(); while (it.hasNext()) { Variable var = (Variable) it.next(); boolean bConv = var.sendMsgToPots(bMaximize); bConverged = bConverged && bConv; } // send messages to neighbor variables for (int f = 0; f < m_vecPots.size(); f++) { Potential pot = (Potential) m_vecPots.get(f); boolean bConv = pot.sendMsgToVars(bMaximize); bConverged = bConverged && bConv; } if (nIter == 0) bConverged = false; nIter++; } System.out.println("Stopped after " + nIter + " iterations!"); // absorb potential messages & set marginals it = m_setVars.iterator(); while (it.hasNext()) { Variable var = (Variable) it.next(); double[] var_prod = var.newMessage(); for (int f = 0; f < var.m_vecPots.size(); f++) { Potential pot = (Potential) var.m_vecPots.get(f); MathUtils.dotProduct(var_prod, pot.getMessage(var)); } MathUtils.normalize(var_prod); var.setMarginal(var_prod); } return true; } public void setExactMPE() { Iterator it = m_setVars.iterator(); while (it.hasNext()) { Variable var = (Variable) it.next(); if (var.m_bNatural) { beliefPropagation(true); var.setArgmax(); } } } public void setMPE() { beliefPropagation(true); Iterator it = m_setVars.iterator(); while (it.hasNext()) { Variable var = (Variable) it.next(); var.setArgmax(); } } public void setConfMarginal() { beliefPropagation(false); Iterator it = m_setVars.iterator(); while (it.hasNext()) { Variable var = (Variable) it.next(); var.setConfMarginal(); } } public void setInfHidden() { Iterator it = m_setVars.iterator(); while (it.hasNext()) { Variable var = (Variable) it.next(); var.setInfHidden(); } } public void computeTrueCounts() { for (int p = 0; p < m_vecPots.size(); p++) { Potential pot = (Potential) m_vecPots.get(p); int pos[] = new int[pot.m_vars.length]; for (int v = 0; v < pot.m_vars.length; v++) { pos[v] = pot.m_vars[v].getTrueValue(); } pot.m_pf.incTrueCounts(pos); } } public void computeInfCounts() { for (int p = 0; p < m_vecPots.size(); p++) { Potential pot = (Potential) m_vecPots.get(p); int pos[] = new int[pot.m_vars.length]; for (int v = 0; v < pot.m_vars.length; v++) { pos[v] = pot.m_vars[v].getInfValue(); } pot.m_pf.incInfCounts(pos); } } public Vector computeCounts() { Vector vecPF = new Vector(); // compute true counts for (int p = 0; p < m_vecPots.size(); p++) { Potential pot = (Potential) m_vecPots.get(p); if (pot.m_pf.isLearning()) { if (!pot.m_pf.m_bUsed) { vecPF.add(pot.m_pf); pot.m_pf.m_bUsed = true; } int pos[] = new int[pot.m_vars.length]; for (int v = 0; v < pot.m_vars.length; v++) { pos[v] = pot.m_vars[v].getTrueValue(); } pot.m_pf.incTrueCounts(pos); } } // compute inf counts for (int p = 0; p < m_vecPots.size(); p++) { Potential pot = (Potential) m_vecPots.get(p); if (pot.m_pf.isLearning()) { if (!pot.m_pf.m_bUsed) { vecPF.add(pot.m_pf); pot.m_pf.m_bUsed = true; } int pos[] = new int[pot.m_vars.length]; for (int v = 0; v < pot.m_vars.length; v++) { pos[v] = pot.m_vars[v].getInfValue(); } pot.m_pf.incInfCounts(pos); } } return vecPF; } public TreeSet getSetVars() { return m_setVars; } /** * Main method for testing this class. */ static public void main(String[] args) { // create potentials double[] wt1 = {0.5, 0.5}; PotentialFactory1 pf1 = new PotentialFactory1(wt1); Potential pot1 = pf1.newInstance(); double[][] wt2 = {{0.5, 0.5}, {0.5, 0.5}}; PotentialFactory2 pf2 = new PotentialFactory2(wt2); Potential pot2 = pf2.newInstance(); double[][] wt3 = {{0.25, 0.35}, {0.45, 0.55}}; PotentialFactory2 pf3 = new PotentialFactory2(wt3); Potential pot3 = pf3.newInstance(); double[][][] wt4 = {{{0.5, 0.35}, {0.35, 0.5}}, {{0.35, 0.5}, {0.5, 0.35}}}; PotentialFactory3 pf4 = new PotentialFactory3(wt4); Potential pot4 = pf4.newInstance(); // create variables Variable var1 = new Variable("var1", 2); Variable var2 = new Variable("var2", 2); var2.setObserved(0); Variable var3 = new Variable("var3", 2); Variable var4 = new Variable("var4", 2); // create factor graph FactorGraph fg = new FactorGraph(); Variable[] dom1 = new Variable[1]; dom1[0] = var1; fg.addEdges(pot1, dom1); Variable[] dom2 = new Variable[2]; dom2[0] = var1; dom2[1] = var2; fg.addEdges(pot2, dom2); Variable[] dom3 = new Variable[2]; dom3[0] = var1; dom3[1] = var3; fg.addEdges(pot3, dom3); Variable[] dom4 = new Variable[3]; dom4[0] = var2; dom4[1] = var3; dom4[2] = var4; fg.addEdges(pot4, dom4); fg.allocateMessages(); // inference - marginals fg.beliefPropagation(false); // display marginals Iterator it = fg.m_setVars.iterator(); int x = 0; while (it.hasNext()) { Variable var = (Variable) it.next(); System.out.print("Variable " + (x + 1) + " has marginal "); for (int i = 0; i < var.m_nCard; i++) System.out.print(var.m_marginal[i] + " "); System.out.println(); x++; } // inference - MPE fg.setMPE(); // if fast approximate computation is desired // fg.setExactMPE(); // if exact computation desired // display MPE it = fg.m_setVars.iterator(); x = 0; while (it.hasNext()) { Variable var = (Variable) it.next(); System.out.println("Variable " + (x + 1) + " has MPE " + var.getInfValue()); x++; } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -