⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bayesiannaive.java

📁 此代码未数据挖掘软件的一个例子
💻 JAVA
字号:
/*
Bao Jie 2002-04-02
Iowa State University
*/

package weka.classifiers;

import java.io.*;
import java.util.*;
import weka.core.*;


public class BayesianNaive extends Classifier {// for Naive Bayes Classifier

  /** All the counts for nominal attributes. */
  private double [][][] m_Counts;
  
  /** The means for numeric attributes. */
  private double [][] m_Means;

  /** The standard deviations for numeric attributes. */
  private double [][] m_Devs;

  /** The prior probabilities of the classes. */
  private double [] m_Priors;

  /** The instances used for training. */
  private Instances m_Instances;

  /** Constant for normal distribution. */
  private static double NORM_CONST = Math.sqrt(2 * Math.PI);

  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data 
   * @exception Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    int attIndex = 0;
    double sum;
    
    if (instances.checkForStringAttributes()) {
      throw new Exception("Can't handle string attributes!");
    }
    if (instances.classAttribute().isNumeric()) {
      throw new Exception("Naive Bayes: Class is numeric!");
    }
    
    m_Instances = new Instances(instances, 0);
    
    // Reserve space
    m_Counts = new double[instances.numClasses()]
      [instances.numAttributes() - 1][0];
    m_Means = new double[instances.numClasses()]
      [instances.numAttributes() - 1];
    m_Devs = new double[instances.numClasses()]
      [instances.numAttributes() - 1];
    m_Priors = new double[instances.numClasses()];
    Enumeration enum = instances.enumerateAttributes();
    while (enum.hasMoreElements()) 
    {
      Attribute attribute = (Attribute) enum.nextElement();
      if (attribute.isNominal()) 
      {
		for (int j = 0; j < instances.numClasses(); j++) 
		{
	  		m_Counts[j][attIndex] = new double[attribute.numValues()];
		}
      }
      else 
      {
		for (int j = 0; j < instances.numClasses(); j++) {
	  		m_Counts[j][attIndex] = new double[1];
		}
      }
      attIndex++;
    }
    
    // Compute counts and sums
    Enumeration enumInsts = instances.enumerateInstances();
// 	System.out.println("\n number of instances: " + m_Instances.numInstances() );
    while (enumInsts.hasMoreElements()) 
    {
//    System.out.print('.');
      Instance instance = (Instance) enumInsts.nextElement();
      if (!instance.classIsMissing()) 
      {
		Enumeration enumAtts = instances.enumerateAttributes();
		attIndex = 0;
		while (enumAtts.hasMoreElements()) 
		{
	  		Attribute attribute = (Attribute) enumAtts.nextElement();
	  		if (!instance.isMissing(attribute)) 
	  		{
	    		if (attribute.isNominal()) 
	    		{
	      			m_Counts[(int)instance.classValue()][attIndex][(int)instance.value(attribute)]++;
	    		}
	    		else 
	    		{
	      			m_Means[(int)instance.classValue()][attIndex] += instance.value(attribute);
	      			m_Counts[(int)instance.classValue()][attIndex][0]++;
	    		}
	     }
	     attIndex++;
	   }
	   m_Priors[(int)instance.classValue()]++;
      }
    }
    
    // Compute means
    Enumeration enumAtts = instances.enumerateAttributes();
    attIndex = 0;
    while (enumAtts.hasMoreElements()) {
      Attribute attribute = (Attribute) enumAtts.nextElement();
      if (attribute.isNumeric()) {
	for (int j = 0; j < instances.numClasses(); j++) {
	  if (m_Counts[j][attIndex][0] < 2) {
	    throw new Exception("attribute " + attribute.name() +
				": less than two values for class " +
				instances.classAttribute().value(j));
	  }
	  m_Means[j][attIndex] /= m_Counts[j][attIndex][0];
	}
      }
      attIndex++;
    }    
    
    // Compute standard deviations
    enumInsts = instances.enumerateInstances();
    while (enumInsts.hasMoreElements()) {
      Instance instance = 
	(Instance) enumInsts.nextElement();
      if (!instance.classIsMissing()) {
	enumAtts = instances.enumerateAttributes();
	attIndex = 0;
	while (enumAtts.hasMoreElements()) {
	  Attribute attribute = (Attribute) enumAtts.nextElement();
	  if (!instance.isMissing(attribute)) {
	    if (attribute.isNumeric()) {
	      m_Devs[(int)instance.classValue()][attIndex] +=
		(m_Means[(int)instance.classValue()][attIndex]-
		 instance.value(attribute))*
		(m_Means[(int)instance.classValue()][attIndex]-
		 instance.value(attribute));
	    }
	  }
	  attIndex++;
	}
      }
    }
    enumAtts = instances.enumerateAttributes();
    attIndex = 0;
    while (enumAtts.hasMoreElements()) {
      Attribute attribute = (Attribute) enumAtts.nextElement();
      if (attribute.isNumeric()) {
	for (int j = 0; j < instances.numClasses(); j++) {
	  if (m_Devs[j][attIndex] <= 0) {
	    throw new Exception("attribute " + attribute.name() +
				": standard deviation is 0 for class " +
				instances.classAttribute().value(j));
	  }
	  else {
	    m_Devs[j][attIndex] /= m_Counts[j][attIndex][0] - 1;
	    m_Devs[j][attIndex] = Math.sqrt(m_Devs[j][attIndex]);
	  }
	}
      }
      attIndex++;
    } 
    
