📄 potentialselect.java
字号:
package rmn;import java.util.*;public class PotentialSelect extends Potential { public PotentialSelect(PotentialFactory pf) { super(pf); } public boolean sendMsgToVars(boolean bMaximize) { if (bMaximize) return sendMsgToVarsMax(); else return sendMsgToVarsSum(); } /* public boolean sendMsgToVarsMax() { boolean bConverged = true; // locate "sel" variable, create message Variable varSel = m_vars[0]; double[] matSel = varSel.getOldMessage(this); double[] msgSel = varSel.newMessage(); // compute "sel" message, 0 msgSel[0] = 1.0; for (int v = 1; v < m_vars.length; v++) { Variable varArg = m_vars[v]; double[] matArg = varArg.getOldMessage(this); msgSel[0] = msgSel[0] * matArg[0]; } // compute "sel" message, > 0 for (int k = 1; k < varSel.m_nCard; k++) { double maxRatio = 0.0; for (int v = 1; v < m_vars.length; v++) { Variable varArg = m_vars[v]; double[] matArg = varArg.getOldMessage(this); assert matArg[0] != 0 : v; double ratio = matArg[k] / matArg[0]; if (ratio > maxRatio) maxRatio = ratio; } msgSel[k] = msgSel[0] * maxRatio; } // compute "sel" argument messages for (int v = 1; v < m_vars.length; v++) { Variable varArg = m_vars[v]; double[] matArg = varArg.getOldMessage(this); double msgArg[]= varArg.newMessage(); // argument message, > 0 for (int k = 1; k < varSel.m_nCard; k++) msgArg[k] = msgSel[0] * matSel[k] / matArg[0]; // argument message, 0 double max = msgSel[0] * matSel[0] / matArg[0]; double maxRatio = 0.0; for (int vv = 1; vv < m_vars.length; vv++) { if (vv != v) { Variable vvArg = m_vars[vv]; double[] mmArg = vvArg.getOldMessage(this); for (int k = 1; k < varSel.m_nCard; k++) { double ratio = matSel[k] * mmArg[k] / mmArg[0]; if (ratio > maxRatio) maxRatio = ratio; } } } msgArg[0] = Math.max(max, msgSel[0] * maxRatio / matArg[0]); MathUtils.normalize(msgArg); setMessage(v, msgArg); // need more time? if (!MathUtils.approxeq(getMessage(v), getOldMessage(v), FactorGraph.TOLERANCE)) bConverged = false; } MathUtils.normalize(msgSel); setMessage(0, msgSel); // need more time? if (!MathUtils.approxeq(getMessage(0), getOldMessage(0), FactorGraph.TOLERANCE)) bConverged = false; return bConverged; }*/ public boolean sendMsgToVarsMax() { boolean bConverged = true; // locate "sel" variable, create message Variable varSel = m_vars[0]; double[] matSel = varSel.getOldMessage(this); double[] msgSel = varSel.newMessage(); // compute "sel" message, 0 msgSel[0] = 1.0; for (int v = 1; v < m_vars.length; v++) { Variable varArg = m_vars[v]; double[] matArg = varArg.getOldMessage(this); msgSel[0] = msgSel[0] * matArg[0]; } // compute "sel" message, > 0 for (int k = 1; k < varSel.m_nCard; k++) { double max = 0.0; for (int v = 1; v < m_vars.length; v++) { Variable varArg = m_vars[v]; double[] matArg = varArg.getOldMessage(this); double prod = matArg[k]; for (int vv = 1; vv < m_vars.length; vv++) if (vv != v) { double[] mmat = m_vars[vv].getOldMessage(this); prod = prod * mmat[0]; } if (prod > max ) max = prod; } msgSel[k] = max; } // compute "sel" argument messages for (int v = 1; v < m_vars.length; v++) { Variable varArg = m_vars[v]; double[] matArg = varArg.getOldMessage(this); double msgArg[]= varArg.newMessage(); // argument message, > 0 for (int k = 1; k < varSel.m_nCard; k++) { msgArg[k] = matSel[k]; for (int vv = 1; vv < m_vars.length; vv++) if (vv != v) { double[] mmat = m_vars[vv].getOldMessage(this); msgArg[k] = msgArg[k] * mmat[0]; } } // argument message, 0 double maxZ = matSel[0]; for (int vv = 1; vv < m_vars.length; vv++) if (vv != v) { double[] mmat = m_vars[vv].getOldMessage(this); maxZ = maxZ * mmat[0]; } double maxNZ = 0.0; for (int vv = 1; vv < m_vars.length; vv++) { if (vv != v) { Variable vvArg = m_vars[vv]; double[] mmArg = vvArg.getOldMessage(this); for (int k = 1; k < varSel.m_nCard; k++) { double prod = matSel[k] * mmArg[k]; for (int vvv = 1; vvv < m_vars.length; vvv++) { if (vvv != v && vvv != vv) { double[] mmmat = m_vars[vvv].getOldMessage(this); prod = prod * mmmat[0]; } } if (prod > maxNZ) maxNZ = prod; } } } msgArg[0] = Math.max(maxZ, maxNZ); MathUtils.normalize(msgArg); setMessage(v, msgArg); // need more time? if (!MathUtils.approxeq(getMessage(v), getOldMessage(v), FactorGraph.TOLERANCE)) bConverged = false; } MathUtils.normalize(msgSel); setMessage(0, msgSel); // need more time? if (!MathUtils.approxeq(getMessage(0), getOldMessage(0), FactorGraph.TOLERANCE)) bConverged = false; return bConverged; } public boolean sendMsgToVarsSum() { boolean bConverged = true; // locate "sel" variable, create message Variable varSel = m_vars[0]; double[] matSel = varSel.getOldMessage(this); double[] msgSel = varSel.newMessage(); // compute "sel" message, 0 msgSel[0] = 1.0; for (int v = 1; v < m_vars.length; v++) { Variable varArg = m_vars[v]; double[] matArg = varArg.getOldMessage(this); msgSel[0] = msgSel[0] * matArg[0]; } // compute "sel" message, > 0 for (int k = 1; k < varSel.m_nCard; k++) { msgSel[k] = 0; for (int v = 1; v < m_vars.length; v++) { Variable varArg = m_vars[v]; double[] matArg = varArg.getOldMessage(this); msgSel[k] += msgSel[0] * matArg[k] / matArg[0]; } } // compute "sel" argument messages for (int v = 1; v < m_vars.length; v++) { Variable varArg = m_vars[v]; double[] matArg = varArg.getOldMessage(this); double msgArg[]= varArg.newMessage(); // argument message, > 0 for (int k = 1; k < varSel.m_nCard; k++) msgArg[k] = msgSel[0] * matSel[k] / matArg[0]; // argument message, 0 msgArg[0] = msgSel[0] * matSel[0] / matArg[0]; for (int vv = 1; vv < m_vars.length; vv++) { if (vv != v) { Variable vvArg = m_vars[vv]; double[] mmArg = vvArg.getOldMessage(this); for (int k = 1; k < varSel.m_nCard; k++) { msgArg[0] += msgSel[0] * matSel[k] * mmArg[k] / (matArg[0] * mmArg[0]); } } } MathUtils.normalize(msgArg); setMessage(v, msgArg); // need more time? if (!MathUtils.approxeq(getMessage(v), getOldMessage(v), FactorGraph.TOLERANCE)) bConverged = false; } MathUtils.normalize(msgSel); setMessage(0, msgSel); // need more time? if (!MathUtils.approxeq(getMessage(0), getOldMessage(0), FactorGraph.TOLERANCE)) bConverged = false; return bConverged; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -