⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bayesiannetwork.java

📁 次代码是数据挖掘软件的一个小例子
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/*
Bao Jie 2002-04-02
Iowa State University
*/

package weka.classifiers;

import java.io.*;
import java.util.*;
import weka.core.*;
import Probtable;

public class BayesianNetwork extends DistributionClassifier 
		// Bayesian Network
{
  boolean bDebug = true;
  
  int maxDrgree = 10;
  
  /** The instances used for training. */
  public Instances m_Instances;
  
  /** The Attibutes of instances */
  public static Attribute[] m_Att;
  
  /** the Bayesian network */
  public static int [][] m_Network;
    // (i,j) = 1: a edge from i to j
    //       = 0: no edge from i to j (but there may be edge of j to i)
  public String m_networkFile;
    
  /** conditonal table at nodes */
  public Probtable [] m_Table;

  
  // Convert father nodes' combination to a index number
  // Compress a high-dimensional conditional porb. table into a 2-dimensional table.
  public int Combination2Index(int[][] Combination,int Node, int[][]network )
  {
  	int numAtt = m_Instances.numAttributes() -1;
  	int Index=0;
  	int f = NumberOfFather(Node, network , numAtt)+1;
  	int base = 1;
  	
  	Index += Combination[0][1];
  	base *= m_Instances.numClasses();
  	
// 	trace("node "+ Node);
  	for (int i = 0 ; i < f ; i ++)
  	{
  		int att = Combination[i+1][0];
  		int att_num = m_Att[att].numValues();
// 		trace( "att =" + att + " att_num =" +att_num );
  		Index += Combination[i+1][1] * base;
  		base *= att_num;
  	}

/* //trace begin
	PrintCombination(int[][] Combination, f+1);
  	System.out.println("\nindex = " + Index);
  	
*/ //trace end  	
  	return Index;
  }
  
  public void PrintCombination(int[][] Combination, int lim)
  {
  	System.out.print("attribute ");
 	for (int i = 0 ; i < lim ; i ++)
  	{
  		System.out.print("\t" + Combination[i][0]);
  	}
 	System.out.print("\nvalue     ");
  	for (int i = 0 ; i < lim ; i ++)
  	{
  		System.out.print("\t" + Combination[i][1]);
  	}
  }
  
  
  // Resume father nodes' combination from a index number
  // the first col is for the value of classes
  public int[][]  Index2Combination(int Index, int Node, int [][] network)
  {
  	int numAtt = m_Instances.numAttributes() -1;
  	int[][] Combination = new int[NumberOfFather(Node, network,numAtt)+2][2];
  	int j = 0;
  	int base = 1;
  	
  	//value of class
	Combination[j][0]= -1;
	Combination[j][1]= (Index % m_Instances.numClasses());
  	Index -= (Index %  m_Instances.numClasses());
  	Index /= m_Instances.numClasses();
  	j++;
	
  	for ( int i = 0 ; i < numAtt ; i++)
  	{
  		if ((network[i][Node] == 1) || (i == Node))
  		{
	  		int att_num = m_Att[i].numValues();
  			Combination[j][0] = i;
  			Combination[j][1] = (Index %  att_num);   			
  			
  			Index -= (Index %  att_num);
  			Index /= att_num;
  			j++;
  		}
  	}  	
   	return Combination;
  }
  
  // return the number of father of certain node
  public int NumberOfFather(int Node, int [][] network, int dim)
  {
//  trace("NumberOfFather ");trace(Node);
  	
  	int num=0;
  	for ( int i = 0 ; i < dim ; i++)
  	{
  		num += network[i][Node];
  	}  	
  	return num;  	
  }
  
  // return the len of prob. table of certain node
  public int LengthOfTable(int Node, int [][] network)
  {
  	int num = m_Instances.numClasses() * m_Att[Node].numValues();
  	int numAtt = m_Instances.numAttributes() -1;
  	for ( int i = 0 ; i < numAtt ; i++)
  	{
  		if (network[i][Node] == 1)
  			num *= m_Att[i].numValues();
  	}  	
  	return num;  	
  }
  
  // build prob. table for a fixed network
  public void BuildTable( int[][] network, Instances inst)
  {  	
  	int numAtt = inst.numAttributes() -1;
  	
  	m_Table = new Probtable[numAtt];
  	for(int i = 0 ; i < numAtt ; i ++)
  	{
	  	// allocate storage space for tables
//	  	trace("LengthOfTable(i,network) "+ LengthOfTable(i,network));
  		m_Table[i] = new Probtable(LengthOfTable(i,network));
  	
  		// calc the table
  		
  		// which are the parents of this node(attribute)
    	int[][] Combination = new int[NumberOfFather(i, network,numAtt)+2][2];
	  	int numC = inst.numClasses();
//	  	trace("num " + num + " NumberOfFather " + NumberOfFather(i, network,numAtt));
	  	
	  	Combination[0][0] = -1;
	  	int j = 1;
	  	for ( int m = 0 ; m < numAtt ; m++)
  		{
  			if ((network[m][i] == 1) || (m == i))
  				Combination [j++][0] = m ;
  		}
  		  	  		
  	    Enumeration enumInsts = inst.enumerateInstances();
  	    int[] numClass = new int[numC];// the number of samples of each class
  	    for (int u = 0 ; u < numC ; u ++) numClass[u] = 0;
// 	    trace("inst.numInstances() "+ inst.numInstances());
// 	    trace("attribute " + i);
  	    while (enumInsts.hasMoreElements()) 
    	{
    		Instance instance = (Instance) enumInsts.nextElement();
     		if (!instance.classIsMissing()) 
      		{
      			Combination[0][1] =  (int)instance.classValue();     			
	  			int h = 0;
				for (int attIndex = 0 ; attIndex < numAtt ; attIndex++)
				{
	  				Attribute attribute = (Attribute) m_Att[attIndex];
		    		if (attribute.isNominal()) 
	   				{
	   					if ((network[attIndex][i] == 1) || (attIndex == i))// the fathor of this node or itself
	   					{
	   						h++;
//	   						trace("h ="+ h +" value =" + (int)instance.value(attribute));
	   						Combination[h][1] = (int)instance.value(attribute);
	   					}
	     			}//if
	   			}// for
//	   			trace(Combination2Index( Combination, i, m_Network));
	
	   			m_Table[i].Prob[Combination2Index( Combination, i, m_Network)] += 1; 
	   			m_Table[i].type[Combination2Index( Combination, i, m_Network)] = (int)instance.classValue();
	   			numClass[(int)instance.classValue()] ++;	   			
      		}// END if (!instance.classIsMissing())
    	}// END while
//    	for (int k = 0 ; k < numC ; k ++)
//    	{  trace("number of class "+ k + " = "+ numClass[k]); }
// 		trace("\nbefore normalization");m_Table[i].print();
    	m_Table[i].ClassNormalize(numClass, numC );
// 		trace("\nafter normalization");m_Table[i].print();
    } // for i 
  }
  
  // print the network out
  public void PrintNetwork(int [][]network, int dim)
  {
  	System.out.println("number of attributes in network: " + dim);
  	boolean hasEdge = false;
  	for (int i = 0 ; i < dim ; i ++)
  	{
  		for ( int j = 0 ; j < dim ; j++)
  		{
  			if( m_Network[i][j]== 1)
  			{
  				hasEdge = true;
  				System.out.println("Edge: from " + i +" to "+j);
  			}
// 			System.out.println("i: "+ i +" j: "+ j + " value: "+ network[i][j]);
  		}
  	}
  	if ( hasEdge == false )
		System.out.println("No edge in the network" );
  }
  // read network from a saved file
  // return: number of node in that network
  public int LoadNetwork(int [][] network, String networkFile)
  {
      int dim = 0;
	  File f = new File(networkFile);
 	  try
 	  {
 	  	  FileInputStream s = new FileInputStream(f);
 	  	  dim = s.read();
  	  	  for (int i = 0 ; i < dim ; i ++)
  		      for ( int j = 0 ; j < dim ; j++)
  			      network[i][j]=s.read();
  		  s.close();
  	  }
  	  catch(IOException e){}  	 
  	  
  	  return dim;
  }
  
  // save network to a file
  public void saveNetwork(int [][] network, int dim, String networkFile)
  {
	  File f = new File(networkFile);
 	  try
 	  {
 	  	  FileOutputStream s = new FileOutputStream(f);
 	  	  s.write(dim);
  	  	  for (int i = 0 ; i < dim ; i ++)
  		      for ( int j = 0 ; j < dim ; j++)
  			      s.write(network[i][j]);
  		  s.close();
  	  }
  	  catch(IOException e){}
  	  
  }
  
  // Automatically  search a network struture
  // k is the maximal parents number of a node
  public void SearchNetwork(int [][] network, int k, Instances inst )
  {
  	System.out.println("Search in progress...");
  	int numAtt = inst.numAttributes() -1;
// 	trace("numAtt="+numAtt);
  	
  	// create a network with no edges
// 	trace("initialize");
  	for (int i = 0 ; i < numAtt ; i ++)
  	{
  		for ( int j = 0 ; j < numAtt ; j++)
  		{
  			network[i][j]=0;
  		}
  	}
// 	PrintNetwork(network,numAtt);
  	
  	double score = -10e10;//-inf
  	double maxscore = 0;
// 	trace("search");
  	do
  	{
  		maxscore = score;
  		for ( int i = 0  ; i < numAtt ; i ++)
  		{ 
		  	for ( int j = i+1 ; j < numAtt ; j++)
  			{
  				System.out.print('.');
// 				trace("start, score = " + score );

//   				trace("try pair: "+ i + " , " +j );
 				  				
  				int oldij = network[i][j];
  				int oldji = network[j][i];
  				int max = 0;
//				trace("max="+max);
  				
// 				trace("case 1");
  				// case 1 : no edge i <-> j
  				network[i][j] = 0; network[j][i] = 0;
  				BuildTable(network, inst);
  				double newscore1 = ScoreFromData(network, inst, numAtt);
  				if (newscore1 > score) 
  				{ score = newscore1; max = 1; }
//  				trace("newscore1= "+newscore1);
  				
// 				trace("case 2");
  				// case 2: add edge j->i
  				network[i][j] = 0;network[j][i] = 1;
  				double newscore2 = 0 ;
  				if( NumberOfFather(i,network,numAtt) <=k && LoopCheck(network, i, numAtt) == false)
  				{
// 					trace("BuildTable");
  					BuildTable(network, inst);
//    				trace("ScoreFromData");
					newscore2 = ScoreFromData(network, inst, numAtt);
  					if (newscore2 > score) 
  					{ score = newscore2; max = 2; }
  				}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -