bayesnet.java

来自「Weka」· Java 代码 · 共 1,118 行 · 第 1/3 页

JAVA
1,118
字号
/*
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 * BayesNet.java
 * Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
 * 
 */
package weka.classifiers.bayes;

import weka.classifiers.Classifier;
import weka.classifiers.bayes.net.ADNode;
import weka.classifiers.bayes.net.BIFReader;
import weka.classifiers.bayes.net.ParentSet;
import weka.classifiers.bayes.net.estimate.BayesNetEstimator;
import weka.classifiers.bayes.net.estimate.DiscreteEstimatorBayes;
import weka.classifiers.bayes.net.estimate.SimpleEstimator;
import weka.classifiers.bayes.net.search.SearchAlgorithm;
import weka.classifiers.bayes.net.search.local.K2;
import weka.classifiers.bayes.net.search.local.LocalScoreSearchAlgorithm;
import weka.classifiers.bayes.net.search.local.Scoreable;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.Capabilities.Capability;
import weka.estimators.Estimator;
import weka.filters.Filter;
import weka.filters.supervised.attribute.Discretize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

import java.util.Enumeration;
import java.util.Vector;

/**
 <!-- globalinfo-start -->
 * Bayes Network learning using various search algorithms and quality measures.<br/>
 * Base class for a Bayes Network classifier. Provides datastructures (network structure, conditional probability distributions, etc.) and facilities common to Bayes Network learning algorithms like K2 and B.<br/>
 * <br/>
 * For more information see:<br/>
 * <br/>
 * http://www.cs.waikato.ac.nz/~remco/weka.pdf
 * <p/>
 <!-- globalinfo-end -->
 * 
 <!-- options-start -->
 * Valid options are: <p/>
 * 
 * <pre> -D
 *  Do not use ADTree data structure
 * </pre>
 * 
 * <pre> -B &lt;BIF file&gt;
 *  BIF file to compare with
 * </pre>
 * 
 * <pre> -Q weka.classifiers.bayes.net.search.SearchAlgorithm
 *  Search algorithm
 * </pre>
 * 
 * <pre> -E weka.classifiers.bayes.net.estimate.SimpleEstimator
 *  Estimator algorithm
 * </pre>
 * 
 <!-- options-end -->
 *
 * @author Remco Bouckaert (rrb@xm.co.nz)
 * @version $Revision: 1.31 $
 */
public class BayesNet extends Classifier implements OptionHandler, WeightedInstancesHandler, Drawable, AdditionalMeasureProducer {

    /** for serialization */
    static final long serialVersionUID = 746037443258775954L;
    
    
    /**
     * The parent sets.
     */
    protected ParentSet[] m_ParentSets;

    /**
     * The attribute estimators containing CPTs.
     */
    public Estimator[][] m_Distributions;


   	/** filter used to quantize continuous variables, if any **/
    protected Discretize m_DiscretizeFilter = null;
    
    /** attribute index of a non-nominal attribute */
    int m_nNonDiscreteAttribute = -1;

    /** filter used to fill in missing values, if any **/
    protected ReplaceMissingValues m_MissingValuesFilter = null;	
	
    /**
     * The number of classes
     */
    protected int m_NumClasses;

    /**
     * The dataset header for the purposes of printing out a semi-intelligible
     * model
     */
    public Instances m_Instances;

    /**
     * Datastructure containing ADTree representation of the database.
     * This may result in more efficient access to the data.
     */
    ADNode m_ADTree;

    /**
     * Bayes network to compare the structure with.
     */
    protected BIFReader m_otherBayesNet = null;

    /**
     * Use the experimental ADTree datastructure for calculating contingency tables
     */
    boolean m_bUseADTree = false;

    /**
     * Search algorithm used for learning the structure of a network.
     */
    SearchAlgorithm m_SearchAlgorithm = new K2();

    /**
     * Search algorithm used for learning the structure of a network.
     */
    BayesNetEstimator m_BayesNetEstimator = new SimpleEstimator();

    /**
     * Returns default capabilities of the classifier.
     *
     * @return      the capabilities of this classifier
     */
    public Capabilities getCapabilities() {
      Capabilities result = super.getCapabilities();

      // attributes
      result.enable(Capability.NOMINAL_ATTRIBUTES);
      result.enable(Capability.NUMERIC_ATTRIBUTES);
      result.enable(Capability.MISSING_VALUES);

      // class
      result.enable(Capability.NOMINAL_CLASS);
      result.enable(Capability.MISSING_CLASS_VALUES);

      // instances
      result.setMinimumNumberInstances(0);
      
      return result;
    }

    /**
     * Generates the classifier.
     * 
     * @param instances set of instances serving as training data
     * @throws Exception if the classifier has not been generated
     * successfully
     */
    public void buildClassifier(Instances instances) throws Exception {

      // can classifier handle the data?
      getCapabilities().testWithFail(instances);

      // remove instances with missing class
      instances = new Instances(instances);
      instances.deleteWithMissingClass();
      
		// ensure we have a data set with discrete variables only and with no missing values
		instances = normalizeDataSet(instances);

        // Copy the instances
        m_Instances = new Instances(instances);

        // sanity check: need more than 1 variable in datat set
        m_NumClasses = instances.numClasses();

        // initialize ADTree
        if (m_bUseADTree) {
            m_ADTree = ADNode.makeADTree(instances);
            //      System.out.println("Oef, done!");
        }

        // build the network structure
        initStructure();

        // build the network structure
        buildStructure();

        // build the set of CPTs
        estimateCPTs();

        // Save space
        // m_Instances = new Instances(m_Instances, 0);
        m_ADTree = null;
    } // buildClassifier

	/** ensure that all variables are nominal and that there are no missing values
	 * @param instances data set to check and quantize and/or fill in missing values
	 * @return filtered instances
	 * @throws Exception if a filter (Discretize, ReplaceMissingValues) fails
	 */
	protected Instances normalizeDataSet(Instances instances) throws Exception {
		m_DiscretizeFilter = null;
		m_MissingValuesFilter = null;

		boolean bHasNonNominal = false;
		boolean bHasMissingValues = false;

		Enumeration enu = instances.enumerateAttributes();		
		while (enu.hasMoreElements()) {
			Attribute attribute = (Attribute) enu.nextElement();
			if (attribute.type() != Attribute.NOMINAL) {
				m_nNonDiscreteAttribute = attribute.index();
				bHasNonNominal = true;
				//throw new UnsupportedAttributeTypeException("BayesNet handles nominal variables only. Non-nominal variable in dataset detected.");
			}
			Enumeration enum2 = instances.enumerateInstances();
			while (enum2.hasMoreElements()) {
				if (((Instance) enum2.nextElement()).isMissing(attribute)) {
					bHasMissingValues = true;
					// throw new NoSupportForMissingValuesException("BayesNet: no missing values, please.");
				}
			}
		}
		        
		if (bHasNonNominal) {
			System.err.println("Warning: discretizing data set");
			m_DiscretizeFilter = new Discretize();
			m_DiscretizeFilter.setInputFormat(instances);
			instances = Filter.useFilter(instances, m_DiscretizeFilter);
		}

		if (bHasMissingValues) {
			System.err.println("Warning: filling in missing values in data set");
			m_MissingValuesFilter = new ReplaceMissingValues();
			m_MissingValuesFilter.setInputFormat(instances);
			instances = Filter.useFilter(instances, m_MissingValuesFilter);
		}
		return instances;
	} // normalizeDataSet

	/** ensure that all variables are nominal and that there are no missing values
	 * @param instance instance to check and quantize and/or fill in missing values
	 * @return filtered instance
	 * @throws Exception if a filter (Discretize, ReplaceMissingValues) fails
	 */
	protected Instance normalizeInstance(Instance instance) throws Exception {
		if ((m_DiscretizeFilter != null) &&
			(instance.attribute(m_nNonDiscreteAttribute).type() != Attribute.NOMINAL)) {
			m_DiscretizeFilter.input(instance);
			instance = m_DiscretizeFilter.output();
		}
		if (m_MissingValuesFilter != null) {
			m_MissingValuesFilter.input(instance);
			instance = m_MissingValuesFilter.output();
		} else {
			// is there a missing value in this instance?
			// this can happen when there is no missing value in the training set
			for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) {
				if (iAttribute != instance.classIndex() && instance.isMissing(iAttribute)) {
					System.err.println("Warning: Found missing value in test set, filling in values.");
					m_MissingValuesFilter = new ReplaceMissingValues();
					m_MissingValuesFilter.setInputFormat(m_Instances);
					Filter.useFilter(m_Instances, m_MissingValuesFilter);
					m_MissingValuesFilter.input(instance);
					instance = m_MissingValuesFilter.output();
					iAttribute = m_Instances.numAttributes();
				}
			}
		}
		return instance;
	} // normalizeInstance

    /**
     * Init structure initializes the structure to an empty graph or a Naive Bayes
     * graph (depending on the -N flag).
     * 
     * @throws Exception in case of an error
     */
    public void initStructure() throws Exception {

        // initialize topological ordering
        //    m_nOrder = new int[m_Instances.numAttributes()];
        //    m_nOrder[0] = m_Instances.classIndex();

        int nAttribute = 0;

        for (int iOrder = 1; iOrder < m_Instances.numAttributes(); iOrder++) {
            if (nAttribute == m_Instances.classIndex()) {
                nAttribute++;
            }

            //      m_nOrder[iOrder] = nAttribute++;
        }

        // reserve memory
        m_ParentSets = new ParentSet[m_Instances.numAttributes()];

        for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) {
            m_ParentSets[iAttribute] = new ParentSet(m_Instances.numAttributes());
        }
    } // initStructure

    /**
     * buildStructure determines the network structure/graph of the network.
     * The default behavior is creating a network where all nodes have the first
     * node as its parent (i.e., a BayesNet that behaves like a naive Bayes classifier).
     * This method can be overridden by derived classes to restrict the class
     * of network structures that are acceptable.
     * 
     * @throws Exception in case of an error
     */
    public void buildStructure() throws Exception {
        m_SearchAlgorithm.buildStructure(this, m_Instances);
    } // buildStructure

    /**
     * estimateCPTs estimates the conditional probability tables for the Bayes
     * Net using the network structure.
     * 
     * @throws Exception in case of an error
     */
    public void estimateCPTs() throws Exception {
        m_BayesNetEstimator.estimateCPTs(this);
    } // estimateCPTs

    /**
     * initializes the conditional probabilities
     * 
     * @throws Exception in case of an error
     */
    public void initCPTs() throws Exception {
        m_BayesNetEstimator.initCPTs(this);
    } // estimateCPTs

    /**
     * Updates the classifier with the given instance.
     * 
     * @param instance the new training instance to include in the model
     * @throws Exception if the instance could not be incorporated in
     * the model.
     */
    public void updateClassifier(Instance instance) throws Exception {
		instance = normalizeInstance(instance);
    	m_BayesNetEstimator.updateClassifier(this, instance);
    } // updateClassifier

    /**
     * Calculates the class membership probabilities for the given test

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?