📄 bayesiannaive.java
字号:
/*
Bao Jie 2002-04-02
Iowa State University
*/
package weka.classifiers;
import java.io.*;
import java.util.*;
import weka.core.*;
public class BayesianNaive extends Classifier {// for Naive Bayes Classifier
/** All the counts for nominal attributes. */
private double [][][] m_Counts;
/** The means for numeric attributes. */
private double [][] m_Means;
/** The standard deviations for numeric attributes. */
private double [][] m_Devs;
/** The prior probabilities of the classes. */
private double [] m_Priors;
/** The instances used for training. */
private Instances m_Instances;
/** Constant for normal distribution. */
private static double NORM_CONST = Math.sqrt(2 * Math.PI);
/**
* Generates the classifier.
*
* @param instances set of instances serving as training data
* @exception Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances instances) throws Exception {
int attIndex = 0;
double sum;
if (instances.checkForStringAttributes()) {
throw new Exception("Can't handle string attributes!");
}
if (instances.classAttribute().isNumeric()) {
throw new Exception("Naive Bayes: Class is numeric!");
}
m_Instances = new Instances(instances, 0);
// Reserve space
m_Counts = new double[instances.numClasses()]
[instances.numAttributes() - 1][0];
m_Means = new double[instances.numClasses()]
[instances.numAttributes() - 1];
m_Devs = new double[instances.numClasses()]
[instances.numAttributes() - 1];
m_Priors = new double[instances.numClasses()];
Enumeration enum = instances.enumerateAttributes();
while (enum.hasMoreElements())
{
Attribute attribute = (Attribute) enum.nextElement();
if (attribute.isNominal())
{
for (int j = 0; j < instances.numClasses(); j++)
{
m_Counts[j][attIndex] = new double[attribute.numValues()];
}
}
else
{
for (int j = 0; j < instances.numClasses(); j++) {
m_Counts[j][attIndex] = new double[1];
}
}
attIndex++;
}
// Compute counts and sums
Enumeration enumInsts = instances.enumerateInstances();
// System.out.println("\n number of instances: " + m_Instances.numInstances() );
while (enumInsts.hasMoreElements())
{
// System.out.print('.');
Instance instance = (Instance) enumInsts.nextElement();
if (!instance.classIsMissing())
{
Enumeration enumAtts = instances.enumerateAttributes();
attIndex = 0;
while (enumAtts.hasMoreElements())
{
Attribute attribute = (Attribute) enumAtts.nextElement();
if (!instance.isMissing(attribute))
{
if (attribute.isNominal())
{
m_Counts[(int)instance.classValue()][attIndex][(int)instance.value(attribute)]++;
}
else
{
m_Means[(int)instance.classValue()][attIndex] += instance.value(attribute);
m_Counts[(int)instance.classValue()][attIndex][0]++;
}
}
attIndex++;
}
m_Priors[(int)instance.classValue()]++;
}
}
// Compute means
Enumeration enumAtts = instances.enumerateAttributes();
attIndex = 0;
while (enumAtts.hasMoreElements()) {
Attribute attribute = (Attribute) enumAtts.nextElement();
if (attribute.isNumeric()) {
for (int j = 0; j < instances.numClasses(); j++) {
if (m_Counts[j][attIndex][0] < 2) {
throw new Exception("attribute " + attribute.name() +
": less than two values for class " +
instances.classAttribute().value(j));
}
m_Means[j][attIndex] /= m_Counts[j][attIndex][0];
}
}
attIndex++;
}
// Compute standard deviations
enumInsts = instances.enumerateInstances();
while (enumInsts.hasMoreElements()) {
Instance instance =
(Instance) enumInsts.nextElement();
if (!instance.classIsMissing()) {
enumAtts = instances.enumerateAttributes();
attIndex = 0;
while (enumAtts.hasMoreElements()) {
Attribute attribute = (Attribute) enumAtts.nextElement();
if (!instance.isMissing(attribute)) {
if (attribute.isNumeric()) {
m_Devs[(int)instance.classValue()][attIndex] +=
(m_Means[(int)instance.classValue()][attIndex]-
instance.value(attribute))*
(m_Means[(int)instance.classValue()][attIndex]-
instance.value(attribute));
}
}
attIndex++;
}
}
}
enumAtts = instances.enumerateAttributes();
attIndex = 0;
while (enumAtts.hasMoreElements()) {
Attribute attribute = (Attribute) enumAtts.nextElement();
if (attribute.isNumeric()) {
for (int j = 0; j < instances.numClasses(); j++) {
if (m_Devs[j][attIndex] <= 0) {
throw new Exception("attribute " + attribute.name() +
": standard deviation is 0 for class " +
instances.classAttribute().value(j));
}
else {
m_Devs[j][attIndex] /= m_Counts[j][attIndex][0] - 1;
m_Devs[j][attIndex] = Math.sqrt(m_Devs[j][attIndex]);
}
}
}
attIndex++;
}
// Normalize counts
enumAtts = instances.enumerateAttributes();
attIndex = 0;
while (enumAtts.hasMoreElements()) {
Attribute attribute = (Attribute) enumAtts.nextElement();
if (attribute.isNominal()) {
for (int j = 0; j < instances.numClasses(); j++) {
sum = Utils.sum(m_Counts[j][attIndex]);
for (int i = 0; i < attribute.numValues(); i++) {
m_Counts[j][attIndex][i] =
(m_Counts[j][attIndex][i] + 1)
/ (sum + (double)attribute.numValues());
}
}
}
attIndex++;
}
// Normalize priors
sum = Utils.sum(m_Priors);
for (int j = 0; j < instances.numClasses(); j++)
m_Priors[j] = (m_Priors[j] + 1)
/ (sum + (double)instances.numClasses());
}
/**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @exception Exception if distribution can't be computed
*/
public double[] distributionForInstance(Instance instance) throws Exception {
double [] probs = new double[instance.numClasses()];
int attIndex;
for (int j = 0; j < instance.numClasses(); j++) {
probs[j] = 1;
Enumeration enumAtts = instance.enumerateAttributes();
attIndex = 0;
while (enumAtts.hasMoreElements()) {
Attribute attribute = (Attribute) enumAtts.nextElement();
if (!instance.isMissing(attribute)) {
if (attribute.isNominal()) {
probs[j] *= m_Counts[j][attIndex][(int)instance.value(attribute)];
} else {
probs[j] *= normalDens(instance.value(attribute),
m_Means[j][attIndex],
m_Devs[j][attIndex]);}
}
attIndex++;
}
probs[j] *= m_Priors[j];
}
// Normalize probabilities
Utils.normalize(probs);
return probs;
}
/**
* Classifies the given test instance. The instance has to belong to a
* dataset when it's being classified.
*
* @param instance the instance to be classified
* @return the predicted most likely class for the instance or
* Instance.missingValue() if no prediction is made
* @exception Exception if an error occurred during the prediction
*/
public double classifyInstance(Instance instance) throws Exception {
double [] dist = distributionForInstance(instance);
if (dist == null) {
throw new Exception("Null distribution predicted");
}
switch (instance.classAttribute().type())
{
case Attribute.NOMINAL:
double max = 0;
int maxIndex = 0;
for (int i = 0; i < dist.length; i++)
{
if (dist[i] > max)
{
maxIndex = i;
max = dist[i];
}
}
if (max > 0)
{
return maxIndex;
}
else
{
return Instance.missingValue();
}
case Attribute.NUMERIC:
return dist[0];
default:
return Instance.missingValue();
}
}
/**
* Returns a description of the classifier.
*
* @return a description of the classifier as a string.
*/
public String toString() {
if (m_Instances == null) {
return "My Naive Bayes (simple): No model built yet.";
}
try {
StringBuffer text = new StringBuffer("My Naive Bayes (simple)");
int attIndex;
text.append("\nnumber of Class " + Utils.doubleToString(m_Instances.numClasses(), 10, 8) );
text.append("\nnumber of Attribute: " + Utils.doubleToString(m_Instances.numAttributes(), 10, 8) );
text.append("\nclass attribute's index: " + Utils.doubleToString(m_Instances.classIndex(), 10, 8));
text.append("\nnumber of instances: " + Utils.doubleToString(m_Instances.numInstances(), 10, 8) );
for (int i = 0; i < m_Instances.numClasses(); i++) {
text.append("\n\nClass " + m_Instances.classAttribute().value(i)
+ ": P(C) = "
+ Utils.doubleToString(m_Priors[i], 10, 8)
+ "\n\n");
Enumeration enumAtts = m_Instances.enumerateAttributes();
attIndex = 0;
while (enumAtts.hasMoreElements()) {
Attribute attribute = (Attribute) enumAtts.nextElement();
text.append("Attribute " + attribute.name() + "\n");
if (attribute.isNominal()) {
for (int j = 0; j < attribute.numValues(); j++) {
text.append(attribute.value(j) + "\t");
}
text.append("\n");
for (int j = 0; j < attribute.numValues(); j++)
text.append(Utils.
doubleToString(m_Counts[i][attIndex][j], 10, 8)
+ "\t");
} else {
text.append("Mean: " + Utils.
doubleToString(m_Means[i][attIndex], 10, 8) + "\t");
text.append("Standard Deviation: "
+ Utils.doubleToString(m_Devs[i][attIndex], 10, 8));
}
text.append("\n\n");
attIndex++;
}
}
return text.toString();
} catch (Exception e) {
return "Can't print Naive Bayes classifier!";
}
}
/**
* Density function of normal distribution.
*/
private double normalDens(double x, double mean, double stdDev) {
double diff = x - mean;
return (1 / (NORM_CONST * stdDev))
* Math.exp(-(diff * diff / (2 * stdDev * stdDev)));
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
Classifier scheme;
try {
scheme = new BayesianNaive();
System.out.println(Evaluation.evaluateModel(scheme, argv));
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -