📄 bayesiannetwork.java
字号:
/*
Bao Jie 2002-04-02
Iowa State University
*/
package weka.classifiers;
import java.io.*;
import java.util.*;
import weka.core.*;
import Probtable;
public class BayesianNetwork extends DistributionClassifier
// Bayesian Network
{
boolean bDebug = true;
int maxDrgree = 10;
/** The instances used for training. */
public Instances m_Instances;
/** The Attibutes of instances */
public static Attribute[] m_Att;
/** the Bayesian network */
public static int [][] m_Network;
// (i,j) = 1: a edge from i to j
// = 0: no edge from i to j (but there may be edge of j to i)
public String m_networkFile;
/** conditonal table at nodes */
public Probtable [] m_Table;
// Convert father nodes' combination to a index number
// Compress a high-dimensional conditional porb. table into a 2-dimensional table.
public int Combination2Index(int[][] Combination,int Node, int[][]network )
{
int numAtt = m_Instances.numAttributes() -1;
int Index=0;
int f = NumberOfFather(Node, network , numAtt)+1;
int base = 1;
Index += Combination[0][1];
base *= m_Instances.numClasses();
// trace("node "+ Node);
for (int i = 0 ; i < f ; i ++)
{
int att = Combination[i+1][0];
int att_num = m_Att[att].numValues();
// trace( "att =" + att + " att_num =" +att_num );
Index += Combination[i+1][1] * base;
base *= att_num;
}
/* //trace begin
PrintCombination(int[][] Combination, f+1);
System.out.println("\nindex = " + Index);
*/ //trace end
return Index;
}
public void PrintCombination(int[][] Combination, int lim)
{
System.out.print("attribute ");
for (int i = 0 ; i < lim ; i ++)
{
System.out.print("\t" + Combination[i][0]);
}
System.out.print("\nvalue ");
for (int i = 0 ; i < lim ; i ++)
{
System.out.print("\t" + Combination[i][1]);
}
}
// Resume father nodes' combination from a index number
// the first col is for the value of classes
public int[][] Index2Combination(int Index, int Node, int [][] network)
{
int numAtt = m_Instances.numAttributes() -1;
int[][] Combination = new int[NumberOfFather(Node, network,numAtt)+2][2];
int j = 0;
int base = 1;
//value of class
Combination[j][0]= -1;
Combination[j][1]= (Index % m_Instances.numClasses());
Index -= (Index % m_Instances.numClasses());
Index /= m_Instances.numClasses();
j++;
for ( int i = 0 ; i < numAtt ; i++)
{
if ((network[i][Node] == 1) || (i == Node))
{
int att_num = m_Att[i].numValues();
Combination[j][0] = i;
Combination[j][1] = (Index % att_num);
Index -= (Index % att_num);
Index /= att_num;
j++;
}
}
return Combination;
}
// return the number of father of certain node
public int NumberOfFather(int Node, int [][] network, int dim)
{
// trace("NumberOfFather ");trace(Node);
int num=0;
for ( int i = 0 ; i < dim ; i++)
{
num += network[i][Node];
}
return num;
}
// return the len of prob. table of certain node
public int LengthOfTable(int Node, int [][] network)
{
int num = m_Instances.numClasses() * m_Att[Node].numValues();
int numAtt = m_Instances.numAttributes() -1;
for ( int i = 0 ; i < numAtt ; i++)
{
if (network[i][Node] == 1)
num *= m_Att[i].numValues();
}
return num;
}
// build prob. table for a fixed network
public void BuildTable( int[][] network, Instances inst)
{
int numAtt = inst.numAttributes() -1;
m_Table = new Probtable[numAtt];
for(int i = 0 ; i < numAtt ; i ++)
{
// allocate storage space for tables
// trace("LengthOfTable(i,network) "+ LengthOfTable(i,network));
m_Table[i] = new Probtable(LengthOfTable(i,network));
// calc the table
// which are the parents of this node(attribute)
int[][] Combination = new int[NumberOfFather(i, network,numAtt)+2][2];
int numC = inst.numClasses();
// trace("num " + num + " NumberOfFather " + NumberOfFather(i, network,numAtt));
Combination[0][0] = -1;
int j = 1;
for ( int m = 0 ; m < numAtt ; m++)
{
if ((network[m][i] == 1) || (m == i))
Combination [j++][0] = m ;
}
Enumeration enumInsts = inst.enumerateInstances();
int[] numClass = new int[numC];// the number of samples of each class
for (int u = 0 ; u < numC ; u ++) numClass[u] = 0;
// trace("inst.numInstances() "+ inst.numInstances());
// trace("attribute " + i);
while (enumInsts.hasMoreElements())
{
Instance instance = (Instance) enumInsts.nextElement();
if (!instance.classIsMissing())
{
Combination[0][1] = (int)instance.classValue();
int h = 0;
for (int attIndex = 0 ; attIndex < numAtt ; attIndex++)
{
Attribute attribute = (Attribute) m_Att[attIndex];
if (attribute.isNominal())
{
if ((network[attIndex][i] == 1) || (attIndex == i))// the fathor of this node or itself
{
h++;
// trace("h ="+ h +" value =" + (int)instance.value(attribute));
Combination[h][1] = (int)instance.value(attribute);
}
}//if
}// for
// trace(Combination2Index( Combination, i, m_Network));
m_Table[i].Prob[Combination2Index( Combination, i, m_Network)] += 1;
m_Table[i].type[Combination2Index( Combination, i, m_Network)] = (int)instance.classValue();
numClass[(int)instance.classValue()] ++;
}// END if (!instance.classIsMissing())
}// END while
// for (int k = 0 ; k < numC ; k ++)
// { trace("number of class "+ k + " = "+ numClass[k]); }
// trace("\nbefore normalization");m_Table[i].print();
m_Table[i].ClassNormalize(numClass, numC );
// trace("\nafter normalization");m_Table[i].print();
} // for i
}
// print the network out
public void PrintNetwork(int [][]network, int dim)
{
System.out.println("number of attributes in network: " + dim);
boolean hasEdge = false;
for (int i = 0 ; i < dim ; i ++)
{
for ( int j = 0 ; j < dim ; j++)
{
if( m_Network[i][j]== 1)
{
hasEdge = true;
System.out.println("Edge: from " + i +" to "+j);
}
// System.out.println("i: "+ i +" j: "+ j + " value: "+ network[i][j]);
}
}
if ( hasEdge == false )
System.out.println("No edge in the network" );
}
// read network from a saved file
// return: number of node in that network
public int LoadNetwork(int [][] network, String networkFile)
{
int dim = 0;
File f = new File(networkFile);
try
{
FileInputStream s = new FileInputStream(f);
dim = s.read();
for (int i = 0 ; i < dim ; i ++)
for ( int j = 0 ; j < dim ; j++)
network[i][j]=s.read();
s.close();
}
catch(IOException e){}
return dim;
}
// save network to a file
public void saveNetwork(int [][] network, int dim, String networkFile)
{
File f = new File(networkFile);
try
{
FileOutputStream s = new FileOutputStream(f);
s.write(dim);
for (int i = 0 ; i < dim ; i ++)
for ( int j = 0 ; j < dim ; j++)
s.write(network[i][j]);
s.close();
}
catch(IOException e){}
}
// Automatically search a network struture
// k is the maximal parents number of a node
public void SearchNetwork(int [][] network, int k, Instances inst )
{
System.out.println("Search in progress...");
int numAtt = inst.numAttributes() -1;
// trace("numAtt="+numAtt);
// create a network with no edges
// trace("initialize");
for (int i = 0 ; i < numAtt ; i ++)
{
for ( int j = 0 ; j < numAtt ; j++)
{
network[i][j]=0;
}
}
// PrintNetwork(network,numAtt);
double score = -10e10;//-inf
double maxscore = 0;
// trace("search");
do
{
maxscore = score;
for ( int i = 0 ; i < numAtt ; i ++)
{
for ( int j = i+1 ; j < numAtt ; j++)
{
System.out.print('.');
// trace("start, score = " + score );
// trace("try pair: "+ i + " , " +j );
int oldij = network[i][j];
int oldji = network[j][i];
int max = 0;
// trace("max="+max);
// trace("case 1");
// case 1 : no edge i <-> j
network[i][j] = 0; network[j][i] = 0;
BuildTable(network, inst);
double newscore1 = ScoreFromData(network, inst, numAtt);
if (newscore1 > score)
{ score = newscore1; max = 1; }
// trace("newscore1= "+newscore1);
// trace("case 2");
// case 2: add edge j->i
network[i][j] = 0;network[j][i] = 1;
double newscore2 = 0 ;
if( NumberOfFather(i,network,numAtt) <=k && LoopCheck(network, i, numAtt) == false)
{
// trace("BuildTable");
BuildTable(network, inst);
// trace("ScoreFromData");
newscore2 = ScoreFromData(network, inst, numAtt);
if (newscore2 > score)
{ score = newscore2; max = 2; }
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -