📄 bayesnet.java
字号:
} // distributionForInstance
/**
* Calculates the counts for Dirichlet distribution for the
* class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return counts for Dirichlet distribution for class probability
* @exception Exception if there is a problem generating the prediction
*/
public double[] countsForInstance(Instance instance) throws Exception {
double[] fCounts = new double[m_NumClasses];
for (int iClass = 0; iClass < m_NumClasses; iClass++) {
fCounts[iClass] = 0.0;
}
for (int iClass = 0; iClass < m_NumClasses; iClass++) {
double fCount = 0;
for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) {
double iCPT = 0;
for (int iParent = 0; iParent < m_ParentSets[iAttribute].getNrOfParents(); iParent++) {
int nParent = m_ParentSets[iAttribute].getParent(iParent);
if (nParent == m_Instances.classIndex()) {
iCPT = iCPT * m_NumClasses + iClass;
} else {
iCPT = iCPT * m_Instances.attribute(nParent).numValues() + instance.value(nParent);
}
}
if (iAttribute == m_Instances.classIndex()) {
fCount += ((DiscreteEstimatorBayes) m_Distributions[iAttribute][(int) iCPT]).getCount(iClass);
} else {
fCount
+= ((DiscreteEstimatorBayes) m_Distributions[iAttribute][(int) iCPT]).getCount(
instance.value(iAttribute));
}
}
fCounts[iClass] += fCount;
}
return fCounts;
} // countsForInstance
/**
* Returns an enumeration describing the available options
*
* @return an enumeration of all the available options
*/
public Enumeration<Option> listOptions() {
Vector<Option> newVector = new Vector<Option>(4);
newVector.addElement(new Option("\tUse ADTree data structure\n", "D", 0, "-D"));
newVector.addElement(new Option("\tBIF file to compare with\n", "B", 1, "-B <BIF file>"));
newVector.addElement(new Option("\tSearch algorithm\n", "Q", 1, "-Q weka.classifiers.bayes.net.search.SearchAlgorithm"));
newVector.addElement(new Option("\tEstimator algorithm\n", "E", 1, "-E weka.classifiers.bayes.net.estimate.SimpleEstimator"));
return newVector.elements();
} // listOptions
/**
* Parses a given list of options. Valid options are:<p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
m_bUseADTree = !(Utils.getFlag('D', options));
String sBIFFile = Utils.getOption('B', options);
if (sBIFFile != null && !sBIFFile.equals("")) {
setBIFFile(sBIFFile);
}
String searchAlgorithmName = Utils.getOption('Q', options);
if (searchAlgorithmName.length() == 0) {
throw new Exception("A searchAlgorithmName must be specified with" + " the -Q option.");
}
setSearchAlgorithm(
(SearchAlgorithm) Utils.forName(
SearchAlgorithm.class,
searchAlgorithmName,
partitionOptions(options)));
String estimatorName = Utils.getOption('E', options);
if (estimatorName.length() == 0) {
throw new Exception("A estimatorName must be specified with" + " the -E option.");
}
setEstimator(
(BayesNetEstimator) Utils.forName(
BayesNetEstimator.class,
estimatorName,
Utils.partitionOptions(options)));
Utils.checkForRemainingOptions(options);
} // setOptions
/**
* Returns the secondary set of options (if any) contained in
* the supplied options array. The secondary set is defined to
* be any options after the first "--" but before the "-E". These
* options are removed from the original options array.
*
* @param options the input array of options
* @return the array of secondary options
*/
public static String [] partitionOptions(String [] options) {
for (int i = 0; i < options.length; i++) {
if (options[i].equals("--")) {
// ensure it follows by a -E option
int j = i;
while ((j < options.length) && !(options[j].equals("-E"))) {
j++;
}
if (j >= options.length) {
return new String[0];
}
options[i++] = "";
String [] result = new String [options.length - i];
j = i;
while ((j < options.length) && !(options[j].equals("-E"))) {
result[j - i] = options[j];
options[j] = "";
j++;
}
while(j < options.length) {
result[j - i] = "";
j++;
}
return result;
}
}
return new String [0];
}
/**
* Gets the current settings of the classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String[] getOptions() {
String[] searchOptions = m_SearchAlgorithm.getOptions();
String[] estimatorOptions = m_BayesNetEstimator.getOptions();
String[] options = new String[8 + searchOptions.length + estimatorOptions.length];
int current = 0;
if (!m_bUseADTree) {
options[current++] = "-D";
}
if (m_otherBayesNet != null) {
options[current++] = "-B";
options[current++] = ((BIFReader) m_otherBayesNet).getFileName();
}
options[current++] = "-Q";
options[current++] = "" + getSearchAlgorithm().getClass().getName();
options[current++] = "--";
for (int iOption = 0; iOption < searchOptions.length; iOption++) {
options[current++] = searchOptions[iOption];
}
options[current++] = "-E";
options[current++] = "" + getEstimator().getClass().getName();
options[current++] = "--";
for (int iOption = 0; iOption < estimatorOptions.length; iOption++) {
options[current++] = estimatorOptions[iOption];
}
// Fill up rest with empty strings, not nulls!
while (current < options.length) {
options[current++] = "";
}
return options;
} // getOptions
/**
* Set the SearchAlgorithm used in searching for network structures.
* @param newSearchAlgorithm the SearchAlgorithm to use.
*/
public void setSearchAlgorithm(SearchAlgorithm newSearchAlgorithm) {
m_SearchAlgorithm = newSearchAlgorithm;
}
/**
* Get the SearchAlgorithm used as the search algorithm
* @return the SearchAlgorithm used as the search algorithm
*/
public SearchAlgorithm getSearchAlgorithm() {
return m_SearchAlgorithm;
}
/**
* Set the Estimator Algorithm used in calculating the CPTs
* @param newEstimator the Estimator to use.
*/
public void setEstimator(BayesNetEstimator newBayesNetEstimator) {
m_BayesNetEstimator = newBayesNetEstimator;
}
/**
* Get the BayesNetEstimator used for calculating the CPTs
* @return the BayesNetEstimator used.
*/
public BayesNetEstimator getEstimator() {
return m_BayesNetEstimator;
}
/**
* Set whether ADTree structure is used or not
* @param bUseADTree
*/
public void setUseADTree(boolean bUseADTree) {
m_bUseADTree = bUseADTree;
}
/**
* Method declaration
* @return whether ADTree structure is used or not
*/
public boolean getUseADTree() {
return m_bUseADTree;
}
/**
* Set name of network in BIF file to compare with
* @param sBIFFile
*/
public void setBIFFile(String sBIFFile) {
try {
m_otherBayesNet = new BIFReader().processFile(sBIFFile);
} catch (Throwable t) {
m_otherBayesNet = null;
}
}
/**
* Get name of network in BIF file to compare with
* @return BIF file name
*/
public String getBIFFile() {
if (m_otherBayesNet != null) {
return m_otherBayesNet.getFileName();
}
return "";
}
/**
* Returns a description of the classifier.
*
* @return a description of the classifier as a string.
*/
public String toString() {
StringBuffer text = new StringBuffer();
text.append("Bayes Network Classifier");
text.append("\n" + (m_bUseADTree ? "Using " : "not using ") + "ADTree");
if (m_Instances == null) {
text.append(": No model built yet.");
} else {
// flatten BayesNet down to text
text.append("\n#attributes=");
text.append(m_Instances.numAttributes());
text.append(" #classindex=");
text.append(m_Instances.classIndex());
text.append("\nNetwork structure (nodes followed by parents)\n");
for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) {
text.append(
m_Instances.attribute(iAttribute).name()
+ "("
+ m_Instances.attribute(iAttribute).numValues()
+ "): ");
for (int iParent = 0; iParent < m_ParentSets[iAttribute].getNrOfParents(); iParent++) {
text.append(m_Instances.attribute(m_ParentSets[iAttribute].getParent(iParent)).name() + " ");
}
text.append("\n");
// Description of distributions tends to be too much detail, so it is commented out here
// for (int iParent = 0; iParent < m_ParentSets[iAttribute].GetCardinalityOfParents(); iParent++) {
// text.append('(' + m_Distributions[iAttribute][iParent].toString() + ')');
// }
// text.append("\n");
}
text.append("LogScore Bayes: " + measureBayesScore() + "\n");
text.append("LogScore BDeu: " + measureBDeuScore() + "\n");
text.append("LogScore MDL: " + measureMDLScore() + "\n");
text.append("LogScore ENTROPY: " + measureEntropyScore() + "\n");
text.append("LogScore AIC: " + measureAICScore() + "\n");
if (m_otherBayesNet != null) {
text.append(
"Missing: "
+ m_otherBayesNet.missingArcs(this)
+ " Extra: "
+ m_otherBayesNet.extraArcs(this)
+ " Reversed: "
+ m_otherBayesNet.reversedArcs(this)
+ "\n");
text.append("Divergence: " + m_otherBayesNet.divergence(this) + "\n");
}
}
return text.toString();
} // toString
/**
* Returns the type of graph this classifier
* represents.
* @return Drawable.TREE
*/
public int graphType() {
return Drawable.BayesNet;
}
/**
Returns a BayesNet graph in XMLBIF ver
0.3 format.
@return - String representing this
BayesNet in XMLBIF ver 0.3
*/
public String graph() throws Exception {
return toXMLBIF03();
}
/**
* Returns a description of the classifier in XML BIF 0.3 format.
* See http://www-2.cs.cmu.edu/~fgcozman/Research/InterchangeFormat/
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -