📄 potential.java
字号:
package rmn;import java.util.*;public class Potential { // potential factory, used for getting potential weights PotentialFactory m_pf; // attached variables Variable[] m_vars; // array of messages, one message for each variable double[][] m_msgPotToVar; double[][] m_oldPotToVar; public Potential(PotentialFactory pf) { m_pf = pf; } public void allocateMessages() { m_msgPotToVar = new double[m_vars.length][]; m_oldPotToVar = new double[m_vars.length][]; for (int i = 0; i < m_vars.length; i++) { m_msgPotToVar[i] = new double[m_vars[i].m_nCard]; m_oldPotToVar[i] = new double[m_vars[i].m_nCard]; } } public void deallocateMessages() { m_msgPotToVar = null; m_oldPotToVar = null; } public void attachVariables(Variable[] vars) { // assert vars.length == m_pf.size() : vars.length; // m_vars = new Variable[m_pf.size()]; m_vars = new Variable[vars.length]; for (int i = 0; i < m_vars.length; i++) m_vars[i] = vars[i]; } public double[] getMessage(Variable var) { for (int i = 0; i < m_vars.length; i++) if (m_vars[i].compareTo(var) == 0) return m_msgPotToVar[i]; return null; } public double[] getOldMessage(Variable var) { for (int i = 0; i < m_vars.length; i++) if (m_vars[i].compareTo(var) == 0) return m_oldPotToVar[i]; return null; } public double[] getMessage(int var) { return m_msgPotToVar[var]; } public double[] getOldMessage(int var) { return m_oldPotToVar[var]; } public void setMessage(int var, double[] msg) { assert msg.length == m_vars[var].m_nCard: msg.length; System.arraycopy(msg, 0, m_msgPotToVar[var], 0, msg.length); } public void initMessages() { for (int i = 0; i < m_vars.length; i++) Arrays.fill(m_msgPotToVar[i], 1); } public Matrix newWeightMatrix() { Matrix m = m_pf.newWeightMatrix(); m.fill(1); return m; } public Matrix getWeightMatrix() { return m_pf.getWeightMatrix(); } public void backupMessages() { for (int i = 0; i < m_vars.length; i++) for (int j = 0; j < m_vars[i].m_nCard; j++) m_oldPotToVar[i][j] = m_msgPotToVar[i][j]; } public boolean sendMsgToVars(boolean bMaximize) { boolean bConverged = true; for (int v = 0; v < m_vars.length; v++) { Variable var = m_vars[v]; // absorb variable messages Matrix temp1 = newWeightMatrix(); for (int vv = 0; vv < m_vars.length; vv++) { if (vv != v) { Variable varvar = m_vars[vv]; temp1.dotProduct(varvar.getOldMessage(this), vv); } } temp1.dotProduct(getWeightMatrix()); double[] temp2 = temp1.marginalize(v, bMaximize); MathUtils.normalize(temp2); setMessage(v, temp2); // need more time? if (!MathUtils.approxeq(getMessage(v), getOldMessage(v), FactorGraph.TOLERANCE)) bConverged = false; } return bConverged; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -