⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bayesiannetwork.java

📁 java BayesianNetwork 源码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
//  				trace("newscore2= "+newscore2);
// 				trace("case 3");
  				// case 3: add edge i->j
  				network[i][j]=1; network[j][i] = 0;
  				double newscore3 =0 ;
  				if( NumberOfFather(j,network,numAtt) <=k && LoopCheck(network, j, numAtt) == false)
  				{
  					BuildTable(network, inst);
  					newscore3 = ScoreFromData(network, inst, numAtt);
  					if (newscore3 > score) 
  					{ score = newscore3; max = 3; }
  				}
//  				trace("newscore3= "+newscore3);
  				
// 				trace("choose best one");
//				trace("max="+max);
  				switch (max)
  				{
  					case 0:
  					    network[i][j] = oldij; network[j][i] = oldji;
//  					    trace("keep original graph");
  					    break;
  					case 1:
  					    network[i][j] = 0; network[j][i] = 0;
//  					    trace("no edge between " + j +" and "+i);
  					    break;
  					case 2:
  						network[i][j] = 0; network[j][i] = 1;
//  					    trace("add edge from " + j +" to "+i);
  						break;
  					case 3:
  						network[i][j] = 1; network[j][i] = 0;
//  						trace("add edge from " + i +" to "+j);  						
  				}  
				BuildTable(network, inst);
//  			trace("end, score = "+score +" maxscore = " +maxscore);				
 			}
  		}
  	}while( score >  maxscore);
   	System.out.println("Search finished");
  }
  
  //check if there is loop in that network
  static int times = 0;
  public boolean LoopCheck(int [][] network, int startNode, int numAtt)
  {
  	boolean[] Visited = new boolean[numAtt];
  	for (int i = 0 ; i < numAtt ; i ++)
	 	Visited[i] = false;
	times = 0;	  	
  	DFS(network,startNode,startNode,Visited,numAtt);
  	return (times >1);
  }
  
  // depth first search
  // return the times that Aim have been found
  public void DFS(int [][] network,int CurrentNode, int Aim, boolean[] Visited, int numAtt)
  {
  	if( Aim == CurrentNode)
  	   times++; 
  	if( Visited[CurrentNode] == true)
  	   return ;  	
  	Visited[CurrentNode] = true; 
  	    
  	for (int j = 0 ; j < numAtt ; j ++)
  	 {  
  	 	 if( network[CurrentNode][j] == 1) // edge: CurrentNode -> j
  	     	 DFS(network,j,Aim,Visited,numAtt );  	     	 
  	 }
  	return ;
  	
  }
  
  public double ScoreFromData(int [][] network, Instances inst, int numAtt)
  {
  	double score = 0;
  	double penalty = 0 ;// we use 0 there. however, it should be depeded on the complexity of network
  	// calculate the entropy
   	double p=0;
   	
	Enumeration enumInsts = inst.enumerateInstances();
    while (enumInsts.hasMoreElements()) 
    {
    	// for each instance
    	Instance instance = (Instance) enumInsts.nextElement();
       	p = 1;
//		if (!instance.classIsMissing()) 
    	{
    		// for each node (attribute)
  			for (int j = 0 ; j < numAtt ; j ++)
			{
				// which are the parents of this node(attribute)
    			int[][] Combination = new int[NumberOfFather(j, network,numAtt)+2][2];
	  			int num = inst.numClasses();
	  	
	  			Combination[0][0] = -1;// class
	  			int t = 1;
	  			for ( int m = 0 ; m < numAtt ; m++)
  				{
  					if ((network[m][j] == 1) || (m == j))
  						Combination [t++][0] = m ;
	  			}
	  			
	  			Combination[0][1] =  (int)instance.classValue();     			
				
				for (int attIndex = 0 ; attIndex < numAtt ; attIndex++)
				{
	  				Attribute attribute = (Attribute) m_Att[attIndex];
	  				int h = 1;
		    		if (attribute.isNominal()) 
	   				{
	   					if ((network[attIndex][j] == 1) || (attIndex == j))// the fathor of this node or itself
	   					    Combination[h++][1] = (int)instance.value(attribute);
	     			}//if
	   			}// for
	   			int index = Combination2Index( Combination, j, m_Network);
	   			p *= m_Table[j].Prob[index]; 
			}	
    	}// end of processing one instance
		score -= p * Math.log(p);	
	}// while		
 	score -= penalty;
// 	trace("score = "+score);
  	return score;
  }
  
  public void buildClassifier(Instances instances) throws Exception
  {
    if (instances.checkForStringAttributes()) {
      throw new Exception("Can't handle string attributes!");
    }
    if (instances.classAttribute().isNumeric()) {
      throw new Exception("BayesianNetFixed: Class is numeric!");
    }
    

//  trace("\n number of Class " + Utils.doubleToString(instances.numClasses(), 10, 8) );
//  trace("\n number of Attribute: " + Utils.doubleToString(instances.numAttributes(), 10, 8) );
// 	trace("\n class attribute's index: " + Utils.doubleToString(instances.classIndex(), 10, 8) );
// 	trace("\n number of instances: " + Utils.doubleToString(instances.numInstances(), 10, 8) );

    m_Instances = new Instances(instances, 0);
    
    // Count and register attributes
    // the last attribute is class 
//  	trace("Count and register attributes");
  	int numAtt = m_Instances.numAttributes()-1;
  	
  	m_Att = new Attribute[numAtt];
  	Enumeration enumAtts = m_Instances.enumerateAttributes();
  	int ii = 0 ;
  	while (enumAtts.hasMoreElements()) 
  	{
	  	m_Att[ii++] = (Attribute) enumAtts.nextElement();  	  
  	}
//  Create network;
//	trace("Alloc network");
	m_Network = new int[numAtt][numAtt];
	for (int i = 0 ; i < numAtt ; i++)
		for (int j = 0 ; j < numAtt ; j++)
		{	m_Network[i][j] = 0 ; }

//	trace("Create network");
//  LoadNetwork(m_Network, m_networkFile);
	SearchNetwork(m_Network,2,instances );
    // try for "weather.nominal.arff"
//    m_Network[0][1]=1;m_Network[0][2]=1;m_Network[3][1]=1;
//    if(LoopCheck(m_Network,1,numAtt))    trace("loop"); else trace("no loop");
//    if(LoopCheck(m_Network,3,numAtt))    trace("loop"); else trace("no loop");
    
//    saveNetwork(m_Network,numAtt,"test.net");
//    m_Network[0][1] = 0; m_Network[3][2]=1;
//    int xx= LoadNetwork(m_Network,"test.net");
//    PrintNetwork(m_Network,xx);
  
//	for (int i = 0 ; i < numAtt ; i++)
//    	trace("Node "+ i + " number of parent " + NumberOfFather(i,m_Network,numAtt));
//    pause();
    
  
    // Compute Table 
//  trace("Compute Table");
	BuildTable(m_Network, instances);
 }	
	
 /**
   * Calculates the class membership probabilities for the given test instance.
   *
   * @param instance the instance to be classified
   * @return predicted class probability distribution
   */
  public double[] distributionForInstance(Instance instance) throws Exception 
  { 
//    trace("test begin"); 
    double [] probs = new double[instance.numClasses()];
  	int numAtt = m_Instances.numAttributes() -1;    
// 	trace("numAtt="+numAtt);
    
//  PrintNetwork(m_Network,numAtt);
    for (int x = 0; x < instance.numClasses(); x++) 
    {
   		double p = 1;
//		if (!instance.classIsMissing()) 
    	{
    		// for each node (attribute)
  			for (int j = 0 ; j < numAtt ; j ++)
			{
//				trace("attribute: "+j);
				// which are the parents of this node(attribute)
    			int[][] Combination = new int[NumberOfFather(j, m_Network,numAtt)+2][2];
	  			int num = m_Instances.numClasses();
	  	
	  			Combination[0][0] = -1;// class
	  			int t = 1;
	  			for ( int m = 0 ; m < numAtt ; m++)
  				{
  					if ((m_Network[m][j] == 1) || (m == j))
  						Combination [t++][0] = m ;
	  			}
	  		
	  			Combination[0][1] =  x;     			
			
				// find its parent nodes
				for (int attIndex = 0 ; attIndex < numAtt ; attIndex++)
				{
	  				Attribute attribute = (Attribute) m_Att[attIndex];
	  				int h = 1;
		   			if (attribute.isNominal()) 
	   				{
	   					if ((m_Network[attIndex][j] == 1) || (attIndex == j))// the fathor of this node or itself
		  					    Combination[h++][1] = (int)instance.value(attribute);
		   			}//if
	   			}// for
	   			int index = Combination2Index( Combination, j, m_Network);
//	   			trace("index: " + index + ", p= " + m_Table[j].Prob[index]);
	   			p *= (m_Table[j].Prob[index]);
//	   			p *= (m_Table[j].Prob[index]*2);// times 2 to avoid a too small p 
			}	
    	}// end of processing one class
    	probs[x] = p;
//    	trace("P of class "+x+" is " + p);
	}// END of for
    
    // Normalize probabilities
    Utils.normalize(probs);

    return probs;
  }
    /**
   * Returns a description of the classifier.
   *
   * @return a description of the classifier as a string.
   */
  public String toString() {

    if (m_Instances == null) {
      return "Bayesian Network : No model built yet.";
    }
    try {
      StringBuffer text = new StringBuffer("Bayesian Network ");
      int attIndex;
      
    text.append("\nnumber of Class " + Utils.doubleToString(m_Instances.numClasses(), 10, 8) );
    text.append("\nnumber of Attribute: " + Utils.doubleToString(m_Instances.numAttributes(),10, 8) );
   	text.append("\nclass attribute's index: " + Utils.doubleToString(m_Instances.classIndex(), 10, 8) );
   	
   	text.append("\n\nStructure of the Network");
  	boolean hasEdge = false;
  	for (int i = 0 ; i < m_Instances.numAttributes()-1 ; i ++)
  	{
  		for ( int j = 0 ; j < m_Instances.numAttributes()-1 ; j++)
  		{
 			if( m_Network[i][j]== 1)
  			{
  				hasEdge = true;
  				text.append("\n\tEdge: from " + i +" to "+j);
  			}
// 			text.append("i: "+ i +" j: "+ j + " value: "+ m_Network[i][j]);
  		}
  	}
  	if ( hasEdge == false )
		text.append("\n\tNo edge in the network" );
  	
   	
	return text.toString();
    }
    catch (Exception e) {
      return "Can't print Naive Bayes classifier!";
    }
  } 
    /**
   * Main method for testing this class.
   *
   * @param argv the options
   */
  public static void main(String [] argv) {

    Classifier scheme;

    try {
      scheme = new BayesianNetwork();
      System.out.println(Evaluation.evaluateModel(scheme, argv));
    } catch (Exception e) {
      System.err.println(e.getMessage());
    }
  }
  
  public void pause(){
  	if(bDebug)
  	{  try { System.in.read(); } catch(Exception e) {}}
  }
  public void trace(double x)
  {
  	if(bDebug)
  	   System.out.println(x);
  }
  public void trace(String str)
  {
  	if(bDebug)
  	   System.out.println(str);
  }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -