⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bnet.cs

📁 基于BP算法的贝叶斯网络参数学习
💻 CS
字号:
using System;
using System.IO;
using System.Xml;
using System.Collections;
using System.Diagnostics;

namespace bs
{
	/// <summary>
	/// BNet 的摘要说明。
	/// </summary>
	/// 
	
	

	class BNode
	{
		public ArrayList Parents;
		public double[,] CPT;
		public int Evidence=-1;

		private string m_name;
		private int m_id;
		private int m_range;
		public BNode(string name, int id, int range)
		{
			m_name = name;
			m_id = id;
			m_range = range;
			Parents = new ArrayList();
			Clear();
		}
		public int evidence
		{
			get
			{
				return Evidence;
			}
			set
			{
				Evidence=value;
			}
		}

		public string Name
		{
			get{ return m_name; }
		}

		public int ID
		{
			get{ return m_id; }
		}

		public int Range
		{
			get{ return m_range; }
		}

		public void Clear()
		{
			Evidence = -1;
		}

		
	}
	public class BNet
	{
		private ArrayList m_bnNodes;
	
			public BNet()
			{
				m_bnNodes = new ArrayList();
			}

		public ArrayList Nodes
		{
			get{return m_bnNodes;}
		}

		public void ResetNodes()
		{
			foreach(BNode node in m_bnNodes)
				node.Clear();
		}

		public void SetNodes(string s)
		{
			string[] obs = s.Split(';');
			foreach(string ob in obs)
			{
				string[] pair = ob.Split('=');
				foreach(BNode node in m_bnNodes)
					if( pair.Length==2 && 
						node.Name==pair[0].Trim().ToLower() )
					{
						node.Evidence = Convert.ToInt32(pair[1]);
						break;
					}
			}
		}

		public void Build(string xmlfile)
		{
			XmlDocument doc = new XmlDocument();
			doc.Load(xmlfile);
			XmlElement root = doc.DocumentElement;

			XmlNodeList nodeList = root.SelectNodes("/BNetNodes/*");

			foreach(XmlNode node in nodeList)
				CreateNode(node);
		}

		private void CreateNode(XmlNode theXmlNode)
		{
			XmlAttributeCollection attr = theXmlNode.Attributes;
			int range = Convert.ToInt32(attr.GetNamedItem("Range").Value);

			// Creat new node and add it to the list later
			int nid = m_bnNodes.Count;
			BNode newBNode = new BNode(theXmlNode.Name.ToLower(), nid, range);

			// Connect to all its parents
			XmlNodeList xmlNodes = theXmlNode.SelectNodes("Parents/*");
			foreach(XmlNode xml_node in xmlNodes)
				foreach(BNode bn_node in m_bnNodes)
					if( xml_node.Name.ToLower() == bn_node.Name )
					{
						newBNode.Parents.Add(bn_node);
						break;
					}

			// Prepare CP Table
			int table_rows = 1;
			xmlNodes = theXmlNode.SelectNodes("CPT_Col");
			if(range!=xmlNodes.Count+1) throw new Exception("CPT cols mismatch");
			foreach(BNode bn_node in newBNode.Parents)
				table_rows *= bn_node.Range;

			newBNode.CPT = new double[table_rows, range];

			//重点待理解的地方3,明白下面的xmlnodes.count表示xml节点“cp”的数量,
			//此处皆为1。对于节点的range!=2的情况可加以改进
			// Assign value to CP Table
			for(int i=0; i<xmlNodes.Count; ++i)
			{
				XmlNodeList cpNodes = xmlNodes[i].SelectNodes("CP");

				if(cpNodes.Count!=table_rows) 
					throw new Exception( "CPT Rows mismatch");

				for( int j=0; j<table_rows; ++j )
					newBNode.CPT[j,i] = Convert.ToDouble(cpNodes[j].InnerText);
			}

			// Assign value to the last col of the table by rule of Sum Pcol=1.0
			for(int i=0; i<table_rows; ++i )
			{
				double pr = 1.0;
				for(int j=0; j<range-1; ++j)
					pr -= newBNode.CPT[i,j];
				if( pr < 0 ) throw new Exception("Probability does not normalize");
				newBNode.CPT[i,range-1] = pr;
			}

			m_bnNodes.Add(newBNode);
		}

		public void PrintNet(string fileName)
		{
			StreamWriter w = new StreamWriter(fileName);
	
			w.WriteLine("The BeliefNet Layout");

			foreach(BNode node in m_bnNodes)
			{
				w.WriteLine("");
				w.WriteLine("Node: {0}", node.Name);
				w.WriteLine("    Total {0} Parents", node.Parents.Count);
				if( node.Parents.Count > 0 ) w.Write("    Parents:");  
				foreach(BNode pnode in node.Parents)
					w.Write("   {0}", pnode.Name); 
				if( node.Parents.Count > 0 ) w.WriteLine("");  

				w.WriteLine("    CPT");
				int rows = node.CPT.GetUpperBound(0) + 1;
				int cols = node.CPT.GetUpperBound(1) + 1;
				for(int r=0; r<rows; ++r)
				{
					string cpt = "";
					for(int c=0; c<cols; ++c)
						cpt += "   " + node.CPT[r,c].ToString();
					w.WriteLine("    {0}", cpt);
				}
			}
			w.Close();
		}

		
		
	}
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -