margincalculator.java
来自「Weka」· Java 代码 · 共 921 行 · 第 1/2 页
JAVA
921 行
package weka.classifiers.bayes.net;import weka.classifiers.bayes.net.BIFReader;import weka.classifiers.bayes.net.ParentSet;import weka.classifiers.bayes.net.estimate.BayesNetEstimator;import weka.classifiers.bayes.net.estimate.DiscreteEstimatorBayes;import weka.classifiers.bayes.BayesNet;import java.io.Serializable;import java.util.*;public class MarginCalculator implements Serializable { /** for serialization */ private static final long serialVersionUID = 650278019241175534L; boolean m_debug = false; public JunctionTreeNode m_root = null; JunctionTreeNode [] jtNodes; public int getNode(String sNodeName) { int iNode = 0; while (iNode < m_root.m_bayesNet.m_Instances.numAttributes()) { if (m_root.m_bayesNet.m_Instances.attribute(iNode).name().equals(sNodeName)) { return iNode; } iNode++; } //throw new Exception("Could not find node [[" + sNodeName + "]]"); return -1; } public String toXMLBIF03() {return m_root.m_bayesNet.toXMLBIF03();} /** * Calc marginal distributions of nodes in Bayesian network * Note that a connected network is assumed. * Unconnected networks may give unexpected results. * @param bayesNet * @return root of junction tree */ public void calcMargins(BayesNet bayesNet) throws Exception { //System.out.println(bayesNet.toString()); boolean[][] bAdjacencyMatrix = moralize(bayesNet); process(bAdjacencyMatrix, bayesNet); } // calcMargins public void calcFullMargins(BayesNet bayesNet) throws Exception { //System.out.println(bayesNet.toString()); int nNodes = bayesNet.getNrOfNodes(); boolean[][] bAdjacencyMatrix = new boolean[nNodes][nNodes]; for (int iNode = 0; iNode < nNodes; iNode++) { for (int iNode2 = 0; iNode2 < nNodes; iNode2++) { bAdjacencyMatrix[iNode][iNode2] = true; } } process(bAdjacencyMatrix, bayesNet); } // calcMargins public void process(boolean[][] bAdjacencyMatrix, BayesNet bayesNet) throws Exception { int[] order = getMaxCardOrder(bAdjacencyMatrix); bAdjacencyMatrix = fillIn(order, bAdjacencyMatrix); order = getMaxCardOrder(bAdjacencyMatrix); Set [] cliques = getCliques(order, bAdjacencyMatrix); Set [] separators = getSeparators(order, cliques); int [] parentCliques = getCliqueTree(order, cliques, separators); // report cliques int nNodes = bAdjacencyMatrix.length; if (m_debug) { for (int i = 0; i < nNodes; i++) { int iNode = order[i]; if (cliques[iNode] != null) { System.out.print("Clique " + iNode + " ("); Iterator nodes = cliques[iNode].iterator(); while (nodes.hasNext()) { int iNode2 = (Integer) nodes.next(); System.out.print(iNode2 + " " + bayesNet.getNodeName(iNode2)); if (nodes.hasNext()) { System.out.print(","); } } System.out.print(") S("); nodes = separators[iNode].iterator(); while (nodes.hasNext()) { int iNode2 = (Integer) nodes.next(); System.out.print(iNode2 + " " + bayesNet.getNodeName(iNode2)); if (nodes.hasNext()) { System.out.print(","); } } System.out.println(") parent clique " + parentCliques[iNode]); } } } jtNodes = getJunctionTree(cliques, separators, parentCliques, order, bayesNet); m_root = null; for (int iNode = 0; iNode < nNodes; iNode++) { if (parentCliques[iNode] < 0 && jtNodes[iNode] != null) { m_root = jtNodes[iNode]; break; } } m_Margins = new double[nNodes][]; initialize(jtNodes, order, cliques, separators, parentCliques); // sanity check for (int i = 0; i < nNodes; i++) { int iNode = order[i]; if (cliques[iNode] != null) { if (parentCliques[iNode] == -1 && separators[iNode].size() > 0) { throw new Exception("Something wrong in clique tree"); } } } if (m_debug) { //System.out.println(m_root.toString()); } } // process void initialize(JunctionTreeNode [] jtNodes, int [] order, Set [] cliques, Set [] separators, int [] parentCliques) { int nNodes = order.length; for (int i = nNodes - 1; i >= 0; i--) { int iNode = order[i]; if (jtNodes[iNode]!=null) { jtNodes[iNode].initializeUp(); } } for (int i = 0; i < nNodes; i++) { int iNode = order[i]; if (jtNodes[iNode]!=null) { jtNodes[iNode].initializeDown(false); } } } // initialize JunctionTreeNode [] getJunctionTree(Set [] cliques, Set [] separators, int [] parentCliques, int [] order, BayesNet bayesNet) { int nNodes = order.length; JunctionTreeNode root = null; JunctionTreeNode [] jtns = new JunctionTreeNode[nNodes]; boolean [] bDone = new boolean[nNodes]; // create junction tree nodes for (int i = 0; i < nNodes; i++) { int iNode = order[i]; if (cliques[iNode] != null) { jtns[iNode] = new JunctionTreeNode(cliques[iNode], bayesNet, bDone); } } // create junction tree separators for (int i = 0; i < nNodes; i++) { int iNode = order[i]; if (cliques[iNode] != null) { JunctionTreeNode parent = null; if (parentCliques[iNode] > 0) { parent = jtns[parentCliques[iNode]]; JunctionTreeSeparator jts = new JunctionTreeSeparator(separators[iNode], bayesNet, jtns[iNode], parent); jtns[iNode].setParentSeparator(jts); jtns[parentCliques[iNode]].addChildClique(jtns[iNode]); } else { root = jtns[iNode]; } } } return jtns; } // getJunctionTree public class JunctionTreeSeparator implements Serializable { private static final long serialVersionUID = 6502780192411755343L; int [] m_nNodes; int m_nCardinality; double [] m_fiParent; double [] m_fiChild; JunctionTreeNode m_parentNode; JunctionTreeNode m_childNode; BayesNet m_bayesNet; JunctionTreeSeparator(Set separator, BayesNet bayesNet, JunctionTreeNode childNode, JunctionTreeNode parentNode) { ////////////////////// // initialize node set m_nNodes = new int[separator.size()]; int iPos = 0; m_nCardinality = 1; for(Iterator nodes = separator.iterator(); nodes.hasNext();) { int iNode = (Integer) nodes.next(); m_nNodes[iPos++] = iNode; m_nCardinality *= bayesNet.getCardinality(iNode); } m_parentNode = parentNode; m_childNode = childNode; m_bayesNet = bayesNet; } // c'tor /** marginalize junciontTreeNode node over all nodes outside the separator set * of the parent clique * */ public void updateFromParent() { double [] fis = update(m_parentNode); if (fis == null) { m_fiParent = null; } else { m_fiParent = fis; // normalize double sum = 0; for (int iPos = 0; iPos < m_nCardinality; iPos++) { sum += m_fiParent[iPos]; } for (int iPos = 0; iPos < m_nCardinality; iPos++) { m_fiParent[iPos] /= sum; } } } // updateFromParent /** marginalize junciontTreeNode node over all nodes outside the separator set * of the child clique * */ public void updateFromChild() { double [] fis = update(m_childNode); if (fis == null) { m_fiChild = null; } else { m_fiChild = fis; // normalize double sum = 0; for (int iPos = 0; iPos < m_nCardinality; iPos++) { sum += m_fiChild[iPos]; } for (int iPos = 0; iPos < m_nCardinality; iPos++) { m_fiChild[iPos] /= sum; } } } // updateFromChild /** marginalize junciontTreeNode node over all nodes outside the separator set * * @param node: one of the neighboring junciont tree nodes of this separator */ public double [] update(JunctionTreeNode node) { if (node.m_P == null) { return null; } double [] fi = new double[m_nCardinality]; int [] values = new int[node.m_nNodes.length]; int [] order = new int[m_bayesNet.getNrOfNodes()]; for (int iNode = 0; iNode < node.m_nNodes.length; iNode++) { order[node.m_nNodes[iNode]] = iNode; } // fill in the values for (int iPos = 0; iPos < node.m_nCardinality; iPos++) { int iNodeCPT = getCPT(node.m_nNodes, node.m_nNodes.length, values, order, m_bayesNet); int iSepCPT = getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet); fi[iSepCPT] += node.m_P[iNodeCPT]; // update values int i = 0; values[i]++; while (i < node.m_nNodes.length && values[i] == m_bayesNet.getCardinality(node.m_nNodes[i])) { values[i] = 0; i++; if (i < node.m_nNodes.length) { values[i]++; } } } return fi; } // update } // class JunctionTreeSeparator public class JunctionTreeNode implements Serializable { private static final long serialVersionUID = 650278019241175536L; /** reference Bayes net for information about variables like name, cardinality, etc. * but not for relations between nodes **/ BayesNet m_bayesNet; /** nodes of the Bayes net in this junction node **/ public int [] m_nNodes; /** cardinality of the instances of variables in this junction node **/ int m_nCardinality; /** potentials for first network **/ double [] m_fi; /** distribution over this junction node according to first Bayes network **/ double [] m_P; double [][] m_MarginalP; JunctionTreeSeparator m_parentSeparator; public void setParentSeparator(JunctionTreeSeparator parentSeparator) {m_parentSeparator = parentSeparator;} public Vector m_children; public void addChildClique(JunctionTreeNode child) {m_children.add(child);} public void initializeUp() { m_P = new double[m_nCardinality]; for (int iPos = 0; iPos < m_nCardinality; iPos++) { m_P[iPos] = m_fi[iPos]; } int [] values = new int[m_nNodes.length]; int [] order = new int[m_bayesNet.getNrOfNodes()]; for (int iNode = 0; iNode < m_nNodes.length; iNode++) { order[m_nNodes[iNode]] = iNode; } for (Iterator child = m_children.iterator(); child.hasNext(); ) { JunctionTreeNode childNode = (JunctionTreeNode) child.next(); JunctionTreeSeparator separator = childNode.m_parentSeparator; // Update the values for (int iPos = 0; iPos < m_nCardinality; iPos++) { int iSepCPT = getCPT(separator.m_nNodes, separator.m_nNodes.length, values, order, m_bayesNet); int iNodeCPT = getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet); m_P[iNodeCPT] *= separator.m_fiChild[iSepCPT]; // update values int i = 0; values[i]++; while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) { values[i] = 0; i++; if (i < m_nNodes.length) { values[i]++; } } } } // normalize double sum = 0; for (int iPos = 0; iPos < m_nCardinality; iPos++) { sum += m_P[iPos]; } for (int iPos = 0; iPos < m_nCardinality; iPos++) { m_P[iPos] /= sum; } if (m_parentSeparator != null) { // not a root node m_parentSeparator.updateFromChild(); } } // initializeUp public void initializeDown(boolean recursively) { if (m_parentSeparator == null) { // a root node calcMarginalProbabilities(); } else { m_parentSeparator.updateFromParent(); int [] values = new int[m_nNodes.length]; int [] order = new int[m_bayesNet.getNrOfNodes()]; for (int iNode = 0; iNode < m_nNodes.length; iNode++) { order[m_nNodes[iNode]] = iNode; } // Update the values for (int iPos = 0; iPos < m_nCardinality; iPos++) { int iSepCPT = getCPT(m_parentSeparator.m_nNodes, m_parentSeparator.m_nNodes.length, values, order, m_bayesNet); int iNodeCPT = getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet); if ( m_parentSeparator.m_fiChild[iSepCPT] > 0) { m_P[iNodeCPT] *= m_parentSeparator.m_fiParent[iSepCPT] / m_parentSeparator.m_fiChild[iSepCPT]; } else { m_P[iNodeCPT] = 0; } // update values int i = 0; values[i]++; while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) { values[i] = 0; i++; if (i < m_nNodes.length) { values[i]++; } } } // normalize double sum = 0; for (int iPos = 0; iPos < m_nCardinality; iPos++) { sum += m_P[iPos]; } for (int iPos = 0; iPos < m_nCardinality; iPos++) { m_P[iPos] /= sum; } m_parentSeparator.updateFromChild(); calcMarginalProbabilities(); } if (recursively) { for (Iterator child = m_children.iterator(); child.hasNext(); ) { JunctionTreeNode childNode = (JunctionTreeNode) child.next(); childNode.initializeDown(true); } } } // initializeDown /** calculate marginal probabilities for the individual nodes in the clique. * Store results in m_MarginalP */ void calcMarginalProbabilities() { // calculate marginal probabilities int [] values = new int[m_nNodes.length]; int [] order = new int[m_bayesNet.getNrOfNodes()]; m_MarginalP = new double[m_nNodes.length][]; for (int iNode = 0; iNode < m_nNodes.length; iNode++) { order[m_nNodes[iNode]] = iNode; m_MarginalP[iNode]=new double[m_bayesNet.getCardinality(m_nNodes[iNode])]; } for (int iPos = 0; iPos < m_nCardinality; iPos++) { int iNodeCPT = getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet); for (int iNode = 0; iNode < m_nNodes.length; iNode++) { m_MarginalP[iNode][values[iNode]] += m_P[iNodeCPT]; } // update values int i = 0; values[i]++; while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) { values[i] = 0; i++; if (i < m_nNodes.length) { values[i]++; } } } for (int iNode = 0; iNode < m_nNodes.length; iNode++) { m_Margins[m_nNodes[iNode]] = m_MarginalP[iNode]; } } // calcMarginalProbabilities public String toString() { StringBuffer buf = new StringBuffer(); for (int iNode = 0; iNode < m_nNodes.length; iNode++) { buf.append(m_bayesNet.getNodeName(m_nNodes[iNode]) + ": "); for (int iValue = 0; iValue < m_MarginalP[iNode].length; iValue++) { buf.append(m_MarginalP[iNode][iValue] + " "); } buf.append('\n'); } for (Iterator child = m_children.iterator(); child.hasNext(); ) { JunctionTreeNode childNode = (JunctionTreeNode) child.next(); buf.append("----------------\n"); buf.append(childNode.toString()); } return buf.toString(); } // toString void calculatePotentials(BayesNet bayesNet, Set clique, boolean [] bDone) { m_fi = new double[m_nCardinality]; int [] values = new int[m_nNodes.length]; int [] order = new int[bayesNet.getNrOfNodes()]; for (int iNode = 0; iNode < m_nNodes.length; iNode++) { order[m_nNodes[iNode]] = iNode; } // find conditional probabilities that need to be taken in account boolean [] bIsContained = new boolean[m_nNodes.length]; for (int iNode = 0; iNode < m_nNodes.length; iNode++) { int nNode = m_nNodes[iNode]; bIsContained[iNode] = !bDone[nNode]; for (int iParent = 0; iParent < bayesNet.getNrOfParents(nNode); iParent++) { int nParent = bayesNet.getParent(nNode, iParent); if (!clique.contains(nParent)) { bIsContained[iNode] = false; } } if (bIsContained[iNode]) {
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?