⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 potential.java

📁 是实现关系型贝叶斯网络一中机器学习算法
💻 JAVA
字号:
package rmn;import java.util.*;public class Potential {  // potential factory, used for getting potential weights  PotentialFactory m_pf;    // attached variables  Variable[] m_vars;  // array of messages, one message for each variable  double[][] m_msgPotToVar;  double[][] m_oldPotToVar;  public Potential(PotentialFactory pf)  {    m_pf = pf;  }  public void allocateMessages()  {    m_msgPotToVar = new double[m_vars.length][];    m_oldPotToVar = new double[m_vars.length][];    for (int i = 0; i < m_vars.length; i++) {      m_msgPotToVar[i] = new double[m_vars[i].m_nCard];      m_oldPotToVar[i] = new double[m_vars[i].m_nCard];    }  }  public void deallocateMessages()  {    m_msgPotToVar = null;    m_oldPotToVar = null;  }  public void attachVariables(Variable[] vars)  {    //    assert vars.length == m_pf.size() : vars.length;    //    m_vars = new Variable[m_pf.size()];    m_vars = new Variable[vars.length];    for (int i = 0; i < m_vars.length; i++)      m_vars[i] = vars[i];  }    public double[] getMessage(Variable var)  {    for (int i = 0; i < m_vars.length; i++)       if (m_vars[i].compareTo(var) == 0)        return m_msgPotToVar[i];    return null;  }  public double[] getOldMessage(Variable var)  {    for (int i = 0; i < m_vars.length; i++)       if (m_vars[i].compareTo(var) == 0)        return m_oldPotToVar[i];    return null;  }  public double[] getMessage(int var)  {    return m_msgPotToVar[var];  }  public double[] getOldMessage(int var)  {    return m_oldPotToVar[var];  }  public void setMessage(int var, double[] msg)  {    assert msg.length == m_vars[var].m_nCard: msg.length;    System.arraycopy(msg, 0, m_msgPotToVar[var], 0, msg.length);  }  public void initMessages()  {    for (int i = 0; i < m_vars.length; i++)      Arrays.fill(m_msgPotToVar[i], 1);  }  public Matrix newWeightMatrix()  {    Matrix m = m_pf.newWeightMatrix();    m.fill(1);    return m;  }  public Matrix getWeightMatrix()  {    return m_pf.getWeightMatrix();  }  public void backupMessages()  {    for (int i = 0; i < m_vars.length; i++)      for (int j = 0; j < m_vars[i].m_nCard; j++)        m_oldPotToVar[i][j] = m_msgPotToVar[i][j];  }  public boolean sendMsgToVars(boolean bMaximize)  {    boolean bConverged = true;    for (int v = 0; v < m_vars.length; v++) {      Variable var = m_vars[v];      // absorb variable messages      Matrix temp1 = newWeightMatrix();      for (int vv = 0; vv < m_vars.length; vv++) {	if (vv != v) {	  Variable varvar = m_vars[vv];	  temp1.dotProduct(varvar.getOldMessage(this), vv);	}      }      temp1.dotProduct(getWeightMatrix());      double[] temp2 = temp1.marginalize(v, bMaximize);      MathUtils.normalize(temp2);      setMessage(v, temp2);      // need more time?      if (!MathUtils.approxeq(getMessage(v), getOldMessage(v),			      FactorGraph.TOLERANCE))	bConverged = false;    }        return bConverged;  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -