potentialmostone.java
来自「是实现关系型贝叶斯网络一中机器学习算法」· Java 代码 · 共 113 行
JAVA
113 行
package rmn;import java.util.*;public class PotentialMostOne extends Potential { public PotentialMostOne(PotentialFactory pf) { super(pf); } public boolean sendMsgToVars(boolean bMaximize) { if (bMaximize) return sendMsgToVarsMax(); else return sendMsgToVarsSum(); } public boolean sendMsgToVarsMax() { boolean bConverged = true; for (int v = 0; v < m_vars.length; v++) { double msg[] = new double[2]; // compute "1" message msg[1] = 1.0; for (int j = 0; j < m_vars.length; j++) if (j != v) { double[] mat = m_vars[j].getOldMessage(this); msg[1] = msg[1] * mat[0]; } // compute "0" message, brute-force (can be done in linear time) msg[0] = 1.0; for (int k = 0; k < m_vars.length; k++) if (k != v) { double[] mat = m_vars[k].getOldMessage(this); msg[0] = msg[0] * mat[0]; } for (int k = 0; k < m_vars.length; k++) if (k != v) { double[] matk = m_vars[k].getOldMessage(this); double m = matk[1]; for (int j = 0; j < m_vars.length; j++) if (j != v && j != k) { double[] matj = m_vars[j].getOldMessage(this); m = m * matj[0]; } if (m > msg[0]) msg[0] = m; } MathUtils.normalize(msg); setMessage(v, msg); // need more time? if (!MathUtils.approxeq(getMessage(v), getOldMessage(v), FactorGraph.TOLERANCE)) bConverged = false; } return bConverged; } public boolean sendMsgToVarsSum() { boolean bConverged = true; for (int v = 0; v < m_vars.length; v++) { double msg[] = new double[2]; // compute "1" message msg[1] = 1.0; for (int j = 0; j < m_vars.length; j++) if (j != v) { double[] mat = m_vars[j].getOldMessage(this); msg[1] = msg[1] * mat[0]; } // compute "0" message msg[0] = 0.0; for (int k = 0; k < m_vars.length; k++) if (k != v) { double m = 1.0; for (int j = 0; j < m_vars.length; j++) if (j != v && j != k) { double[] mat = m_vars[j].getOldMessage(this); m = m * mat[0]; } msg[0] = msg[0] + m; } msg[0] = msg[0] - (m_vars.length - 2) * msg[1]; MathUtils.normalize(msg); setMessage(v, msg); // need more time? if (!MathUtils.approxeq(getMessage(v), getOldMessage(v), FactorGraph.TOLERANCE)) bConverged = false; } return bConverged; }}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?