margincalculator.java
来自「Weka」· Java 代码 · 共 921 行 · 第 1/2 页
JAVA
921 行
bDone[nNode] = true; if (m_debug) { System.out.println("adding node " +nNode); } } } // fill in the values for (int iPos = 0; iPos < m_nCardinality; iPos++) { int iCPT = getCPT(m_nNodes, m_nNodes.length, values, order, bayesNet); m_fi[iCPT] = 1.0; for (int iNode = 0; iNode < m_nNodes.length; iNode++) { if (bIsContained[iNode]) { int nNode = m_nNodes[iNode]; int [] nNodes = bayesNet.getParentSet(nNode).getParents(); int iCPT2 = getCPT(nNodes, bayesNet.getNrOfParents(nNode), values, order, bayesNet); double f = bayesNet.getDistributions()[nNode][iCPT2].getProbability(values[iNode]); m_fi[iCPT] *= f; } } // update values int i = 0; values[i]++; while (i < m_nNodes.length && values[i] == bayesNet.getCardinality(m_nNodes[i])) { values[i] = 0; i++; if (i < m_nNodes.length) { values[i]++; } } } } // calculatePotentials JunctionTreeNode(Set clique, BayesNet bayesNet, boolean [] bDone) { m_bayesNet = bayesNet; m_children = new Vector(); ////////////////////// // initialize node set m_nNodes = new int[clique.size()]; int iPos = 0; m_nCardinality = 1; for(Iterator nodes = clique.iterator(); nodes.hasNext();) { int iNode = (Integer) nodes.next(); m_nNodes[iPos++] = iNode; m_nCardinality *= bayesNet.getCardinality(iNode); } //////////////////////////////// // initialize potential function calculatePotentials(bayesNet, clique, bDone); } // JunctionTreeNode c'tor /* check whether this junciton tree node contains node nNode * */ boolean contains(int nNode) { for (int iNode = 0; iNode < m_nNodes.length; iNode++) { if (m_nNodes[iNode]== nNode){ return true; } } return false; } // contains public void setEvidence(int nNode, int iValue) throws Exception { int [] values = new int[m_nNodes.length]; int [] order = new int[m_bayesNet.getNrOfNodes()]; int nNodeIdx = -1; for (int iNode = 0; iNode < m_nNodes.length; iNode++) { order[m_nNodes[iNode]] = iNode; if (m_nNodes[iNode] == nNode) { nNodeIdx = iNode; } } if (nNodeIdx < 0) { throw new Exception("setEvidence: Node " + nNode + " not found in this clique"); } for (int iPos = 0; iPos < m_nCardinality; iPos++) { if (values[nNodeIdx] != iValue) { int iNodeCPT = getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet); m_P[iNodeCPT] = 0; } // update values int i = 0; values[i]++; while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) { values[i] = 0; i++; if (i < m_nNodes.length) { values[i]++; } } } // normalize double sum = 0; for (int iPos = 0; iPos < m_nCardinality; iPos++) { sum += m_P[iPos]; } for (int iPos = 0; iPos < m_nCardinality; iPos++) { m_P[iPos] /= sum; } calcMarginalProbabilities(); updateEvidence(this); } // setEvidence void updateEvidence(JunctionTreeNode source) { if (source != this) { int [] values = new int[m_nNodes.length]; int [] order = new int[m_bayesNet.getNrOfNodes()]; for (int iNode = 0; iNode < m_nNodes.length; iNode++) { order[m_nNodes[iNode]] = iNode; } int [] nChildNodes = source.m_parentSeparator.m_nNodes; int nNumChildNodes = nChildNodes.length; for (int iPos = 0; iPos < m_nCardinality; iPos++) { int iNodeCPT = getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet); int iChildCPT = getCPT(nChildNodes, nNumChildNodes, values, order, m_bayesNet); if (source.m_parentSeparator.m_fiParent[iChildCPT] != 0) { m_P[iNodeCPT] *= source.m_parentSeparator.m_fiChild[iChildCPT]/source.m_parentSeparator.m_fiParent[iChildCPT]; } else { m_P[iNodeCPT] = 0; } // update values int i = 0; values[i]++; while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) { values[i] = 0; i++; if (i < m_nNodes.length) { values[i]++; } } } // normalize double sum = 0; for (int iPos = 0; iPos < m_nCardinality; iPos++) { sum += m_P[iPos]; } for (int iPos = 0; iPos < m_nCardinality; iPos++) { m_P[iPos] /= sum; } calcMarginalProbabilities(); } for (Iterator child = m_children.iterator(); child.hasNext(); ) { JunctionTreeNode childNode = (JunctionTreeNode) child.next(); if (childNode != source) { childNode.initializeDown(true); } } if (m_parentSeparator != null) { m_parentSeparator.updateFromChild(); m_parentSeparator.m_parentNode.updateEvidence(this); m_parentSeparator.updateFromParent(); } } // updateEvidence } // class JunctionTreeNode int getCPT(int [] nodeSet, int nNodes, int[] values, int[] order, BayesNet bayesNet) { int iCPTnew = 0; for (int iNode = 0; iNode < nNodes; iNode++) { int nNode = nodeSet[iNode]; iCPTnew = iCPTnew * bayesNet.getCardinality(nNode); iCPTnew += values[order[nNode]]; } return iCPTnew; } // getCPT int [] getCliqueTree(int [] order, Set [] cliques, Set [] separators) { int nNodes = order.length; int [] parentCliques = new int[nNodes]; //for (int i = nNodes - 1; i >= 0; i--) { for (int i = 0; i < nNodes; i++) { int iNode = order[i]; parentCliques[iNode] = -1; if (cliques[iNode] != null && separators[iNode].size() > 0) { //for (int j = nNodes - 1; j > i; j--) { for (int j = 0; j < nNodes; j++) { int iNode2 = order[j]; if (iNode!= iNode2 && cliques[iNode2] != null && cliques[iNode2].containsAll(separators[iNode])) { parentCliques[iNode] = iNode2; j = i; j = 0; j = nNodes; } } } } return parentCliques; } // getCliqueTree /** calculate separator sets in clique tree * * @param order: maximum cardinality ordering of the graph * @param cliques: set of cliques * @return set of separator sets */ Set [] getSeparators(int [] order, Set [] cliques) { int nNodes = order.length; Set [] separators = new HashSet[nNodes]; Set processedNodes = new HashSet(); //for (int i = nNodes - 1; i >= 0; i--) { for (int i = 0; i < nNodes; i++) { int iNode = order[i]; if (cliques[iNode] != null) { Set separator = new HashSet(); separator.addAll(cliques[iNode]); separator.retainAll(processedNodes); separators[iNode] = separator; processedNodes.addAll(cliques[iNode]); } } return separators; } // getSeparators /** * get cliques in a decomposable graph represented by an adjacency matrix * * @param order: maximum cardinality ordering of the graph * @param bAdjacencyMatrix: decomposable graph * @return set of cliques */ Set [] getCliques(int[] order, boolean[][] bAdjacencyMatrix) throws Exception { int nNodes = bAdjacencyMatrix.length; Set [] cliques = new HashSet[nNodes]; //int[] inverseOrder = new int[nNodes]; //for (int iNode = 0; iNode < nNodes; iNode++) { //inverseOrder[order[iNode]] = iNode; //} // consult nodes in reverse order for (int i = nNodes - 1; i >= 0; i--) { int iNode = order[i]; if (iNode == 22) { int h = 3; h ++; } Set clique = new HashSet(); clique.add(iNode); for (int j = 0; j < i; j++) { int iNode2 = order[j]; if (bAdjacencyMatrix[iNode][iNode2]) { clique.add(iNode2); } } //for (int iNode2 = 0; iNode2 < nNodes; iNode2++) { //if (bAdjacencyMatrix[iNode][iNode2] && inverseOrder[iNode2] < inverseOrder[iNode]) { //clique.add(iNode2); //} //} cliques[iNode] = clique; } for (int iNode = 0; iNode < nNodes; iNode++) { for (int iNode2 = 0; iNode2 < nNodes; iNode2++) { if (iNode != iNode2 && cliques[iNode]!= null && cliques[iNode2]!= null && cliques[iNode].containsAll(cliques[iNode2])) { cliques[iNode2] = null; } } } // sanity check if (m_debug) { int [] nNodeSet = new int[nNodes]; for (int iNode = 0; iNode < nNodes; iNode++) { if (cliques[iNode] != null) { Iterator it = cliques[iNode].iterator(); int k = 0; while (it.hasNext()) { nNodeSet[k++] = (Integer) it.next(); } for (int i = 0; i < cliques[iNode].size(); i++) { for (int j = 0; j < cliques[iNode].size(); j++) { if (i!=j && !bAdjacencyMatrix[nNodeSet[i]][nNodeSet[j]]) { throw new Exception("Non clique" + i + " " + j); } } } } } } return cliques; } // getCliques /** * moralize DAG and calculate * adjacency matrix representation for a Bayes Network, effecively * converting the directed acyclic graph to an undirected graph. * * @param bayesNet: * Bayes Network to process * @return adjacencies in boolean matrix format */ public boolean[][] moralize(BayesNet bayesNet) { int nNodes = bayesNet.getNrOfNodes(); boolean[][] bAdjacencyMatrix = new boolean[nNodes][nNodes]; for (int iNode = 0; iNode < nNodes; iNode++) { ParentSet parents = bayesNet.getParentSets()[iNode]; moralizeNode(parents, iNode, bAdjacencyMatrix); } return bAdjacencyMatrix; } // moralize private void moralizeNode(ParentSet parents, int iNode, boolean[][] bAdjacencyMatrix) { for (int iParent = 0; iParent < parents.getNrOfParents(); iParent++) { int nParent = parents.getParent(iParent); if ( m_debug && !bAdjacencyMatrix[iNode][nParent]) System.out.println("Insert " + iNode + "--" + nParent); bAdjacencyMatrix[iNode][nParent] = true; bAdjacencyMatrix[nParent][iNode] = true; for (int iParent2 = iParent + 1; iParent2 < parents.getNrOfParents(); iParent2++) { int nParent2 = parents.getParent(iParent2); if (m_debug && !bAdjacencyMatrix[nParent2][nParent]) System.out.println("Mary " + nParent + "--" + nParent2); bAdjacencyMatrix[nParent2][nParent] = true; bAdjacencyMatrix[nParent][nParent2] = true; } } } // moralizeNode /** * Apply Tarjan and Yannakakis (1984) fill in algorithm for graph * triangulation. In reverse order, insert edges between any non-adjacent * neighbors that are lower numbered in the ordering. * * Side effect: input matrix is used as output * * @param order: * node ordering * @param bAdjacencyMatrix: * boolean matrix representing the graph * @return boolean matrix representing the graph with fill ins */ public boolean[][] fillIn(int[] order, boolean[][] bAdjacencyMatrix) { int nNodes = bAdjacencyMatrix.length; int[] inverseOrder = new int[nNodes]; for (int iNode = 0; iNode < nNodes; iNode++) { inverseOrder[order[iNode]] = iNode; } // consult nodes in reverse order for (int i = nNodes - 1; i >= 0; i--) { int iNode = order[i]; // find pairs of neighbors with lower order for (int j = 0; j < i; j++) { int iNode2 = order[j]; if (bAdjacencyMatrix[iNode][iNode2]) { for (int k = j+1; k < i; k++) { int iNode3 = order[k]; if (bAdjacencyMatrix[iNode][iNode3]) { // fill in if (m_debug && (!bAdjacencyMatrix[iNode2][iNode3] || !bAdjacencyMatrix[iNode3][iNode2]) ) System.out.println("Fill in " + iNode2 + "--" + iNode3); bAdjacencyMatrix[iNode2][iNode3] = true; bAdjacencyMatrix[iNode3][iNode2] = true; } } } } } return bAdjacencyMatrix; } // fillIn /** * calculate maximum cardinality ordering; start with first node add node * that has most neighbors already ordered till all nodes are in the * ordering * * This implementation does not assume the graph is connected * * @param bAdjacencyMatrix: * n by n matrix with adjacencies in graph of n nodes * @return maximum cardinality ordering */ int[] getMaxCardOrder(boolean[][] bAdjacencyMatrix) { int nNodes = bAdjacencyMatrix.length; int[] order = new int[nNodes]; if (nNodes==0) {return order;} boolean[] bDone = new boolean[nNodes]; // start with node 0 order[0] = 0; bDone[0] = true; // order remaining nodes for (int iNode = 1; iNode < nNodes; iNode++) { int nMaxCard = -1; int iBestNode = -1; // find node with higest cardinality of previously ordered nodes for (int iNode2 = 0; iNode2 < nNodes; iNode2++) { if (!bDone[iNode2]) { int nCard = 0; // calculate cardinality for node iNode2 for (int iNode3 = 0; iNode3 < nNodes; iNode3++) { if (bAdjacencyMatrix[iNode2][iNode3] && bDone[iNode3]) { nCard++; } } if (nCard > nMaxCard) { nMaxCard = nCard; iBestNode = iNode2; } } } order[iNode] = iBestNode; bDone[iBestNode] = true; } return order; } // getMaxCardOrder public void setEvidence(int nNode, int iValue) throws Exception { if (m_root == null) { throw new Exception("Junction tree not initialize yet"); } int iJtNode = 0; while (iJtNode < jtNodes.length && (jtNodes[iJtNode] == null ||!jtNodes[iJtNode].contains(nNode))) { iJtNode++; } if (jtNodes.length == iJtNode) { throw new Exception("Could not find node " + nNode + " in junction tree"); } jtNodes[iJtNode].setEvidence(nNode, iValue); } // setEvidence public String toString() { return m_root.toString(); } // toString double [][] m_Margins; public double [] getMargin(int iNode) { return m_Margins[iNode]; } // getMargin public static void main(String[] args) { try { BIFReader bayesNet = new BIFReader(); bayesNet.processFile(args[0]); MarginCalculator dc = new MarginCalculator(); dc.calcMargins(bayesNet); int iNode = 2; int iValue = 0; int iNode2 = 4; int iValue2 = 0; dc.setEvidence(iNode, iValue); dc.setEvidence(iNode2, iValue2); System.out.print(dc.toString()); dc.calcFullMargins(bayesNet); dc.setEvidence(iNode, iValue); dc.setEvidence(iNode2, iValue2); System.out.println("=============="); System.out.print(dc.toString()); } catch (Exception e) { e.printStackTrace(); } } // main} // class MarginCalculator
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?