📄 geneticsearch.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* GeneticSearch.java
* Copyright (C) 2004 Remco Bouckaert
*
*/
package weka.classifiers.bayes.net.search.local;
import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.net.*;
import weka.core.*;
import java.util.*;
/** GeneticSearch is a crude implementation of genetic search for learning
* Bayesian network structures.
*
* @author Remco Bouckaert (rrb@xm.co.nz)
* Version: $Revision$
*/
public class GeneticSearch extends LocalScoreSearchAlgorithm {
/** number of runs **/
int m_nRuns = 10;
/** size of population **/
int m_nPopulationSize = 10;
/** size of descendant population **/
int m_nDescendantPopulationSize = 100;
/** use cross-over? **/
boolean m_bUseCrossOver = true;
/** use mutation? **/
boolean m_bUseMutation = true;
/** use tournament selection or take best sub-population **/
boolean m_bUseTournamentSelection = false;
/** random number seed **/
int m_nSeed = 1;
/** random number generator **/
Random m_random = null;
/** used in BayesNetRepresentation for efficiently determining
* whether a number is square
*/
static boolean [] g_bIsSquare;
class BayesNetRepresentation {
/** number of nodes in network **/
int m_nNodes = 0;
/** bit representation of parent sets
* m_bits[iTail + iHead * m_nNodes] represents arc iTail->iHead
*/
boolean [] m_bits;
/** score of represented network structure **/
double m_fScore = 0.0f;
/** return score of represented network structure **/
public double getScore() {
return m_fScore;
} // getScore
/** c'tor **/
BayesNetRepresentation (int nNodes) {
m_nNodes = nNodes;
} // c'tor
/** initialize with a random structure by randomly placing
* m_nNodes arcs.
*/
public void randomInit() {
do {
m_bits = new boolean [m_nNodes * m_nNodes];
for (int i = 0; i < m_nNodes; i++) {
int iPos;
do {
iPos = m_random.nextInt(m_nNodes * m_nNodes);
} while (isSquare(iPos));
m_bits[iPos] = true;
}
} while (hasCycles());
calcScore();
}
/** calculate score of current network representation
* As a side effect, the parent sets are set
*/
void calcScore() {
// clear current network
for (int iNode = 0; iNode < m_nNodes; iNode++) {
ParentSet parentSet = m_BayesNet.getParentSet(iNode);
while (parentSet.getNrOfParents() > 0) {
parentSet.deleteLastParent(m_BayesNet.m_Instances);
}
}
// insert arrows
for (int iNode = 0; iNode < m_nNodes; iNode++) {
ParentSet parentSet = m_BayesNet.getParentSet(iNode);
for (int iNode2 = 0; iNode2 < m_nNodes; iNode2++) {
if (m_bits[iNode2 + iNode * m_nNodes]) {
parentSet.addParent(iNode2, m_BayesNet.m_Instances);
}
}
}
// calc score
m_fScore = 0.0;
for (int iNode = 0; iNode < m_nNodes; iNode++) {
m_fScore += calcNodeScore(iNode);
}
} // calcScore
/** check whether there are cycles in the network
*
* @return true if a cycle is found, false otherwise
*/
public boolean hasCycles() {
// check for cycles
boolean[] bDone = new boolean[m_nNodes];
for (int iNode = 0; iNode < m_nNodes; iNode++) {
// find a node for which all parents are 'done'
boolean bFound = false;
for (int iNode2 = 0; !bFound && iNode2 < m_nNodes; iNode2++) {
if (!bDone[iNode2]) {
boolean bHasNoParents = true;
for (int iParent = 0; iParent < m_nNodes; iParent++) {
if (m_bits[iParent + iNode2 * m_nNodes] && !bDone[iParent]) {
bHasNoParents = false;
}
}
if (bHasNoParents) {
bDone[iNode2] = true;
bFound = true;
}
}
}
if (!bFound) {
return true;
}
}
return false;
} // hasCycles
/** create clone of current object
* @return cloned object
*/
BayesNetRepresentation copy() {
BayesNetRepresentation b = new BayesNetRepresentation(m_nNodes);
b.m_bits = new boolean [m_bits.length];
for (int i = 0; i < m_nNodes * m_nNodes; i++) {
b.m_bits[i] = m_bits[i];
}
b.m_fScore = m_fScore;
return b;
} // copy
/** Apply mutation operation to BayesNet
* Calculate score and as a side effect sets BayesNet parent sets.
*/
void mutate() {
// flip a bit
do {
int iBit;
do {
iBit = m_random.nextInt(m_nNodes * m_nNodes);
} while (isSquare(iBit));
m_bits[iBit] = !m_bits[iBit];
} while (hasCycles());
calcScore();
} // mutate
/** Apply cross-over operation to BayesNet
* Calculate score and as a side effect sets BayesNet parent sets.
* @param other: BayesNetRepresentation to cross over with
*/
void crossOver(BayesNetRepresentation other) {
boolean [] bits = new boolean [m_bits.length];
for (int i = 0; i < m_bits.length; i++) {
bits[i] = m_bits[i];
}
int iCrossOverPoint = m_bits.length;
do {
// restore to original state
for (int i = iCrossOverPoint; i < m_bits.length; i++) {
m_bits[i] = bits[i];
}
// take all bits from cross-over points onwards
iCrossOverPoint = m_random.nextInt(m_bits.length);
for (int i = iCrossOverPoint; i < m_bits.length; i++) {
m_bits[i] = other.m_bits[i];
}
} while (hasCycles());
calcScore();
} // crossOver
/** check if number is square and initialize g_bIsSquare structure
* if necessary
* @param nNum: number to check (should be below m_nNodes * m_nNodes)
* @return true if number is square
*/
boolean isSquare(int nNum) {
if (g_bIsSquare == null || g_bIsSquare.length < nNum) {
g_bIsSquare = new boolean [m_nNodes * m_nNodes];
for (int i = 0; i < m_nNodes; i++) {
g_bIsSquare[i * m_nNodes + i] = true;
}
}
return g_bIsSquare[nNum];
} // isSquare
} // class BayesNetRepresentation
/**
* search determines the network structure/graph of the network
* with a genetic search algorithm.
**/
protected void search(BayesNet bayesNet, Instances instances) throws Exception {
// sanity check
if (getDescendantPopulationSize() < getPopulationSize()) {
throw new Exception ("Descendant PopulationSize should be at least Population Size");
}
if (!getUseCrossOver() && !getUseMutation()) {
throw new Exception ("At least one of mutation or cross-over should be used");
}
m_random = new Random(m_nSeed);
// keeps track of best structure found so far
BayesNet bestBayesNet;
// keeps track of score pf best structure found so far
double fBestScore = 0.0;
for (int iAttribute = 0; iAttribute < instances.numAttributes(); iAttribute++) {
fBestScore += calcNodeScore(iAttribute);
}
// initialize bestBayesNet
bestBayesNet = new BayesNet();
bestBayesNet.m_Instances = instances;
bestBayesNet.initStructure();
copyParentSets(bestBayesNet, bayesNet);
// initialize population
BayesNetRepresentation [] population = new BayesNetRepresentation [getPopulationSize()];
double [] score = new double[getPopulationSize()];
for (int i = 0; i < getPopulationSize(); i++) {
population[i] = new BayesNetRepresentation (instances.numAttributes());
population[i].randomInit();
if (population[i].getScore() > fBestScore) {
copyParentSets(bestBayesNet, bayesNet);
fBestScore = population[i].getScore();
}
}
// go do the search
for (int iRun = 0; iRun < m_nRuns; iRun++) {
// create descendants
BayesNetRepresentation [] descendantPopulation = new BayesNetRepresentation [getDescendantPopulationSize()];
for (int i = 0; i < getDescendantPopulationSize(); i++) {
descendantPopulation[i] = population[m_random.nextInt(getPopulationSize())].copy();
if (getUseMutation()) {
if (getUseCrossOver() && m_random.nextBoolean()) {
descendantPopulation[i].crossOver(population[m_random.nextInt(getPopulationSize())]);
} else {
descendantPopulation[i].mutate();
}
} else {
// use crossover
descendantPopulation[i].crossOver(population[m_random.nextInt(getPopulationSize())]);
}
if (descendantPopulation[i].getScore() > fBestScore) {
copyParentSets(bestBayesNet, bayesNet);
fBestScore = descendantPopulation[i].getScore();
}
}
// select new population
boolean [] bSelected = new boolean [getDescendantPopulationSize()];
for (int i = 0; i < getPopulationSize(); i++) {
int iSelected = 0;
if (m_bUseTournamentSelection) {
// use tournament selection
iSelected = m_random.nextInt(getDescendantPopulationSize());
while (bSelected[iSelected]) {
iSelected = (iSelected + 1) % getDescendantPopulationSize();
}
int iSelected2 = m_random.nextInt(getDescendantPopulationSize());
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -