📄 globalscoresearchalgorithm.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* CVSearchAlgorithm.java
* Copyright (C) 2004 Remco Bouckaert
*
*/
package weka.classifiers.bayes.net.search.global;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.net.ParentSet;
import weka.classifiers.bayes.net.search.SearchAlgorithm;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
/** The CVSearchAlgorithm class supports Bayes net structure search algorithms
* that are based on cross validation (as opposed to for example
* score based of conditional independence based search algorithms).
*
* @author Remco Bouckaert
* @version $Revision$
*/
public class GlobalScoreSearchAlgorithm extends SearchAlgorithm {
/** points to Bayes network for which a structure is searched for **/
BayesNet m_BayesNet;
/** toggle between scoring using accuracy = 0-1 loss (when false) or class probabilities (when true) **/
boolean m_bUseProb = true;
/** number of folds for k-fold cross validation **/
int m_nNrOfFolds = 10;
/** constants for score types **/
final static int LOOCV = 0;
final static int KFOLDCV = 1;
final static int CUMCV = 2;
public static final Tag[] TAGS_CV_TYPE =
{
new Tag(LOOCV, "LOO-CV"),
new Tag(KFOLDCV, "k-Fold-CV"),
new Tag(CUMCV, "Cumulative-CV")
};
/**
* Holds the cross validation strategy used to measure quality of network
*/
int m_nCVType = LOOCV;
/**
* performCV returns the accuracy calculated using cross validation.
* The dataset used is m_Instances associated with the Bayes Network.
* @param bayesNet : Bayes Network containing structure to evaluate
* @return accuracy (in interval 0..1) measured using cv.
* @throws Exception whn m_nCVType is invalided + exceptions passed on by updateClassifier
*/
public double calcScore(BayesNet bayesNet) throws Exception {
switch (m_nCVType) {
case LOOCV:
return leaveOneOutCV(bayesNet);
case CUMCV:
return cumulativeCV(bayesNet);
case KFOLDCV:
return kFoldCV(bayesNet, m_nNrOfFolds);
default:
throw new Exception("Unrecognized cross validation type encountered: " + m_nCVType);
}
} // calcScore
/**
* Calc Node Score With Added Parent
*
* @param nNode node for which the score is calculate
* @param nCandidateParent candidate parent to add to the existing parent set
* @return log score
*/
public double calcScoreWithExtraParent(int nNode, int nCandidateParent) throws Exception {
ParentSet oParentSet = m_BayesNet.getParentSet(nNode);
Instances instances = m_BayesNet.m_Instances;
// sanity check: nCandidateParent should not be in parent set already
for (int iParent = 0; iParent < oParentSet.getNrOfParents(); iParent++) {
if (oParentSet.getParent(iParent) == nCandidateParent) {
return -1e100;
}
}
// determine cardinality of parent set & reserve space for frequency counts
int nCardinality =
oParentSet.getCardinalityOfParents() * instances.attribute(nCandidateParent).numValues();
int numValues = instances.attribute(nNode).numValues();
int[][] nCounts = new int[nCardinality][numValues];
// set up candidate parent
oParentSet.addParent(nCandidateParent, instances);
// calculate the score
double fAccuracy = calcScore(m_BayesNet);
// delete temporarily added parent
oParentSet.deleteLastParent(instances);
return fAccuracy;
} // calcScoreWithExtraParent
/**
* Calc Node Score With Parent Deleted
*
* @param nNode node for which the score is calculate
* @param nCandidateParent candidate parent to delete from the existing parent set
* @return log score
*/
public double calcScoreWithMissingParent(int nNode, int nCandidateParent) throws Exception {
ParentSet oParentSet = m_BayesNet.getParentSet(nNode);
Instances instances = m_BayesNet.m_Instances;
// sanity check: nCandidateParent should be in parent set already
if (!oParentSet.contains( nCandidateParent)) {
return -1e100;
}
// set up candidate parent
int iParent = oParentSet.deleteParent(nCandidateParent, instances);
// determine cardinality of parent set & reserve space for frequency counts
int nCardinality = oParentSet.getCardinalityOfParents();
int numValues = instances.attribute(nNode).numValues();
int [][] nCounts = new int[nCardinality][numValues];
// calculate the score
double fAccuracy = calcScore(m_BayesNet);
// reinsert temporarily deleted parent
oParentSet.addParent(nCandidateParent, iParent, instances);
return fAccuracy;
} // calcScoreWithMissingParent
/**
* Calc Node Score With Arrow reversed
*
* @param nNode node for which the score is calculate
* @param nCandidateParent candidate parent to delete from the existing parent set
* @return log score
*/
public double calcScoreWithReversedParent(int nNode, int nCandidateParent) throws Exception {
ParentSet oParentSet = m_BayesNet.getParentSet(nNode);
ParentSet oParentSet2 = m_BayesNet.getParentSet(nCandidateParent);
Instances instances = m_BayesNet.m_Instances;
// sanity check: nCandidateParent should be in parent set already
if (!oParentSet.contains( nCandidateParent)) {
return -1e100;
}
// set up candidate parent
int iParent = oParentSet.deleteParent(nCandidateParent, instances);
oParentSet2.addParent(nNode, instances);
// determine cardinality of parent set & reserve space for frequency counts
int nCardinality = oParentSet.getCardinalityOfParents();
int numValues = instances.attribute(nNode).numValues();
int [][] nCounts = new int[nCardinality][numValues];
// calculate the score
double fAccuracy = calcScore(m_BayesNet);
// restate temporarily reversed arrow
oParentSet2.deleteLastParent(instances);
oParentSet.addParent(nCandidateParent, iParent, instances);
return fAccuracy;
} // calcScoreWithReversedParent
/**
* LeaveOneOutCV returns the accuracy calculated using Leave One Out
* cross validation. The dataset used is m_Instances associated with
* the Bayes Network.
* @param bayesNet : Bayes Network containing structure to evaluate
* @return accuracy (in interval 0..1) measured using leave one out cv.
* @throws Exception passed on by updateClassifier
*/
public double leaveOneOutCV(BayesNet bayesNet) throws Exception {
m_BayesNet = bayesNet;
double fAccuracy = 0.0;
double fWeight = 0.0;
Instances instances = bayesNet.m_Instances;
bayesNet.estimateCPTs();
for (int iInstance = 0; iInstance < instances.numInstances(); iInstance++) {
Instance instance = instances.instance(iInstance);
instance.setWeight(-instance.weight());
bayesNet.updateClassifier(instance);
fAccuracy += accuracyIncrease(instance);
fWeight += instance.weight();
instance.setWeight(-instance.weight());
bayesNet.updateClassifier(instance);
}
return fAccuracy / fWeight;
} // LeaveOneOutCV
/**
* CumulativeCV returns the accuracy calculated using cumulative
* cross validation. The idea is to run through the data set and
* try to classify each of the instances based on the previously
* seen data.
* The data set used is m_Instances associated with the Bayes Network.
* @param bayesNet : Bayes Network containing structure to evaluate
* @return accuracy (in interval 0..1) measured using leave one out cv.
* @throws Exception passed on by updateClassifier
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -