📄 hillclimber.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* TabuSearch.java
* Copyright (C) 2004 Remco Bouckaert
*
*/
package weka.classifiers.bayes.net.search.local;
import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.net.ParentSet;
import weka.core.*;
import java.util.*;
import java.io.Serializable;
/** HillClimber implements hill climbing using local search
* for learning Bayesian network.
*
* @author Remco Bouckaert (rrb@xm.co.nz)
* Version: $Revision$
*/
public class HillClimber extends LocalScoreSearchAlgorithm {
/** the Operation class contains info on operations performed
* on the current Bayesian network.
*/
class Operation implements Serializable {
// constants indicating the type of an operation
final static int OPERATION_ADD = 0;
final static int OPERATION_DEL = 1;
final static int OPERATION_REVERSE = 2;
/** c'tor **/
public Operation() {
}
/** c'tor + initializers
*
* @param nTail
* @param nHead
* @param nOperation
*/
public Operation(int nTail, int nHead, int nOperation) {
m_nHead = nHead;
m_nTail = nTail;
m_nOperation = nOperation;
}
/** compare this operation with another
* @param other: operation to compare with
* @return true if operation is the same
*/
public boolean equals(Operation other) {
if (other == null) {
return false;
}
return (( m_nOperation == other.m_nOperation) &&
(m_nHead == other.m_nHead) &&
(m_nTail == other.m_nTail));
} // equals
/** number of the tail node **/
public int m_nTail;
/** number of the head node **/
public int m_nHead;
/** type of operation (ADD, DEL, REVERSE) **/
public int m_nOperation;
/** change of score due to this operation **/
public double m_fDeltaScore = -1E100;
} // class Operation
/** cache for remembering the change in score for steps in the search space
*/
class Cache {
/** change in score due to adding an arc **/
double [] [] m_fDeltaScoreAdd;
/** change in score due to deleting an arc **/
double [] [] m_fDeltaScoreDel;
/** c'tor
* @param nNrOfNodes: number of nodes in network, used to determine memory size to reserve
*/
Cache(int nNrOfNodes) {
m_fDeltaScoreAdd = new double [nNrOfNodes][nNrOfNodes];
m_fDeltaScoreDel = new double [nNrOfNodes][nNrOfNodes];
}
/** set cache entry
* @param oOperation: operation to perform
* @param fValue: value to put in cache
*/
public void put(Operation oOperation, double fValue) {
if (oOperation.m_nOperation == Operation.OPERATION_ADD) {
m_fDeltaScoreAdd[oOperation.m_nTail][oOperation.m_nHead] = fValue;
} else {
m_fDeltaScoreDel[oOperation.m_nTail][oOperation.m_nHead] = fValue;
}
} // put
/** get cache entry
* @param oOperation: operation to perform
* @return cache value
*/
public double get(Operation oOperation) {
switch(oOperation.m_nOperation) {
case Operation.OPERATION_ADD:
return m_fDeltaScoreAdd[oOperation.m_nTail][oOperation.m_nHead];
case Operation.OPERATION_DEL:
return m_fDeltaScoreDel[oOperation.m_nTail][oOperation.m_nHead];
case Operation.OPERATION_REVERSE:
return m_fDeltaScoreDel[oOperation.m_nTail][oOperation.m_nHead] +
m_fDeltaScoreAdd[oOperation.m_nHead][oOperation.m_nTail];
}
// should never get here
return 0;
} // get
} // class Cache
/** cache for storing score differences **/
Cache m_Cache = null;
/** use the arc reversal operator **/
boolean m_bUseArcReversal = false;
/**
* search determines the network structure/graph of the network
* with the Taby algorithm.
**/
protected void search(BayesNet bayesNet, Instances instances) throws Exception {
initCache(bayesNet, instances);
// go do the search
Operation oOperation = getOptimalOperation(bayesNet, instances);
while ((oOperation != null) && (oOperation.m_fDeltaScore > 0)) {
performOperation(bayesNet, instances, oOperation);
oOperation = getOptimalOperation(bayesNet, instances);
}
// free up memory
m_Cache = null;
} // search
/** initCache initializes the cache
* @param bayesNet: Bayes network to be learned
* @param instances: data set to learn from
* @throws Exception
*/
void initCache(BayesNet bayesNet, Instances instances) throws Exception {
// determine base scores
double[] fBaseScores = new double[instances.numAttributes()];
int nNrOfAtts = instances.numAttributes();
m_Cache = new Cache (nNrOfAtts);
for (int iAttribute = 0; iAttribute < nNrOfAtts; iAttribute++) {
updateCache(iAttribute, nNrOfAtts, bayesNet.getParentSet(iAttribute));
}
for (int iAttribute = 0; iAttribute < nNrOfAtts; iAttribute++) {
fBaseScores[iAttribute] = calcNodeScore(iAttribute);
}
for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) {
for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) {
if (iAttributeHead != iAttributeTail) {
Operation oOperation = new Operation(iAttributeTail, iAttributeHead, Operation.OPERATION_ADD);
m_Cache.put(oOperation, calcScoreWithExtraParent(iAttributeHead, iAttributeTail) - fBaseScores[iAttributeHead]);
}
}
}
} // initCache
/** check whether the operation is not in the forbidden.
* For base hill climber, there are no restrictions on operations,
* so we always return true.
* @param oOperation: operation to be checked
* @return true if operation is not in the tabu list
*/
boolean isNotTabu(Operation oOperation) {
return true;
} // isNotTabu
/** getOptimalOperation finds the optimal operation that can be performed
* on the Bayes network that is not in the tabu list.
* @param bayesNet: Bayes network to apply operation on
* @param instances: data set to learn from
* @return optimal operation found
* @throws Exception
*/
Operation getOptimalOperation(BayesNet bayesNet, Instances instances) throws Exception {
Operation oBestOperation = new Operation();
// Add???
oBestOperation = findBestArcToAdd(bayesNet, instances, oBestOperation);
// Delete???
oBestOperation = findBestArcToDelete(bayesNet, instances, oBestOperation);
// Reverse???
if (getUseArcReversal()) {
oBestOperation = findBestArcToReverse(bayesNet, instances, oBestOperation);
}
// did we find something?
if (oBestOperation.m_fDeltaScore == -1E100) {
return null;
}
return oBestOperation;
} // getOptimalOperation
/** performOperation applies an operation
* on the Bayes network and update the cache.
* @param bayesNet: Bayes network to apply operation on
* @param instances: data set to learn from
* @param oOperation: operation to perform
* @throws Exception
*/
void performOperation(BayesNet bayesNet, Instances instances, Operation oOperation) throws Exception {
// perform operation
switch (oOperation.m_nOperation) {
case Operation.OPERATION_ADD:
applyArcAddition(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
if (bayesNet.getDebug()) {
System.out.print("Add " + oOperation.m_nHead + " -> " + oOperation.m_nTail);
}
break;
case Operation.OPERATION_DEL:
applyArcDeletion(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
if (bayesNet.getDebug()) {
System.out.print("Del " + oOperation.m_nHead + " -> " + oOperation.m_nTail);
}
break;
case Operation.OPERATION_REVERSE:
applyArcDeletion(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
applyArcAddition(bayesNet, oOperation.m_nTail, oOperation.m_nHead, instances);
if (bayesNet.getDebug()) {
System.out.print("Rev " + oOperation.m_nHead+ " -> " + oOperation.m_nTail);
}
break;
}
} // performOperation
void applyArcAddition(BayesNet bayesNet, int iHead, int iTail, Instances instances) {
ParentSet bestParentSet = bayesNet.getParentSet(iHead);
bestParentSet.addParent(iTail, instances);
updateCache(iHead, instances.numAttributes(), bestParentSet);
} // applyArcAddition
void applyArcDeletion(BayesNet bayesNet, int iHead, int iTail, Instances instances) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -