margincalculator.java

来自「Weka」· Java 代码 · 共 921 行 · 第 1/2 页

JAVA
921
字号
package weka.classifiers.bayes.net;import weka.classifiers.bayes.net.BIFReader;import weka.classifiers.bayes.net.ParentSet;import weka.classifiers.bayes.net.estimate.BayesNetEstimator;import weka.classifiers.bayes.net.estimate.DiscreteEstimatorBayes;import weka.classifiers.bayes.BayesNet;import java.io.Serializable;import java.util.*;public class MarginCalculator implements Serializable {	  /** for serialization */	  private static final long serialVersionUID = 650278019241175534L;	  boolean m_debug = false;	  public JunctionTreeNode m_root = null;	JunctionTreeNode [] jtNodes;	public int getNode(String sNodeName) {    	int iNode = 0;    	while (iNode < m_root.m_bayesNet.m_Instances.numAttributes()) {    		if (m_root.m_bayesNet.m_Instances.attribute(iNode).name().equals(sNodeName)) {    			return iNode;    		}	    	iNode++;     	}    	//throw new Exception("Could not find node [[" + sNodeName + "]]");    	return -1;	}	public String toXMLBIF03() {return m_root.m_bayesNet.toXMLBIF03();}		/**	 * Calc marginal distributions of nodes in Bayesian network	 *	 Note that a connected network is assumed. 	 *	 Unconnected networks may give unexpected results.	 * @param bayesNet	 * @return root of junction tree	 */	public void calcMargins(BayesNet bayesNet) throws Exception {		//System.out.println(bayesNet.toString());		boolean[][] bAdjacencyMatrix = moralize(bayesNet);		process(bAdjacencyMatrix, bayesNet);	} // calcMargins	public void calcFullMargins(BayesNet bayesNet) throws Exception {		//System.out.println(bayesNet.toString());		int nNodes = bayesNet.getNrOfNodes();		boolean[][] bAdjacencyMatrix = new boolean[nNodes][nNodes];		for (int iNode = 0; iNode < nNodes; iNode++) {			for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {				bAdjacencyMatrix[iNode][iNode2] = true;			}		}		process(bAdjacencyMatrix, bayesNet);	} // calcMargins			public void process(boolean[][] bAdjacencyMatrix, BayesNet bayesNet) throws Exception {		int[] order = getMaxCardOrder(bAdjacencyMatrix);		bAdjacencyMatrix = fillIn(order, bAdjacencyMatrix);		order = getMaxCardOrder(bAdjacencyMatrix);		Set [] cliques = getCliques(order, bAdjacencyMatrix);		Set [] separators = getSeparators(order, cliques);		int [] parentCliques = getCliqueTree(order, cliques, separators);		// report cliques		int nNodes = bAdjacencyMatrix.length;		if (m_debug) {		for (int i = 0; i < nNodes; i++) {			int iNode = order[i];			if (cliques[iNode] != null) {				System.out.print("Clique " + iNode + " (");				Iterator nodes = cliques[iNode].iterator();				while (nodes.hasNext()) {					int iNode2 = (Integer) nodes.next();					System.out.print(iNode2 + " " + bayesNet.getNodeName(iNode2));					if (nodes.hasNext()) {						System.out.print(",");					}				}				System.out.print(") S(");				nodes = separators[iNode].iterator();				while (nodes.hasNext()) {					int iNode2 = (Integer) nodes.next();					System.out.print(iNode2 + " " + bayesNet.getNodeName(iNode2));					if (nodes.hasNext()) {						System.out.print(",");					}				}				System.out.println(") parent clique " + parentCliques[iNode]);			}				}		}						jtNodes = getJunctionTree(cliques, separators, parentCliques, order, bayesNet);		m_root = null;		for (int iNode = 0; iNode < nNodes; iNode++) {			if (parentCliques[iNode] < 0 && jtNodes[iNode] != null) {				m_root = jtNodes[iNode];				break;			}		}		m_Margins = new double[nNodes][];		initialize(jtNodes, order, cliques, separators, parentCliques);				// sanity check		for (int i = 0; i < nNodes; i++) {			int iNode = order[i];			if (cliques[iNode] != null) {				if (parentCliques[iNode] == -1 && separators[iNode].size() > 0) {					throw new Exception("Something wrong in clique tree");				}			}		}		if (m_debug) {			//System.out.println(m_root.toString());		}	} // process			void initialize(JunctionTreeNode [] jtNodes, int [] order, Set [] cliques, Set [] separators, int [] parentCliques) {		int nNodes = order.length;		for (int i = nNodes - 1; i >= 0; i--) {			int iNode = order[i];			if (jtNodes[iNode]!=null) {				jtNodes[iNode].initializeUp();			}		}			for (int i = 0; i < nNodes; i++) {			int iNode = order[i];			if (jtNodes[iNode]!=null) {				jtNodes[iNode].initializeDown(false);			}		}		} // initialize		JunctionTreeNode [] getJunctionTree(Set [] cliques, Set [] separators, int [] parentCliques, int [] order, BayesNet bayesNet) {		int nNodes = order.length;		JunctionTreeNode root = null;		JunctionTreeNode [] jtns = new JunctionTreeNode[nNodes]; 		boolean [] bDone = new boolean[nNodes];		// create junction tree nodes		for (int i = 0; i < nNodes; i++) {			int iNode = order[i];			if (cliques[iNode] != null) {				jtns[iNode] = new JunctionTreeNode(cliques[iNode], bayesNet, bDone);			}		}		// create junction tree separators		for (int i = 0; i < nNodes; i++) {			int iNode = order[i];			if (cliques[iNode] != null) {				JunctionTreeNode parent = null;				if (parentCliques[iNode] > 0) {					parent = jtns[parentCliques[iNode]];					JunctionTreeSeparator jts = new JunctionTreeSeparator(separators[iNode], bayesNet, jtns[iNode], parent);					jtns[iNode].setParentSeparator(jts);					jtns[parentCliques[iNode]].addChildClique(jtns[iNode]);				} else {					root = jtns[iNode];					}			}		}		return jtns;	} // getJunctionTree		public class JunctionTreeSeparator implements Serializable {		  private static final long serialVersionUID = 6502780192411755343L;		int [] m_nNodes;		int m_nCardinality;		double [] m_fiParent;		double [] m_fiChild;		JunctionTreeNode m_parentNode;		JunctionTreeNode m_childNode;		BayesNet m_bayesNet;				JunctionTreeSeparator(Set separator, BayesNet bayesNet, JunctionTreeNode childNode, JunctionTreeNode parentNode) {			//////////////////////			// initialize node set			m_nNodes = new int[separator.size()];			int iPos = 0;			m_nCardinality = 1;			for(Iterator nodes = separator.iterator(); nodes.hasNext();) {				int iNode = (Integer) nodes.next();				m_nNodes[iPos++] = iNode;				m_nCardinality *= bayesNet.getCardinality(iNode);			}			m_parentNode = parentNode;			m_childNode = childNode;			m_bayesNet = bayesNet;		} // c'tor				/** marginalize junciontTreeNode node over all nodes outside the separator set		 * of the parent clique		 *		 */		public void updateFromParent() {			double [] fis = update(m_parentNode); 			if (fis == null) {				m_fiParent = null;			} else {				m_fiParent = fis;				// normalize				double sum = 0;				for (int iPos = 0; iPos < m_nCardinality; iPos++) {					sum += m_fiParent[iPos];				}				for (int iPos = 0; iPos < m_nCardinality; iPos++) {					m_fiParent[iPos] /= sum;				}			}		} // updateFromParent		/** marginalize junciontTreeNode node over all nodes outside the separator set		 * of the child clique		 *		 */		public void updateFromChild() {			double [] fis = update(m_childNode); 			if (fis == null) {				m_fiChild = null;			} else {				m_fiChild = fis;				// normalize				double sum = 0;				for (int iPos = 0; iPos < m_nCardinality; iPos++) {					sum += m_fiChild[iPos];				}				for (int iPos = 0; iPos < m_nCardinality; iPos++) {					m_fiChild[iPos] /= sum;				}			}		} // updateFromChild				/** marginalize junciontTreeNode node over all nodes outside the separator set		 * 		 * @param node: one of the neighboring junciont tree nodes of this separator		 */		public double [] update(JunctionTreeNode node) {			if (node.m_P == null) {				return null;			}			double [] fi = new double[m_nCardinality];			int [] values = new int[node.m_nNodes.length];			int [] order = new int[m_bayesNet.getNrOfNodes()];			for (int iNode = 0; iNode < node.m_nNodes.length; iNode++) {				order[node.m_nNodes[iNode]] = iNode;			}			// fill in the values			for (int iPos = 0; iPos < node.m_nCardinality; iPos++) {				int iNodeCPT = getCPT(node.m_nNodes, node.m_nNodes.length, values, order, m_bayesNet);				int iSepCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);				fi[iSepCPT] += node.m_P[iNodeCPT];				// update values				int i = 0;				values[i]++;				while (i < node.m_nNodes.length && values[i] == m_bayesNet.getCardinality(node.m_nNodes[i])) {					values[i] = 0;					i++;					if (i < node.m_nNodes.length) {						values[i]++;					}				}			}			return fi;		} // update	} // class JunctionTreeSeparator	public class JunctionTreeNode implements Serializable {		  private static final long serialVersionUID = 650278019241175536L;		/** reference Bayes net for information about variables like name, cardinality, etc.		 * but not for relations between nodes **/		BayesNet m_bayesNet;		/** nodes of the Bayes net in this junction node **/		public int [] m_nNodes;		/** cardinality of the instances of variables in this junction node **/		int m_nCardinality;		/** potentials for first network **/		double [] m_fi;		/** distribution over this junction node according to first Bayes network **/		double [] m_P;		double [][] m_MarginalP;						JunctionTreeSeparator m_parentSeparator;		public void setParentSeparator(JunctionTreeSeparator parentSeparator) {m_parentSeparator = parentSeparator;}		public Vector m_children;		public void addChildClique(JunctionTreeNode child) {m_children.add(child);}		public void initializeUp() {			m_P = new double[m_nCardinality];			for (int iPos = 0; iPos < m_nCardinality; iPos++) {				m_P[iPos] = m_fi[iPos];			}			int [] values = new int[m_nNodes.length];			int [] order = new int[m_bayesNet.getNrOfNodes()];			for (int iNode = 0; iNode < m_nNodes.length; iNode++) {				order[m_nNodes[iNode]] = iNode;			}			for (Iterator child = m_children.iterator(); child.hasNext(); ) {				JunctionTreeNode childNode = (JunctionTreeNode) child.next();				JunctionTreeSeparator separator = childNode.m_parentSeparator;			// Update the values			for (int iPos = 0; iPos < m_nCardinality; iPos++) {				int iSepCPT = getCPT(separator.m_nNodes, separator.m_nNodes.length, values, order, m_bayesNet);				int iNodeCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);					m_P[iNodeCPT] *= separator.m_fiChild[iSepCPT];									// update values				int i = 0;				values[i]++;				while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {					values[i] = 0;					i++;					if (i < m_nNodes.length) {						values[i]++;					}				}			}			}			// normalize			double sum = 0;			for (int iPos = 0; iPos < m_nCardinality; iPos++) {				sum += m_P[iPos];			}			for (int iPos = 0; iPos < m_nCardinality; iPos++) {				m_P[iPos] /= sum;			}			if (m_parentSeparator != null) { // not a root node				m_parentSeparator.updateFromChild();			}		} // initializeUp		public void initializeDown(boolean recursively) {			if (m_parentSeparator == null) { // a root node				calcMarginalProbabilities();			} else {			m_parentSeparator.updateFromParent();				int [] values = new int[m_nNodes.length];				int [] order = new int[m_bayesNet.getNrOfNodes()];				for (int iNode = 0; iNode < m_nNodes.length; iNode++) {					order[m_nNodes[iNode]] = iNode;				}								// Update the values				for (int iPos = 0; iPos < m_nCardinality; iPos++) {					int iSepCPT = getCPT(m_parentSeparator.m_nNodes, m_parentSeparator.m_nNodes.length, values, order, m_bayesNet);					int iNodeCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);					if ( m_parentSeparator.m_fiChild[iSepCPT] > 0) {						m_P[iNodeCPT] *= m_parentSeparator.m_fiParent[iSepCPT] / m_parentSeparator.m_fiChild[iSepCPT];					} else {						m_P[iNodeCPT] = 0;					}					// update values					int i = 0;					values[i]++;					while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {						values[i] = 0;						i++;						if (i < m_nNodes.length) {							values[i]++;						}					}				}				// normalize				double sum = 0;				for (int iPos = 0; iPos < m_nCardinality; iPos++) {					sum += m_P[iPos];				}				for (int iPos = 0; iPos < m_nCardinality; iPos++) {					m_P[iPos] /= sum;				}				m_parentSeparator.updateFromChild();				calcMarginalProbabilities();			}			if (recursively) {				for (Iterator child = m_children.iterator(); child.hasNext(); ) {					JunctionTreeNode childNode = (JunctionTreeNode) child.next();					childNode.initializeDown(true);				}						}		} // initializeDown						/** calculate marginal probabilities for the individual nodes in the clique.		 * Store results in m_MarginalP 		 */		void calcMarginalProbabilities() {						// calculate marginal probabilities			int [] values = new int[m_nNodes.length];			int [] order = new int[m_bayesNet.getNrOfNodes()];			m_MarginalP = new double[m_nNodes.length][];			for (int iNode = 0; iNode < m_nNodes.length; iNode++) {				order[m_nNodes[iNode]] = iNode;				m_MarginalP[iNode]=new double[m_bayesNet.getCardinality(m_nNodes[iNode])];			}			for (int iPos = 0; iPos < m_nCardinality; iPos++) {				int iNodeCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);				for (int iNode = 0; iNode < m_nNodes.length; iNode++) {					m_MarginalP[iNode][values[iNode]] += m_P[iNodeCPT];				}				// update values				int i = 0;				values[i]++;				while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {					values[i] = 0;					i++;					if (i < m_nNodes.length) {						values[i]++;					}				}			}						for (int iNode = 0; iNode < m_nNodes.length; iNode++) {				m_Margins[m_nNodes[iNode]] = m_MarginalP[iNode]; 			}		} // calcMarginalProbabilities				public String toString() {			StringBuffer buf = new StringBuffer();			for (int iNode = 0; iNode < m_nNodes.length; iNode++) {				buf.append(m_bayesNet.getNodeName(m_nNodes[iNode]) + ": ");				for (int iValue = 0; iValue < m_MarginalP[iNode].length; iValue++) {					buf.append(m_MarginalP[iNode][iValue] + " ");				}				buf.append('\n');			}			for (Iterator child = m_children.iterator(); child.hasNext(); ) {				JunctionTreeNode childNode = (JunctionTreeNode) child.next();				buf.append("----------------\n");				buf.append(childNode.toString());			}						return buf.toString();		} // toString				void calculatePotentials(BayesNet bayesNet, Set clique, boolean [] bDone) {			m_fi = new double[m_nCardinality];						int [] values = new int[m_nNodes.length];			int [] order = new int[bayesNet.getNrOfNodes()];			for (int iNode = 0; iNode < m_nNodes.length; iNode++) {				order[m_nNodes[iNode]] = iNode;			}			// find conditional probabilities that need to be taken in account			boolean [] bIsContained = new boolean[m_nNodes.length];			for (int iNode = 0; iNode < m_nNodes.length; iNode++) {				int nNode = m_nNodes[iNode];				bIsContained[iNode] = !bDone[nNode];				for (int iParent = 0; iParent < bayesNet.getNrOfParents(nNode); iParent++) {					int nParent = bayesNet.getParent(nNode, iParent);					if (!clique.contains(nParent)) {						bIsContained[iNode] = false;					}				}				if (bIsContained[iNode]) {

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?