📄 bayesiannetwork.java
字号:
// trace("newscore2= "+newscore2);
// trace("case 3");
// case 3: add edge i->j
network[i][j]=1; network[j][i] = 0;
double newscore3 =0 ;
if( NumberOfFather(j,network,numAtt) <=k && LoopCheck(network, j, numAtt) == false)
{
BuildTable(network, inst);
newscore3 = ScoreFromData(network, inst, numAtt);
if (newscore3 > score)
{ score = newscore3; max = 3; }
}
// trace("newscore3= "+newscore3);
// trace("choose best one");
// trace("max="+max);
switch (max)
{
case 0:
network[i][j] = oldij; network[j][i] = oldji;
// trace("keep original graph");
break;
case 1:
network[i][j] = 0; network[j][i] = 0;
// trace("no edge between " + j +" and "+i);
break;
case 2:
network[i][j] = 0; network[j][i] = 1;
// trace("add edge from " + j +" to "+i);
break;
case 3:
network[i][j] = 1; network[j][i] = 0;
// trace("add edge from " + i +" to "+j);
}
BuildTable(network, inst);
// trace("end, score = "+score +" maxscore = " +maxscore);
}
}
}while( score > maxscore);
System.out.println("Search finished");
}
//check if there is loop in that network
static int times = 0;
public boolean LoopCheck(int [][] network, int startNode, int numAtt)
{
boolean[] Visited = new boolean[numAtt];
for (int i = 0 ; i < numAtt ; i ++)
Visited[i] = false;
times = 0;
DFS(network,startNode,startNode,Visited,numAtt);
return (times >1);
}
// depth first search
// return the times that Aim have been found
public void DFS(int [][] network,int CurrentNode, int Aim, boolean[] Visited, int numAtt)
{
if( Aim == CurrentNode)
times++;
if( Visited[CurrentNode] == true)
return ;
Visited[CurrentNode] = true;
for (int j = 0 ; j < numAtt ; j ++)
{
if( network[CurrentNode][j] == 1) // edge: CurrentNode -> j
DFS(network,j,Aim,Visited,numAtt );
}
return ;
}
public double ScoreFromData(int [][] network, Instances inst, int numAtt)
{
double score = 0;
double penalty = 0 ;// we use 0 there. however, it should be depeded on the complexity of network
// calculate the entropy
double p=0;
Enumeration enumInsts = inst.enumerateInstances();
while (enumInsts.hasMoreElements())
{
// for each instance
Instance instance = (Instance) enumInsts.nextElement();
p = 1;
// if (!instance.classIsMissing())
{
// for each node (attribute)
for (int j = 0 ; j < numAtt ; j ++)
{
// which are the parents of this node(attribute)
int[][] Combination = new int[NumberOfFather(j, network,numAtt)+2][2];
int num = inst.numClasses();
Combination[0][0] = -1;// class
int t = 1;
for ( int m = 0 ; m < numAtt ; m++)
{
if ((network[m][j] == 1) || (m == j))
Combination [t++][0] = m ;
}
Combination[0][1] = (int)instance.classValue();
for (int attIndex = 0 ; attIndex < numAtt ; attIndex++)
{
Attribute attribute = (Attribute) m_Att[attIndex];
int h = 1;
if (attribute.isNominal())
{
if ((network[attIndex][j] == 1) || (attIndex == j))// the fathor of this node or itself
Combination[h++][1] = (int)instance.value(attribute);
}//if
}// for
int index = Combination2Index( Combination, j, m_Network);
p *= m_Table[j].Prob[index];
}
}// end of processing one instance
score -= p * Math.log(p);
}// while
score -= penalty;
// trace("score = "+score);
return score;
}
public void buildClassifier(Instances instances) throws Exception
{
if (instances.checkForStringAttributes()) {
throw new Exception("Can't handle string attributes!");
}
if (instances.classAttribute().isNumeric()) {
throw new Exception("BayesianNetFixed: Class is numeric!");
}
// trace("\n number of Class " + Utils.doubleToString(instances.numClasses(), 10, 8) );
// trace("\n number of Attribute: " + Utils.doubleToString(instances.numAttributes(), 10, 8) );
// trace("\n class attribute's index: " + Utils.doubleToString(instances.classIndex(), 10, 8) );
// trace("\n number of instances: " + Utils.doubleToString(instances.numInstances(), 10, 8) );
m_Instances = new Instances(instances, 0);
// Count and register attributes
// the last attribute is class
// trace("Count and register attributes");
int numAtt = m_Instances.numAttributes()-1;
m_Att = new Attribute[numAtt];
Enumeration enumAtts = m_Instances.enumerateAttributes();
int ii = 0 ;
while (enumAtts.hasMoreElements())
{
m_Att[ii++] = (Attribute) enumAtts.nextElement();
}
// Create network;
// trace("Alloc network");
m_Network = new int[numAtt][numAtt];
for (int i = 0 ; i < numAtt ; i++)
for (int j = 0 ; j < numAtt ; j++)
{ m_Network[i][j] = 0 ; }
// trace("Create network");
// LoadNetwork(m_Network, m_networkFile);
SearchNetwork(m_Network,2,instances );
// try for "weather.nominal.arff"
// m_Network[0][1]=1;m_Network[0][2]=1;m_Network[3][1]=1;
// if(LoopCheck(m_Network,1,numAtt)) trace("loop"); else trace("no loop");
// if(LoopCheck(m_Network,3,numAtt)) trace("loop"); else trace("no loop");
// saveNetwork(m_Network,numAtt,"test.net");
// m_Network[0][1] = 0; m_Network[3][2]=1;
// int xx= LoadNetwork(m_Network,"test.net");
// PrintNetwork(m_Network,xx);
// for (int i = 0 ; i < numAtt ; i++)
// trace("Node "+ i + " number of parent " + NumberOfFather(i,m_Network,numAtt));
// pause();
// Compute Table
// trace("Compute Table");
BuildTable(m_Network, instances);
}
/**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
*/
public double[] distributionForInstance(Instance instance) throws Exception
{
// trace("test begin");
double [] probs = new double[instance.numClasses()];
int numAtt = m_Instances.numAttributes() -1;
// trace("numAtt="+numAtt);
// PrintNetwork(m_Network,numAtt);
for (int x = 0; x < instance.numClasses(); x++)
{
double p = 1;
// if (!instance.classIsMissing())
{
// for each node (attribute)
for (int j = 0 ; j < numAtt ; j ++)
{
// trace("attribute: "+j);
// which are the parents of this node(attribute)
int[][] Combination = new int[NumberOfFather(j, m_Network,numAtt)+2][2];
int num = m_Instances.numClasses();
Combination[0][0] = -1;// class
int t = 1;
for ( int m = 0 ; m < numAtt ; m++)
{
if ((m_Network[m][j] == 1) || (m == j))
Combination [t++][0] = m ;
}
Combination[0][1] = x;
// find its parent nodes
for (int attIndex = 0 ; attIndex < numAtt ; attIndex++)
{
Attribute attribute = (Attribute) m_Att[attIndex];
int h = 1;
if (attribute.isNominal())
{
if ((m_Network[attIndex][j] == 1) || (attIndex == j))// the fathor of this node or itself
Combination[h++][1] = (int)instance.value(attribute);
}//if
}// for
int index = Combination2Index( Combination, j, m_Network);
// trace("index: " + index + ", p= " + m_Table[j].Prob[index]);
p *= (m_Table[j].Prob[index]);
// p *= (m_Table[j].Prob[index]*2);// times 2 to avoid a too small p
}
}// end of processing one class
probs[x] = p;
// trace("P of class "+x+" is " + p);
}// END of for
// Normalize probabilities
Utils.normalize(probs);
return probs;
}
/**
* Returns a description of the classifier.
*
* @return a description of the classifier as a string.
*/
public String toString() {
if (m_Instances == null) {
return "Bayesian Network : No model built yet.";
}
try {
StringBuffer text = new StringBuffer("Bayesian Network ");
int attIndex;
text.append("\nnumber of Class " + Utils.doubleToString(m_Instances.numClasses(), 10, 8) );
text.append("\nnumber of Attribute: " + Utils.doubleToString(m_Instances.numAttributes(),10, 8) );
text.append("\nclass attribute's index: " + Utils.doubleToString(m_Instances.classIndex(), 10, 8) );
text.append("\n\nStructure of the Network");
boolean hasEdge = false;
for (int i = 0 ; i < m_Instances.numAttributes()-1 ; i ++)
{
for ( int j = 0 ; j < m_Instances.numAttributes()-1 ; j++)
{
if( m_Network[i][j]== 1)
{
hasEdge = true;
text.append("\n\tEdge: from " + i +" to "+j);
}
// text.append("i: "+ i +" j: "+ j + " value: "+ m_Network[i][j]);
}
}
if ( hasEdge == false )
text.append("\n\tNo edge in the network" );
return text.toString();
}
catch (Exception e) {
return "Can't print Naive Bayes classifier!";
}
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
Classifier scheme;
try {
scheme = new BayesianNetwork();
System.out.println(Evaluation.evaluateModel(scheme, argv));
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
public void pause(){
if(bDebug)
{ try { System.in.read(); } catch(Exception e) {}}
}
public void trace(double x)
{
if(bDebug)
System.out.println(x);
}
public void trace(String str)
{
if(bDebug)
System.out.println(str);
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -