⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 variable.java

📁 是实现关系型贝叶斯网络一中机器学习算法
💻 JAVA
字号:
package rmn;import java.util.*;public class Variable implements Comparable {  // variable ID  String m_strID;  // cardinality  int m_nCard;  // marginal probabilities  double[] m_marginal;  boolean m_bHidden;  // false if artificial, like "OR" or "SELECT"  boolean m_bNatural;  // true value  int m_nTrueVal;  // inferred value  int m_nInfVal;  // confidence for inferred value  double m_conf;  // attached potentials  Vector m_vecPots;  // array of messages, one message for each potential  double[][] m_msgVarToPot;  double[][] m_oldVarToPot;  public Variable(String strID, int nValues)  {    m_strID = strID;    m_nCard = nValues;    m_bHidden = true;    m_bNatural = true;    m_nTrueVal = m_nInfVal = -1;    m_conf = 0;    m_vecPots = new Vector();  }  public void allocateMessages()  {    m_msgVarToPot = new double[m_vecPots.size()][m_nCard];    m_oldVarToPot = new double[m_vecPots.size()][m_nCard];  }  public void deallocateMessages()  {    m_msgVarToPot = null;    m_oldVarToPot = null;  }  public void setObserved(int nVal)  {    m_nTrueVal = m_nInfVal = nVal;    m_bHidden = false;  }  public void setArgmax()  {    // assume positive potentials    double max = 0;    for (int i = 0; i < m_nCard; i++)      if (m_marginal[i] > max) {        m_nInfVal = i;	max = m_marginal[i];      }    m_conf = max;    assert m_nInfVal != -1;    if (!isHidden())      assert m_nInfVal == m_nTrueVal;  }    public void setConfMarginal()  {    // assume positive potentials    m_conf = 0;    for (int i = 0; i < m_nCard; i++)      if (m_marginal[i] > m_conf)        m_conf = m_marginal[i];  }    public boolean isHidden()  {    return m_bHidden;  }  public void setInfHidden()  {    if (isHidden())      m_nInfVal = -1;  }  public boolean isInfHidden()  {    return m_nInfVal == -1;  }  public void setTrueValue(int nVal)  {    m_nTrueVal = nVal;  }    public int getTrueValue()  {    return m_nTrueVal;  }  public void setInfValue(int nVal)  {    m_nInfVal = nVal;  }    public int getInfValue()  {    return m_nInfVal;  }  public void attachPotential(Potential pot)  {    m_vecPots.add(pot);  }  public int compareTo(Object obj)  {    Variable var = (Variable) obj;    return m_strID.compareTo(var.m_strID);  }  public void initMessages()  {    for (int i = 0; i < m_vecPots.size(); i++)      initMessage(m_msgVarToPot[i]);  }  public void initMessage(double[] msg)  {    if (isInfHidden()) {      for (int j = 0; j < m_nCard; j++)        msg[j] = 1;    }    else {      for (int j = 0; j < m_nCard; j++)        msg[j] = 0;      msg[getInfValue()] = 1;    }  }  public double[] newMessage()  {    double[] msg = new double[m_nCard];    initMessage(msg);    return msg;  }  public double[] getMessage(Potential pot)  {    for (int i = 0; i < m_vecPots.size(); i++)      if (m_vecPots.get(i) == pot)        return m_msgVarToPot[i];    return null;  }  public double[] getOldMessage(Potential pot)  {    for (int i = 0; i < m_vecPots.size(); i++)      if (m_vecPots.get(i) == pot)        return m_oldVarToPot[i];    return null;  }  public double[] getMessage(int nFactor)  {    return m_msgVarToPot[nFactor];  }  public double[] getOldMessage(int nFactor)  {    return m_oldVarToPot[nFactor];  }  public void setMessage(int nFactor, double[] vector)  {    assert vector.length == m_nCard : vector.length;    System.arraycopy(vector, 0, m_msgVarToPot[nFactor], 0, m_nCard);  }  public void backupMessages()  {    for (int i = 0; i < m_vecPots.size(); i++)      for (int j = 0; j < m_nCard; j++)        m_oldVarToPot[i][j] = m_msgVarToPot[i][j];  }  public void setMarginal(double[] marginal)  {    m_marginal = new double[m_nCard];    for (int i = 0; i < m_nCard; i++)      m_marginal[i] = marginal[i];  }  public String getID()  {    return m_strID;  }  public double getConf()  {    return m_conf;  }  public void setConf(double conf)  {    m_conf = conf;  }  public boolean isAttached()  {    return m_vecPots.size() > 0;  }  public void setArtificial()  {    m_bNatural = false;  }  public void cleanup()  {    Vector vecPots = new Vector();    for (int i = 0; i < m_vecPots.size(); i++) {      Potential pot = (Potential) m_vecPots.get(i);      if (pot.m_pf.m_bInference)	vecPots.add(pot);    }    m_vecPots = vecPots;  }      public boolean sendMsgToPots(boolean bMaximize)  {    boolean bConverged = true;    for (int f = 0; f < m_vecPots.size(); f++) {      Potential pot = (Potential) m_vecPots.get(f);      // absorb potential messages      double[] temp = newMessage();      for (int ff = 0; ff < m_vecPots.size(); ff++) {	if (ff != f) {	  Potential potpot = (Potential) m_vecPots.get(ff);	  MathUtils.dotProduct(temp, potpot.getOldMessage(this));	}      }      MathUtils.normalize(temp);      setMessage(f, temp);      // need more time?      if (!MathUtils.approxeq(getMessage(f), getOldMessage(f),			      FactorGraph.TOLERANCE))	bConverged = false;    }        return bConverged;  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -