⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 gbninfer.cs

📁 基于BP算法的贝叶斯网络参数学习
💻 CS
字号:
using System;
using System.Collections;
using System.Windows.Forms;

namespace bs
{
	/// <summary>
	/// BNInfer 的摘要说明。
	/// </summary>
	abstract class gBNInfer
	{
		protected gBNet m_net;

		protected gBNInfer(gBNet net)
		{
			m_net = net;
		}

		public abstract double GetBelief(string x, string o);
		
	}

	class gBElim : gBNInfer
	{
		private ArrayList m_buckets;

		public gBElim(gBNet net) : base(net)
		{
			m_buckets = new ArrayList();

			PrepareBuckets();
		}

		private class Bucket
		{
			public Bucket(int i) 
			{
				id = i;
				parentNodes = new ArrayList();
				childBuckets = new ArrayList();
			}

			public int id;
			public ArrayList parentNodes;
			public ArrayList childBuckets;
		}

		//重点待理解的地方5
		public override double GetBelief(string x, string o)
		{
			m_net.ResetNodes();

			double norm  = 1.0;
	
			if( o.Length > 0 )
			{
				m_net.SetNodes(o);
				norm = Sum(0,0);
			}

			m_net.SetNodes(x);

			return Sum(0,0)/norm;
		}

		private void PrepareBuckets()
		{
			ArrayList nodes = m_net.Nodes;

			for(int i=0; i<nodes.Count; ++i)
				m_buckets.Add(new Bucket(i));

			// go through all buckets from buttom up
			for(int i=nodes.Count-1; i>=0; --i)
			{
				gBNode theNode = (gBNode)nodes[i];
				Bucket theBuck = (Bucket)m_buckets[i];

				foreach(gBNode node in theNode.Parents)
					theBuck.parentNodes.Add(node);

				//////重点代理解的地方1
				foreach(Bucket nxtBuck in theBuck.childBuckets)
				{
					foreach(gBNode node in nxtBuck.parentNodes)
						if( node.ID!=i && !theBuck.parentNodes.Contains(node) ) 
							theBuck.parentNodes.Add(node);
				}

				int max_nid = FindMaxNodeId(theBuck.parentNodes);

				if( max_nid >= 0 )
					((Bucket)m_buckets[max_nid]).childBuckets.Add(theBuck);
			}
		}

		//重点代理解的地方2
		protected double Sum(int nid, int para)
		{
			gBNode theNode = (gBNode)m_net.Nodes[nid];
			Bucket theBuck = (Bucket)m_buckets[nid];

			int p_cnt = theNode.Parents.Count;

			int cond = ( para & ((1<<p_cnt)-1));

			double pr = 0.0;
			//	MessageBox.Show("p_cnt="+p_cnt.ToString()+"cond="+cond.ToString());
			
       
			// sum over all possible values
			for(int e=0; e<2; ++e)
			{
				if( theNode.Evidence != -1 && theNode.Evidence != e )
					continue;

				double tmpPr = theNode.CPT[cond, e];

				// count child bucket's contribution
				foreach(Bucket nxtBuck in theBuck.childBuckets)
				{
					int next_para = 0;

					for(int j=0; j<nxtBuck.parentNodes.Count; ++j)
					{
						gBNode pnode = (gBNode)nxtBuck.parentNodes[j];
		
						int pos = theBuck.parentNodes.IndexOf(pnode);

						next_para += (pos>=0) ? ((para>>pos & 1)<<j) : (e<<j);
						//MessageBox.Show("next_para="+next_para.ToString());
					}

					tmpPr *= Sum(nxtBuck.id, next_para);
					//MessageBox.Show("tempr="+tmpPr.ToString());
				}
	   
				pr += tmpPr;
			}

			return pr;
		}

		private int FindMaxNodeId(ArrayList nodes)
		{
			int max = -1;
			foreach(gBNode node in nodes)
				if( node.ID > max ) max = node.ID;
			return max;
		}
		
	}
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -