margincalculator.java

来自「Weka」· Java 代码 · 共 921 行 · 第 1/2 页

JAVA
921
字号
					bDone[nNode] = true;					if (m_debug) {						System.out.println("adding node " +nNode);					}				}			}						// fill in the values			for (int iPos = 0; iPos < m_nCardinality; iPos++) {				int iCPT = getCPT(m_nNodes, m_nNodes.length, values, order, bayesNet);				m_fi[iCPT] = 1.0;				for (int iNode = 0; iNode < m_nNodes.length; iNode++) {					if (bIsContained[iNode]) {						int nNode = m_nNodes[iNode];						int [] nNodes = bayesNet.getParentSet(nNode).getParents();						int iCPT2 = getCPT(nNodes, bayesNet.getNrOfParents(nNode), values, order, bayesNet);						double f = bayesNet.getDistributions()[nNode][iCPT2].getProbability(values[iNode]);						m_fi[iCPT] *= f;					}				}								// update values				int i = 0;				values[i]++;				while (i < m_nNodes.length && values[i] == bayesNet.getCardinality(m_nNodes[i])) {					values[i] = 0;					i++;					if (i < m_nNodes.length) {						values[i]++;					}				}			}		} // calculatePotentials		JunctionTreeNode(Set clique, BayesNet bayesNet, boolean [] bDone) {			m_bayesNet = bayesNet;			m_children = new Vector();			//////////////////////			// initialize node set			m_nNodes = new int[clique.size()];			int iPos = 0;			m_nCardinality = 1;			for(Iterator nodes = clique.iterator(); nodes.hasNext();) {				int iNode = (Integer) nodes.next();				m_nNodes[iPos++] = iNode;				m_nCardinality *= bayesNet.getCardinality(iNode);			}			////////////////////////////////			// initialize potential function			calculatePotentials(bayesNet, clique, bDone);       } // JunctionTreeNode c'tor		/* check whether this junciton tree node contains node nNode		 * 		 */		boolean contains(int nNode) {			for (int iNode = 0; iNode < m_nNodes.length; iNode++) {				if (m_nNodes[iNode]== nNode){					return true;				}			}			return false;		} // contains				public void setEvidence(int nNode, int iValue) throws Exception {			int [] values = new int[m_nNodes.length];			int [] order = new int[m_bayesNet.getNrOfNodes()];			int nNodeIdx = -1;			for (int iNode = 0; iNode < m_nNodes.length; iNode++) {				order[m_nNodes[iNode]] = iNode;				if (m_nNodes[iNode] == nNode) {					nNodeIdx = iNode;				}			}			if (nNodeIdx < 0) {				throw new Exception("setEvidence: Node " + nNode + " not found in this clique");			}			for (int iPos = 0; iPos < m_nCardinality; iPos++) {				if (values[nNodeIdx] != iValue) {					int iNodeCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);					m_P[iNodeCPT] = 0;				}				// update values				int i = 0;				values[i]++;				while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {					values[i] = 0;					i++;					if (i < m_nNodes.length) {						values[i]++;					}				}			}					// normalize			double sum = 0;			for (int iPos = 0; iPos < m_nCardinality; iPos++) {				sum += m_P[iPos];			}			for (int iPos = 0; iPos < m_nCardinality; iPos++) {				m_P[iPos] /= sum;			}			calcMarginalProbabilities();			updateEvidence(this);		} // setEvidence		void updateEvidence(JunctionTreeNode source) {			if (source != this) {				int [] values = new int[m_nNodes.length];				int [] order = new int[m_bayesNet.getNrOfNodes()];				for (int iNode = 0; iNode < m_nNodes.length; iNode++) {					order[m_nNodes[iNode]] = iNode;				}				int [] nChildNodes = source.m_parentSeparator.m_nNodes;				int nNumChildNodes = nChildNodes.length; 				for (int iPos = 0; iPos < m_nCardinality; iPos++) {					int iNodeCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);					int iChildCPT =  getCPT(nChildNodes, nNumChildNodes, values, order, m_bayesNet);					if (source.m_parentSeparator.m_fiParent[iChildCPT] != 0) {						m_P[iNodeCPT] *= source.m_parentSeparator.m_fiChild[iChildCPT]/source.m_parentSeparator.m_fiParent[iChildCPT];					} else {						m_P[iNodeCPT] = 0;					}					// update values					int i = 0;					values[i]++;					while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {						values[i] = 0;						i++;						if (i < m_nNodes.length) {							values[i]++;						}					}				}						// normalize				double sum = 0;				for (int iPos = 0; iPos < m_nCardinality; iPos++) {					sum += m_P[iPos];				}				for (int iPos = 0; iPos < m_nCardinality; iPos++) {					m_P[iPos] /= sum;				}				calcMarginalProbabilities();			}			for (Iterator child = m_children.iterator(); child.hasNext(); ) {				JunctionTreeNode childNode = (JunctionTreeNode) child.next();				if (childNode != source) {					childNode.initializeDown(true);				}			}						if (m_parentSeparator != null) {				m_parentSeparator.updateFromChild();				m_parentSeparator.m_parentNode.updateEvidence(this);				m_parentSeparator.updateFromParent();			}		} // updateEvidence			} // class JunctionTreeNode	int getCPT(int [] nodeSet, int nNodes, int[] values, int[] order, BayesNet bayesNet) {		int iCPTnew = 0;		for (int iNode = 0; iNode < nNodes; iNode++) {			int nNode = nodeSet[iNode];			iCPTnew = iCPTnew * bayesNet.getCardinality(nNode);			iCPTnew += values[order[nNode]];		}		return iCPTnew;	} // getCPT	int [] getCliqueTree(int [] order, Set [] cliques, Set [] separators) {		int nNodes = order.length;		int [] parentCliques = new int[nNodes];		//for (int i = nNodes - 1; i >= 0; i--) {		for (int i = 0; i < nNodes; i++) {			int iNode = order[i];			parentCliques[iNode] = -1;			if (cliques[iNode] != null && separators[iNode].size() > 0) {				//for (int j = nNodes - 1; j > i; j--) {				for (int j = 0; j < nNodes; j++) {					int iNode2 = order[j];					if (iNode!= iNode2 && cliques[iNode2] != null && cliques[iNode2].containsAll(separators[iNode])) {						parentCliques[iNode] = iNode2;						j = i;						j = 0;						j = nNodes;					}				}							}		}		return parentCliques;	} // getCliqueTree		/** calculate separator sets in clique tree	 * 	 * @param order: maximum cardinality ordering of the graph	 * @param cliques: set of cliques	 * @return set of separator sets	 */	Set [] getSeparators(int [] order, Set [] cliques) {		int nNodes = order.length;		Set [] separators = new HashSet[nNodes];		Set processedNodes = new HashSet(); 		//for (int i = nNodes - 1; i >= 0; i--) {		for (int i = 0; i < nNodes; i++) {			int iNode = order[i];			if (cliques[iNode] != null) {				Set separator = new HashSet();				separator.addAll(cliques[iNode]);				separator.retainAll(processedNodes);				separators[iNode] = separator;				processedNodes.addAll(cliques[iNode]);			}		}		return separators;	} // getSeparators		/**	 * get cliques in a decomposable graph represented by an adjacency matrix	 * 	 * @param order: maximum cardinality ordering of the graph	 * @param bAdjacencyMatrix: decomposable graph	 * @return set of cliques	 */	Set [] getCliques(int[] order, boolean[][] bAdjacencyMatrix) throws Exception {		int nNodes = bAdjacencyMatrix.length;		Set [] cliques = new HashSet[nNodes];		//int[] inverseOrder = new int[nNodes];		//for (int iNode = 0; iNode < nNodes; iNode++) {			//inverseOrder[order[iNode]] = iNode;		//}		// consult nodes in reverse order		for (int i = nNodes - 1; i >= 0; i--) {			int iNode = order[i];			if (iNode == 22) {				int h = 3;				h ++;			}			Set clique = new HashSet();			clique.add(iNode);			for (int j = 0; j < i; j++) {				int iNode2 = order[j];				if (bAdjacencyMatrix[iNode][iNode2]) {					clique.add(iNode2);				}			}						//for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {				//if (bAdjacencyMatrix[iNode][iNode2] && inverseOrder[iNode2] < inverseOrder[iNode]) {					//clique.add(iNode2);				//}			//}			cliques[iNode] = clique;		}		for (int iNode = 0; iNode < nNodes; iNode++) {			for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {				if (iNode != iNode2 && cliques[iNode]!= null && cliques[iNode2]!= null && cliques[iNode].containsAll(cliques[iNode2])) {					cliques[iNode2] = null;				}			}		}				// sanity check		if (m_debug) {		int [] nNodeSet = new int[nNodes];		for (int iNode = 0; iNode < nNodes; iNode++) {			if (cliques[iNode] != null) {				Iterator it = cliques[iNode].iterator();				int k = 0;				while (it.hasNext()) {					nNodeSet[k++] = (Integer) it.next();				}				for (int i = 0; i < cliques[iNode].size(); i++) {					for (int j = 0; j < cliques[iNode].size(); j++) {						if (i!=j && !bAdjacencyMatrix[nNodeSet[i]][nNodeSet[j]]) {							throw new Exception("Non clique" + i + " " + j);						}					}				}			}		}		}		return cliques;	} // getCliques	/**	 * moralize DAG and calculate	 * adjacency matrix representation for a Bayes Network, effecively	 * converting the directed acyclic graph to an undirected graph.	 * 	 * @param bayesNet:	 *            Bayes Network to process	 * @return adjacencies in boolean matrix format	 */	public boolean[][] moralize(BayesNet bayesNet) {		int nNodes = bayesNet.getNrOfNodes();		boolean[][] bAdjacencyMatrix = new boolean[nNodes][nNodes];		for (int iNode = 0; iNode < nNodes; iNode++) {			ParentSet parents = bayesNet.getParentSets()[iNode];			moralizeNode(parents, iNode, bAdjacencyMatrix);		}		return bAdjacencyMatrix;	} // moralize	private void moralizeNode(ParentSet parents, int iNode, boolean[][] bAdjacencyMatrix) {		for (int iParent = 0; iParent < parents.getNrOfParents(); iParent++) {			int nParent = parents.getParent(iParent);			if ( m_debug && !bAdjacencyMatrix[iNode][nParent])				System.out.println("Insert " + iNode + "--" + nParent);			bAdjacencyMatrix[iNode][nParent] = true;			bAdjacencyMatrix[nParent][iNode] = true;			for (int iParent2 = iParent + 1; iParent2 < parents.getNrOfParents(); iParent2++) {				int nParent2 = parents.getParent(iParent2);				if (m_debug && !bAdjacencyMatrix[nParent2][nParent])					System.out.println("Mary " + nParent + "--" + nParent2);				bAdjacencyMatrix[nParent2][nParent] = true;				bAdjacencyMatrix[nParent][nParent2] = true;			}		}		} // moralizeNode		/**	 * Apply Tarjan and Yannakakis (1984) fill in algorithm for graph	 * triangulation. In reverse order, insert edges between any non-adjacent	 * neighbors that are lower numbered in the ordering.	 * 	 * Side effect: input matrix is used as output	 * 	 * @param order:	 *            node ordering	 * @param bAdjacencyMatrix:	 *            boolean matrix representing the graph	 * @return boolean matrix representing the graph with fill ins	 */	public boolean[][] fillIn(int[] order, boolean[][] bAdjacencyMatrix) {		int nNodes = bAdjacencyMatrix.length;		int[] inverseOrder = new int[nNodes];		for (int iNode = 0; iNode < nNodes; iNode++) {			inverseOrder[order[iNode]] = iNode;		}		// consult nodes in reverse order		for (int i = nNodes - 1; i >= 0; i--) {			int iNode = order[i];			// find pairs of neighbors with lower order			for (int j = 0; j < i; j++) {				int iNode2 = order[j];				if (bAdjacencyMatrix[iNode][iNode2]) {					for (int k = j+1; k < i; k++) {						int iNode3 = order[k];						if (bAdjacencyMatrix[iNode][iNode3]) {							// fill in							if (m_debug && (!bAdjacencyMatrix[iNode2][iNode3] || !bAdjacencyMatrix[iNode3][iNode2]) )								System.out.println("Fill in " + iNode2 + "--" + iNode3);							bAdjacencyMatrix[iNode2][iNode3] = true;							bAdjacencyMatrix[iNode3][iNode2] = true;						}					}				}			}		}		return bAdjacencyMatrix;	} // fillIn	/**	 * calculate maximum cardinality ordering; start with first node add node	 * that has most neighbors already ordered till all nodes are in the	 * ordering	 * 	 * This implementation does not assume the graph is connected	 * 	 * @param bAdjacencyMatrix:	 *            n by n matrix with adjacencies in graph of n nodes	 * @return maximum cardinality ordering	 */	int[] getMaxCardOrder(boolean[][] bAdjacencyMatrix) {		int nNodes = bAdjacencyMatrix.length;		int[] order = new int[nNodes];		if (nNodes==0) {return order;}		boolean[] bDone = new boolean[nNodes];		// start with node 0		order[0] = 0;		bDone[0] = true;		// order remaining nodes		for (int iNode = 1; iNode < nNodes; iNode++) {			int nMaxCard = -1;			int iBestNode = -1;			// find node with higest cardinality of previously ordered nodes			for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {				if (!bDone[iNode2]) {					int nCard = 0;					// calculate cardinality for node iNode2					for (int iNode3 = 0; iNode3 < nNodes; iNode3++) {						if (bAdjacencyMatrix[iNode2][iNode3] && bDone[iNode3]) {							nCard++;						}					}					if (nCard > nMaxCard) {						nMaxCard = nCard;						iBestNode = iNode2;					}				}			}			order[iNode] = iBestNode;			bDone[iBestNode] = true;		}		return order;	} // getMaxCardOrder	public void setEvidence(int nNode, int iValue) throws Exception {		if (m_root == null) {			throw new Exception("Junction tree not initialize yet");		}		int iJtNode = 0;		while (iJtNode < jtNodes.length && (jtNodes[iJtNode] == null ||!jtNodes[iJtNode].contains(nNode))) {			iJtNode++;		}		if (jtNodes.length == iJtNode) {			throw new Exception("Could not find node " + nNode + " in junction tree");		}		jtNodes[iJtNode].setEvidence(nNode, iValue);	} // setEvidence		public String toString() {		return m_root.toString();	} // toString	double [][] m_Margins;	public double [] getMargin(int iNode) {		return m_Margins[iNode];	} // getMargin		public static void main(String[] args) {		try {			BIFReader bayesNet = new BIFReader();			bayesNet.processFile(args[0]);			MarginCalculator dc = new MarginCalculator();			dc.calcMargins(bayesNet);			int iNode = 2;			int iValue = 0;			int iNode2 = 4;			int iValue2 = 0;			dc.setEvidence(iNode, iValue);			dc.setEvidence(iNode2, iValue2);			System.out.print(dc.toString());			dc.calcFullMargins(bayesNet);			dc.setEvidence(iNode, iValue);			dc.setEvidence(iNode2, iValue2);			System.out.println("==============");			System.out.print(dc.toString());								} catch (Exception e) {			e.printStackTrace();		}	} // main} // class MarginCalculator

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?