    // Normalize counts
    enumAtts = instances.enumerateAttributes();
    attIndex = 0;
    while (enumAtts.hasMoreElements()) {
      Attribute attribute = (Attribute) enumAtts.nextElement();
      if (attribute.isNominal()) {
	for (int j = 0; j < instances.numClasses(); j++) {
	  sum = Utils.sum(m_Counts[j][attIndex]);
	  for (int i = 0; i < attribute.numValues(); i++) {
	    m_Counts[j][attIndex][i] =
	      (m_Counts[j][attIndex][i] + 1) 
	      / (sum + (double)attribute.numValues());
	  }
	}
      }
      attIndex++;
    }
    
    // Normalize priors
    sum = Utils.sum(m_Priors);
    for (int j = 0; j < instances.numClasses(); j++)
      m_Priors[j] = (m_Priors[j] + 1) 
	/ (sum + (double)instances.numClasses());
  }

  /**
   * Calculates the class membership probabilities for the given test instance.
   *
   * @param instance the instance to be classified
   * @return predicted class probability distribution
   * @exception Exception if distribution can't be computed
   */
  public double[] distributionForInstance(Instance instance) throws Exception {
    
    double [] probs = new double[instance.numClasses()];
    int attIndex;
    
    for (int j = 0; j < instance.numClasses(); j++) {
      probs[j] = 1;
      Enumeration enumAtts = instance.enumerateAttributes();
      attIndex = 0;
      while (enumAtts.hasMoreElements()) {
	Attribute attribute = (Attribute) enumAtts.nextElement();
	if (!instance.isMissing(attribute)) {
	  if (attribute.isNominal()) {
	    probs[j] *= m_Counts[j][attIndex][(int)instance.value(attribute)];
	  } else {
	    probs[j] *= normalDens(instance.value(attribute),
				   m_Means[j][attIndex],
				   m_Devs[j][attIndex]);}
	}
	attIndex++;
      }
      probs[j] *= m_Priors[j];
    }
    
    // Normalize probabilities
    Utils.normalize(probs);

    return probs;
  }

   /**
   * Classifies the given test instance. The instance has to belong to a
   * dataset when it's being classified.
   *
   * @param instance the instance to be classified
   * @return the predicted most likely class for the instance or 
   * Instance.missingValue() if no prediction is made
   * @exception Exception if an error occurred during the prediction
   */
  public double classifyInstance(Instance instance) throws Exception {

    double [] dist = distributionForInstance(instance);
    if (dist == null) {
      throw new Exception("Null distribution predicted");
    }
    switch (instance.classAttribute().type()) 
    {
    case Attribute.NOMINAL:
      double max = 0;
      int maxIndex = 0;
      
      for (int i = 0; i < dist.length; i++) 
      {
		if (dist[i] > max) 
		{
			  maxIndex = i;
			  max = dist[i];
		}
      }
      if (max > 0) 
      {
		return maxIndex;
      } 
      else
      {
      	return Instance.missingValue();
      }
    case Attribute.NUMERIC:
      return dist[0];
    default:
      return Instance.missingValue();
    }
  }
  
   /**
   * Returns a description of the classifier.
   *
   * @return a description of the classifier as a string.
   */
  public String toString() {

    if (m_Instances == null) {
      return "My Naive Bayes (simple): No model built yet.";
    }
    try {
      StringBuffer text = new StringBuffer("My Naive Bayes (simple)");
      int attIndex;
      
    text.append("\nnumber of Class " + Utils.doubleToString(m_Instances.numClasses(), 10, 8) );
    text.append("\nnumber of Attribute: " + Utils.doubleToString(m_Instances.numAttributes(), 10, 8) );
   	text.append("\nclass attribute's index: " + Utils.doubleToString(m_Instances.classIndex(), 10, 8));
   	text.append("\nnumber of instances: " + Utils.doubleToString(m_Instances.numInstances(), 10, 8) );
   	
   	for (int i = 0; i < m_Instances.numClasses(); i++) {
	text.append("\n\nClass " + m_Instances.classAttribute().value(i) 
		    + ": P(C) = " 
		    + Utils.doubleToString(m_Priors[i], 10, 8)
		    + "\n\n");
   	Enumeration enumAtts = m_Instances.enumerateAttributes();
	attIndex = 0;
	while (enumAtts.hasMoreElements()) {
	  Attribute attribute = (Attribute) enumAtts.nextElement();
	  text.append("Attribute " + attribute.name() + "\n");
	  if (attribute.isNominal()) {
	    for (int j = 0; j < attribute.numValues(); j++) {
	      text.append(attribute.value(j) + "\t");
	    }
	    text.append("\n");
	    for (int j = 0; j < attribute.numValues(); j++)
	      text.append(Utils.
			  doubleToString(m_Counts[i][attIndex][j], 10, 8)
			  + "\t");
	  } else {
	    text.append("Mean: " + Utils.
			doubleToString(m_Means[i][attIndex], 10, 8) + "\t");
	    text.append("Standard Deviation: " 
			+ Utils.doubleToString(m_Devs[i][attIndex], 10, 8));
	  }
	  text.append("\n\n");
	  attIndex++;
	}
      }
      
      return text.toString();
    } catch (Exception e) {
      return "Can't print Naive Bayes classifier!";
    }
  }

  /**
   * Density function of normal distribution.
   */
  private double normalDens(double x, double mean, double stdDev) {
    
    double diff = x - mean;
    
    return (1 / (NORM_CONST * stdDev)) 
      * Math.exp(-(diff * diff / (2 * stdDev * stdDev)));
  }

  /**
   * Main method for testing this class.
   *
   * @param argv the options
   */
  public static void main(String [] argv) {

    Classifier scheme;

    try {
      scheme = new BayesianNaive();
      System.out.println(Evaluation.evaluateModel(scheme, argv));
    } catch (Exception e) {
      System.err.println(e.getMessage());
    }
  }
}


⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -