📄 variable.java
字号:
package rmn;import java.util.*;public class Variable implements Comparable { // variable ID String m_strID; // cardinality int m_nCard; // marginal probabilities double[] m_marginal; boolean m_bHidden; // false if artificial, like "OR" or "SELECT" boolean m_bNatural; // true value int m_nTrueVal; // inferred value int m_nInfVal; // confidence for inferred value double m_conf; // attached potentials Vector m_vecPots; // array of messages, one message for each potential double[][] m_msgVarToPot; double[][] m_oldVarToPot; public Variable(String strID, int nValues) { m_strID = strID; m_nCard = nValues; m_bHidden = true; m_bNatural = true; m_nTrueVal = m_nInfVal = -1; m_conf = 0; m_vecPots = new Vector(); } public void allocateMessages() { m_msgVarToPot = new double[m_vecPots.size()][m_nCard]; m_oldVarToPot = new double[m_vecPots.size()][m_nCard]; } public void deallocateMessages() { m_msgVarToPot = null; m_oldVarToPot = null; } public void setObserved(int nVal) { m_nTrueVal = m_nInfVal = nVal; m_bHidden = false; } public void setArgmax() { // assume positive potentials double max = 0; for (int i = 0; i < m_nCard; i++) if (m_marginal[i] > max) { m_nInfVal = i; max = m_marginal[i]; } m_conf = max; assert m_nInfVal != -1; if (!isHidden()) assert m_nInfVal == m_nTrueVal; } public void setConfMarginal() { // assume positive potentials m_conf = 0; for (int i = 0; i < m_nCard; i++) if (m_marginal[i] > m_conf) m_conf = m_marginal[i]; } public boolean isHidden() { return m_bHidden; } public void setInfHidden() { if (isHidden()) m_nInfVal = -1; } public boolean isInfHidden() { return m_nInfVal == -1; } public void setTrueValue(int nVal) { m_nTrueVal = nVal; } public int getTrueValue() { return m_nTrueVal; } public void setInfValue(int nVal) { m_nInfVal = nVal; } public int getInfValue() { return m_nInfVal; } public void attachPotential(Potential pot) { m_vecPots.add(pot); } public int compareTo(Object obj) { Variable var = (Variable) obj; return m_strID.compareTo(var.m_strID); } public void initMessages() { for (int i = 0; i < m_vecPots.size(); i++) initMessage(m_msgVarToPot[i]); } public void initMessage(double[] msg) { if (isInfHidden()) { for (int j = 0; j < m_nCard; j++) msg[j] = 1; } else { for (int j = 0; j < m_nCard; j++) msg[j] = 0; msg[getInfValue()] = 1; } } public double[] newMessage() { double[] msg = new double[m_nCard]; initMessage(msg); return msg; } public double[] getMessage(Potential pot) { for (int i = 0; i < m_vecPots.size(); i++) if (m_vecPots.get(i) == pot) return m_msgVarToPot[i]; return null; } public double[] getOldMessage(Potential pot) { for (int i = 0; i < m_vecPots.size(); i++) if (m_vecPots.get(i) == pot) return m_oldVarToPot[i]; return null; } public double[] getMessage(int nFactor) { return m_msgVarToPot[nFactor]; } public double[] getOldMessage(int nFactor) { return m_oldVarToPot[nFactor]; } public void setMessage(int nFactor, double[] vector) { assert vector.length == m_nCard : vector.length; System.arraycopy(vector, 0, m_msgVarToPot[nFactor], 0, m_nCard); } public void backupMessages() { for (int i = 0; i < m_vecPots.size(); i++) for (int j = 0; j < m_nCard; j++) m_oldVarToPot[i][j] = m_msgVarToPot[i][j]; } public void setMarginal(double[] marginal) { m_marginal = new double[m_nCard]; for (int i = 0; i < m_nCard; i++) m_marginal[i] = marginal[i]; } public String getID() { return m_strID; } public double getConf() { return m_conf; } public void setConf(double conf) { m_conf = conf; } public boolean isAttached() { return m_vecPots.size() > 0; } public void setArtificial() { m_bNatural = false; } public void cleanup() { Vector vecPots = new Vector(); for (int i = 0; i < m_vecPots.size(); i++) { Potential pot = (Potential) m_vecPots.get(i); if (pot.m_pf.m_bInference) vecPots.add(pot); } m_vecPots = vecPots; } public boolean sendMsgToPots(boolean bMaximize) { boolean bConverged = true; for (int f = 0; f < m_vecPots.size(); f++) { Potential pot = (Potential) m_vecPots.get(f); // absorb potential messages double[] temp = newMessage(); for (int ff = 0; ff < m_vecPots.size(); ff++) { if (ff != f) { Potential potpot = (Potential) m_vecPots.get(ff); MathUtils.dotProduct(temp, potpot.getOldMessage(this)); } } MathUtils.normalize(temp); setMessage(f, temp); // need more time? if (!MathUtils.approxeq(getMessage(f), getOldMessage(f), FactorGraph.TOLERANCE)) bConverged = false; } return bConverged; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -