📄 globalscoresearchalgorithm.java
字号:
*/
public double cumulativeCV(BayesNet bayesNet) throws Exception {
m_BayesNet = bayesNet;
double fAccuracy = 0.0;
double fWeight = 0.0;
Instances instances = bayesNet.m_Instances;
bayesNet.initCPTs();
for (int iInstance = 0; iInstance < instances.numInstances(); iInstance++) {
Instance instance = instances.instance(iInstance);
fAccuracy += accuracyIncrease(instance);
bayesNet.updateClassifier(instance);
fWeight += instance.weight();
}
return fAccuracy / fWeight;
} // LeaveOneOutCV
/**
* kFoldCV uses k-fold cross validation to measure the accuracy of a Bayes
* network classifier.
* @param bayesNet : Bayes Network containing structure to evaluate
* @param nNrOfFolds : the number of folds k to perform k-fold cv
* @return accuracy (in interval 0..1) measured using leave one out cv.
* @throws Exception passed on by updateClassifier
*/
public double kFoldCV(BayesNet bayesNet, int nNrOfFolds) throws Exception {
m_BayesNet = bayesNet;
double fAccuracy = 0.0;
double fWeight = 0.0;
Instances instances = bayesNet.m_Instances;
// estimate CPTs based on complete data set
bayesNet.estimateCPTs();
int nFoldStart = 0;
int nFoldEnd = instances.numInstances() / nNrOfFolds;
int iFold = 1;
while (nFoldStart < instances.numInstances()) {
// remove influence of fold iFold from the probability distribution
for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) {
Instance instance = instances.instance(iInstance);
instance.setWeight(-instance.weight());
bayesNet.updateClassifier(instance);
}
// measure accuracy on fold iFold
for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) {
Instance instance = instances.instance(iInstance);
instance.setWeight(-instance.weight());
fAccuracy += accuracyIncrease(instance);
fWeight += instance.weight();
}
// restore influence of fold iFold from the probability distribution
for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) {
Instance instance = instances.instance(iInstance);
instance.setWeight(-instance.weight());
bayesNet.updateClassifier(instance);
}
// go to next fold
nFoldStart = nFoldEnd;
iFold++;
nFoldEnd = iFold * instances.numInstances() / nNrOfFolds;
}
return fAccuracy / fWeight;
} // kFoldCV
/** accuracyIncrease determines how much the accuracy estimate should
* be increased due to the contribution of a single given instance.
*
* @param instance : instance for which to calculate the accuracy increase.
* @return increase in accuracy due to given instance.
* @throws Exception passed on by distributionForInstance and classifyInstance
*/
double accuracyIncrease(Instance instance) throws Exception {
if (m_bUseProb) {
double [] fProb = m_BayesNet.distributionForInstance(instance);
return fProb[(int) instance.classValue()] * instance.weight();
} else {
if (m_BayesNet.classifyInstance(instance) == instance.classValue()) {
return instance.weight();
}
}
return 0;
} // accuracyIncrease
/**
* @return use probabilities or not in accuracy estimate
*/
public boolean getUseProb() {
return m_bUseProb;
} // getUseProb
/**
* @param useProb : use probabilities or not in accuracy estimate
*/
public void setUseProb(boolean useProb) {
m_bUseProb = useProb;
} // setUseProb
/**
* set cross validation strategy to be used in searching for networks.
* @param newCVType : cross validation strategy
*/
public void setCVType(SelectedTag newCVType) {
if (newCVType.getTags() == TAGS_CV_TYPE) {
m_nCVType = newCVType.getSelectedTag().getID();
}
} // setCVType
/**
* get cross validation strategy to be used in searching for networks.
* @return cross validation strategy
*/
public SelectedTag getCVType() {
return new SelectedTag(m_nCVType, TAGS_CV_TYPE);
} // getCVType
/**
* Returns an enumeration describing the available options
*
* @return an enumeration of all the available options
*/
public Enumeration listOptions() {
Vector newVector = new Vector(2);
newVector.addElement(
new Option(
"\tScore type (LOO-CV,k-Fold-CV,Cumulative-CV)\n",
"S",
1,
"-S [LOO-CV|k-Fold-CV|Cumulative-CV]"));
newVector.addElement(new Option("\tUse probabilistic scoring.\n\t(default true)", "Q", 0, "-Q"));
Enumeration em = super.listOptions();
while (em.hasMoreElements()) {
newVector.addElement(em.nextElement());
}
return newVector.elements();
} // listOptions
/**
* Parses a given list of options. Valid options are:<p>
*
* For other options see search algorithm.
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String sScore = Utils.getOption('S', options);
if (sScore.compareTo("LOO-CV") == 0) {
setCVType(new SelectedTag(LOOCV, TAGS_CV_TYPE));
}
if (sScore.compareTo("k-Fold-CV") == 0) {
setCVType(new SelectedTag(KFOLDCV, TAGS_CV_TYPE));
}
if (sScore.compareTo("Cumulative-CV") == 0) {
setCVType(new SelectedTag(CUMCV, TAGS_CV_TYPE));
}
setUseProb(Utils.getFlag('Q', options));
super.setOptions(options);
} // setOptions
/**
* Gets the current settings of the search algorithm.
*
* @return an array of strings suitable for passing to setOptions
*/
public String[] getOptions() {
String[] superOptions = super.getOptions();
String[] options = new String[3 + superOptions.length];
int current = 0;
options[current++] = "-S";
switch (m_nCVType) {
case (LOOCV) :
options[current++] = "LOO-CV";
break;
case (KFOLDCV) :
options[current++] = "k-Fold-CV";
break;
case (CUMCV) :
options[current++] = "Cumulative-CV";
break;
}
if (getUseProb()) {
options[current++] = "-Q";
}
// Fill up rest with empty strings, not nulls!
while (current < options.length) {
options[current++] = "";
}
return options;
} // getOptions
/**
* @return a string to describe the CVType option.
*/
public String CVTypeTipText() {
return "Select cross validation strategy to be used in searching for networks." +
"LOO-CV = Leave one out cross validation\n" +
"k-Fold-CV = k fold cross validation\n" +
"Cumulative-CV = cumulative cross validation."
;
} // CVTypeTipText
/**
* @return a string to describe the UseProb option.
*/
public String useProbTipText() {
return "If set to true, the probability of the class if returned in the estimate of the "+
"accuracy. If set to false, the accuracy estimate is only increased if the classifier returns " +
"exactly the correct class.";
} // useProbTipText
/**
* This will return a string describing the search algorithm.
* @return The string.
*/
public String globalInfo() {
return "This Bayes Network learning algorithm uses cross validation to estimate " +
"classification accuracy.";
} // globalInfo
} // class CVSearchAlgorithm
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -