⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 potentialselect.java

📁 是实现关系型贝叶斯网络一中机器学习算法
💻 JAVA
字号:
package rmn;import java.util.*;public class PotentialSelect extends Potential {  public PotentialSelect(PotentialFactory pf)  {    super(pf);  }  public boolean sendMsgToVars(boolean bMaximize)  {    if (bMaximize)      return sendMsgToVarsMax();    else      return sendMsgToVarsSum();  }    /*  public boolean sendMsgToVarsMax()  {    boolean bConverged = true;    // locate "sel" variable, create message    Variable varSel = m_vars[0];    double[] matSel = varSel.getOldMessage(this);    double[] msgSel = varSel.newMessage();    // compute "sel" message, 0    msgSel[0] = 1.0;    for (int v = 1; v < m_vars.length; v++) {      Variable varArg = m_vars[v];      double[] matArg = varArg.getOldMessage(this);      msgSel[0] = msgSel[0] * matArg[0];    }    // compute "sel" message, > 0    for (int k = 1; k < varSel.m_nCard; k++) {      double maxRatio = 0.0;      for (int v = 1; v < m_vars.length; v++) {	Variable varArg = m_vars[v];	double[] matArg = varArg.getOldMessage(this);	assert matArg[0] != 0 : v;	double ratio = matArg[k] / matArg[0];	if (ratio > maxRatio)	  maxRatio = ratio;      }      msgSel[k] = msgSel[0] * maxRatio;    }    // compute "sel" argument messages    for (int v = 1; v < m_vars.length; v++) {      Variable varArg = m_vars[v];      double[] matArg = varArg.getOldMessage(this);      double msgArg[]= varArg.newMessage();      // argument message, > 0      for (int k = 1; k < varSel.m_nCard; k++)	msgArg[k] = msgSel[0] * matSel[k] / matArg[0];      // argument message, 0      double max = msgSel[0] * matSel[0] / matArg[0];      double maxRatio = 0.0;      for (int vv = 1; vv < m_vars.length; vv++) {	if (vv != v) {	  Variable vvArg = m_vars[vv];	  double[] mmArg = vvArg.getOldMessage(this);	  for (int k = 1; k < varSel.m_nCard; k++) {	    double ratio = matSel[k] * mmArg[k] / mmArg[0];	    if (ratio > maxRatio)	      maxRatio = ratio;	  }	}      }      msgArg[0] = Math.max(max, msgSel[0] * maxRatio / matArg[0]);      MathUtils.normalize(msgArg);      setMessage(v, msgArg);      // need more time?      if (!MathUtils.approxeq(getMessage(v), getOldMessage(v), 			      FactorGraph.TOLERANCE))	bConverged = false;    }        MathUtils.normalize(msgSel);    setMessage(0, msgSel);        // need more time?    if (!MathUtils.approxeq(getMessage(0), getOldMessage(0), 			    FactorGraph.TOLERANCE))      bConverged = false;        return bConverged;  }*/    public boolean sendMsgToVarsMax()  {    boolean bConverged = true;    // locate "sel" variable, create message    Variable varSel = m_vars[0];    double[] matSel = varSel.getOldMessage(this);    double[] msgSel = varSel.newMessage();        // compute "sel" message, 0    msgSel[0] = 1.0;    for (int v = 1; v < m_vars.length; v++) {      Variable varArg = m_vars[v];      double[] matArg = varArg.getOldMessage(this);      msgSel[0] = msgSel[0] * matArg[0];    }        // compute "sel" message, > 0    for (int k = 1; k < varSel.m_nCard; k++) {      double max = 0.0;      for (int v = 1; v < m_vars.length; v++) {	Variable varArg = m_vars[v];	double[] matArg = varArg.getOldMessage(this);	double prod = matArg[k];	for (int vv = 1; vv < m_vars.length; vv++)	  if (vv != v) {	    double[] mmat = m_vars[vv].getOldMessage(this);	    prod = prod * mmat[0];	  }	if (prod > max )	  max = prod;      }      msgSel[k] = max;    }        // compute "sel" argument messages    for (int v = 1; v < m_vars.length; v++) {      Variable varArg = m_vars[v];      double[] matArg = varArg.getOldMessage(this);      double msgArg[]= varArg.newMessage();            // argument message, > 0      for (int k = 1; k < varSel.m_nCard; k++) {	msgArg[k] = matSel[k];	for (int vv = 1; vv < m_vars.length; vv++)	  if (vv != v) {	    double[] mmat = m_vars[vv].getOldMessage(this);	    msgArg[k] = msgArg[k] * mmat[0];	  }      }            // argument message, 0      double maxZ = matSel[0];      for (int vv = 1; vv < m_vars.length; vv++)	if (vv != v) {	  double[] mmat = m_vars[vv].getOldMessage(this);	  maxZ = maxZ * mmat[0];	}            double maxNZ = 0.0;      for (int vv = 1; vv < m_vars.length; vv++) {	if (vv != v) {	  Variable vvArg = m_vars[vv];	  double[] mmArg = vvArg.getOldMessage(this);	  for (int k = 1; k < varSel.m_nCard; k++) {	    double prod = matSel[k] * mmArg[k];	    for (int vvv = 1; vvv < m_vars.length; vvv++) {	      if (vvv != v && vvv != vv) {		double[] mmmat = m_vars[vvv].getOldMessage(this);		prod = prod * mmmat[0];	      }	    }	    if (prod > maxNZ)	      maxNZ = prod;	  }	}      }      msgArg[0] = Math.max(maxZ, maxNZ);            MathUtils.normalize(msgArg);      setMessage(v, msgArg);            // need more time?      if (!MathUtils.approxeq(getMessage(v), getOldMessage(v), 			      FactorGraph.TOLERANCE))	bConverged = false;    }        MathUtils.normalize(msgSel);    setMessage(0, msgSel);        // need more time?    if (!MathUtils.approxeq(getMessage(0), getOldMessage(0), 			    FactorGraph.TOLERANCE))      bConverged = false;        return bConverged;  }    public boolean sendMsgToVarsSum()  {    boolean bConverged = true;    // locate "sel" variable, create message    Variable varSel = m_vars[0];    double[] matSel = varSel.getOldMessage(this);    double[] msgSel = varSel.newMessage();    // compute "sel" message, 0    msgSel[0] = 1.0;    for (int v = 1; v < m_vars.length; v++) {      Variable varArg = m_vars[v];      double[] matArg = varArg.getOldMessage(this);      msgSel[0] = msgSel[0] * matArg[0];    }    // compute "sel" message, > 0    for (int k = 1; k < varSel.m_nCard; k++) {      msgSel[k] = 0;      for (int v = 1; v < m_vars.length; v++) {	Variable varArg = m_vars[v];	double[] matArg = varArg.getOldMessage(this);	msgSel[k] += msgSel[0] * matArg[k] / matArg[0];      }    }    // compute "sel" argument messages    for (int v = 1; v < m_vars.length; v++) {      Variable varArg = m_vars[v];      double[] matArg = varArg.getOldMessage(this);      double msgArg[]= varArg.newMessage();      // argument message, > 0      for (int k = 1; k < varSel.m_nCard; k++)	msgArg[k] = msgSel[0] * matSel[k] / matArg[0];      // argument message, 0      msgArg[0] = msgSel[0] * matSel[0] / matArg[0];      for (int vv = 1; vv < m_vars.length; vv++) {	if (vv != v) {	  Variable vvArg = m_vars[vv];	  double[] mmArg = vvArg.getOldMessage(this);	  for (int k = 1; k < varSel.m_nCard; k++) {	    msgArg[0] += msgSel[0] * matSel[k] * mmArg[k] / (matArg[0] * mmArg[0]);	  }	}      }      MathUtils.normalize(msgArg);      setMessage(v, msgArg);      // need more time?      if (!MathUtils.approxeq(getMessage(v), getOldMessage(v), 			      FactorGraph.TOLERANCE))	bConverged = false;    }        MathUtils.normalize(msgSel);    setMessage(0, msgSel);    // need more time?    if (!MathUtils.approxeq(getMessage(0), getOldMessage(0), 			    FactorGraph.TOLERANCE))      bConverged = false;        return bConverged;  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -