📄 bayesnet.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* BayesNet.java
* Copyright (C) 2001 Remco Bouckaert
*
*/
package weka.classifiers.bayes;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.bayes.net.ADNode;
import weka.classifiers.bayes.net.BIFReader;
import weka.classifiers.bayes.net.ParentSet;
import weka.classifiers.bayes.net.estimate.BayesNetEstimator;
import weka.classifiers.bayes.net.estimate.DiscreteEstimatorBayes;
import weka.classifiers.bayes.net.estimate.SimpleEstimator;
import weka.classifiers.bayes.net.search.SearchAlgorithm;
import weka.classifiers.bayes.net.search.local.K2;
import weka.classifiers.bayes.net.search.local.LocalScoreSearchAlgorithm;
import weka.classifiers.bayes.net.search.local.Scoreable;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.UnsupportedAttributeTypeException;
import weka.core.UnsupportedClassTypeException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.estimators.Estimator;
import weka.filters.supervised.attribute.Discretize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
/**
* Base class for a Bayes Network classifier. Provides datastructures (network structure,
* conditional probability distributions, etc.) and facilities common to Bayes Network
* learning algorithms like K2 and B.
* Works with nominal variables and no missing values only.
*
* For further documentation, see
* <a href='http://www.cs.waikato.ac.nz/~remco/weka.pdf>Bayesian networks in Weka</a>
* user documentation.
*
* @author Remco Bouckaert (rrb@xm.co.nz)
* @version $Revision$
*/
public class BayesNet extends Classifier implements OptionHandler, WeightedInstancesHandler, Drawable, AdditionalMeasureProducer {
/**
*
*/
private static final long serialVersionUID = 2289809244537243787L;
static {
try {
java.beans.PropertyEditorManager.registerEditor(
weka.classifiers.bayes.net.search.SearchAlgorithm.class,
weka.gui.GenericObjectEditor.class);
java.beans.PropertyEditorManager.registerEditor(
weka.classifiers.bayes.net.estimate.BayesNetEstimator.class,
weka.gui.GenericObjectEditor.class);
} catch (Throwable t) {
// ignore
}
}
/**
* The parent sets.
*/
protected ParentSet[] m_ParentSets;
/**
* The attribute estimators containing CPTs.
*/
public Estimator[][] m_Distributions;
/** filter used to quantize continuous variables, if any **/
Discretize m_DiscretizeFilter = null;
int m_nNonDiscreteAttribute = -1;
/** filter used to fill in missing values, if any **/
ReplaceMissingValues m_MissingValuesFilter = null;
/**
* The number of classes
*/
protected int m_NumClasses;
/**
* The dataset header for the purposes of printing out a semi-intelligible
* model
*/
public Instances m_Instances;
/**
* Datastructure containing ADTree representation of the database.
* This may result in more efficient access to the data.
*/
ADNode m_ADTree;
/**
* Bayes network to compare the structure with.
*/
protected BIFReader m_otherBayesNet = null;
/**
* Use the experimental ADTree datastructure for calculating contingency tables
*/
boolean m_bUseADTree = false;
/**
* Search algorithm used for learning the structure of a network.
*/
SearchAlgorithm m_SearchAlgorithm = new K2();
/**
* Search algorithm used for learning the structure of a network.
*/
BayesNetEstimator m_BayesNetEstimator = new SimpleEstimator();
/**
* Generates the classifier.
*
* @param instances set of instances serving as training data
* @exception Exception if the classifier has not been generated
* successfully
*/
public void buildClassifier(Instances instances) throws Exception {
// Check that class is nominal
if (!instances.classAttribute().isNominal()) {
throw new UnsupportedClassTypeException("BayesNet: nominal class, please.");
}
// ensure we have a data set with discrete variables only and with no missing values
instances = normalizeDataSet(instances);
// Copy the instances
m_Instances = new Instances(instances);
// sanity check: need more than 1 variable in datat set
m_NumClasses = instances.numClasses();
// initialize ADTree
if (m_bUseADTree) {
m_ADTree = ADNode.makeADTree(instances);
// System.out.println("Oef, done!");
}
// build the network structure
initStructure();
// build the network structure
buildStructure();
// build the set of CPTs
estimateCPTs();
// Save space
// m_Instances = new Instances(m_Instances, 0);
m_ADTree = null;
} // buildClassifier
/** ensure that all variables are nominal and that there are no missing values
* @param instances: data set to check and quantize and/or fill in missing values
* @return filtered instances
* @throws Exception
*/
Instances normalizeDataSet(Instances instances) throws Exception {
m_DiscretizeFilter = null;
m_MissingValuesFilter = null;
boolean bHasNonNominal = false;
boolean bHasMissingValues = false;
Enumeration em = instances.emerateAttributes();
while (em.hasMoreElements()) {
Attribute attribute = (Attribute) em.nextElement();
if (attribute.type() == Attribute.STRING) {
throw new UnsupportedAttributeTypeException("BayesNet does not handle string variables, only nominal and continuous.");
}
if (attribute.type() == Attribute.DATE) {
throw new UnsupportedAttributeTypeException("BayesNet does not handle date variables, only nominal and continuous.");
}
if (attribute.type() != Attribute.NOMINAL) {
m_nNonDiscreteAttribute = attribute.index();
bHasNonNominal = true;
//throw new UnsupportedAttributeTypeException("BayesNet handles nominal variables only. Non-nominal variable in dataset detected.");
}
Enumeration em2 = instances.emerateInstances();
while (em2.hasMoreElements()) {
if (((Instance) em2.nextElement()).isMissing(attribute)) {
bHasMissingValues = true;
// throw new NoSupportForMissingValuesException("BayesNet: no missing values, please.");
}
}
}
if (bHasNonNominal) {
System.err.println("Warning: discretizing data set");
m_DiscretizeFilter = new Discretize();
m_DiscretizeFilter.setInputFormat(instances);
instances = Discretize.useFilter(instances, m_DiscretizeFilter);
}
if (bHasMissingValues) {
System.err.println("Warning: filling in missing values in data set");
m_MissingValuesFilter = new ReplaceMissingValues();
m_MissingValuesFilter.setInputFormat(instances);
instances = ReplaceMissingValues.useFilter(instances, m_MissingValuesFilter);
}
return instances;
} // normalizeDataSet
/** ensure that all variables are nominal and that there are no missing values
* @param instance: instance to check and quantize and/or fill in missing values
* @return filtered instance
* @throws Exception
*/
Instance normalizeInstance(Instance instance) throws Exception {
if ((m_DiscretizeFilter != null) &&
(instance.attribute(m_nNonDiscreteAttribute).type() != Attribute.NOMINAL)) {
m_DiscretizeFilter.input(instance);
instance = m_DiscretizeFilter.output();
}
if (m_MissingValuesFilter != null) {
m_MissingValuesFilter.input(instance);
instance = m_MissingValuesFilter.output();
} else {
// is there a missing value in this instance?
// this can happen when there is no missing value in the training set
for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) {
if (iAttribute != instance.classIndex() && instance.isMissing(iAttribute)) {
System.err.println("Warning: Found missing value in test set, filling in values.");
m_MissingValuesFilter = new ReplaceMissingValues();
m_MissingValuesFilter.setInputFormat(m_Instances);
ReplaceMissingValues.useFilter(m_Instances, m_MissingValuesFilter);
m_MissingValuesFilter.input(instance);
instance = m_MissingValuesFilter.output();
iAttribute = m_Instances.numAttributes();
}
}
}
return instance;
} // normalizeInstance
/**
* Init structure initializes the structure to an empty graph or a Naive Bayes
* graph (depending on the -N flag).
*/
public void initStructure() throws Exception {
// initialize topological ordering
// m_nOrder = new int[m_Instances.numAttributes()];
// m_nOrder[0] = m_Instances.classIndex();
int nAttribute = 0;
for (int iOrder = 1; iOrder < m_Instances.numAttributes(); iOrder++) {
if (nAttribute == m_Instances.classIndex()) {
nAttribute++;
}
// m_nOrder[iOrder] = nAttribute++;
}
// reserve memory
m_ParentSets = new ParentSet[m_Instances.numAttributes()];
for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) {
m_ParentSets[iAttribute] = new ParentSet(m_Instances.numAttributes());
}
} // initStructure
/**
* buildStructure determines the network structure/graph of the network.
* The default behavior is creating a network where all nodes have the first
* node as its parent (i.e., a BayesNet that behaves like a naive Bayes classifier).
* This method can be overridden by derived classes to restrict the class
* of network structures that are acceptable.
*/
public void buildStructure() throws Exception {
m_SearchAlgorithm.buildStructure(this, m_Instances);
} // buildStructure
/**
* estimateCPTs estimates the conditional probability tables for the Bayes
* Net using the network structure.
*/
public void estimateCPTs() throws Exception {
m_BayesNetEstimator.estimateCPTs(this);
} // estimateCPTs
public void initCPTs() throws Exception {
m_BayesNetEstimator.initCPTs(this);
} // estimateCPTs
/**
* Updates the classifier with the given instance.
*
* @param instance the new training instance to include in the model
* @exception Exception if the instance could not be incorporated in
* the model.
*/
public void updateClassifier(Instance instance) throws Exception {
instance = normalizeInstance(instance);
m_BayesNetEstimator.updateClassifier(this, instance);
} // updateClassifier
/**
* Calculates the class membership probabilities for the given test
* instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @exception Exception if there is a problem generating the prediction
*/
public double[] distributionForInstance(Instance instance) throws Exception {
instance = normalizeInstance(instance);
return m_BayesNetEstimator.distributionForInstance(this, instance);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -