📄 bayesnet.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * BayesNet.java * Copyright (C) 2001 Remco Bouckaert * */package weka.classifiers.bayes;import java.io.*;import java.util.*;import weka.core.*;import weka.estimators.*;import weka.classifiers.*;/** * Base class for a Bayes Network classifier. Provides datastructures (network structure, * conditional probability distributions, etc.) and facilities common to Bayes Network * learning algorithms like K2 and B. * Works with nominal variables and no missing values only. * * @author Remco Bouckaert (rrb@xm.co.nz) * @version $Revision: 1.1.1.1 $ */public class BayesNet extends DistributionClassifier implements OptionHandler, WeightedInstancesHandler { /** * topological ordering of the network */ protected int[] m_nOrder; /** * The parent sets. */ protected ParentSet[] m_ParentSets; /** * The attribute estimators containing CPTs. */ protected Estimator[][] m_Distributions; /** * The number of classes */ protected int m_NumClasses; /** * The dataset header for the purposes of printing out a semi-intelligible * model */ public Instances m_Instances; ADNode m_ADTree; public static final Tag [] TAGS_SCORE_TYPE = { new Tag(Scoreable.BAYES, "BAYES"), new Tag(Scoreable.MDL, "MDL"), new Tag(Scoreable.ENTROPY, "ENTROPY"), new Tag(Scoreable.AIC, "AIC") }; /** * Holds the score type used to measure quality of network */ int m_nScoreType = Scoreable.BAYES; /** * Holds prior on count */ double m_fAlpha = 0.5; /** * Holds upper bound on number of parents */ int m_nMaxNrOfParents = 100000; /** * determines whether initial structure is an empty graph or a Naive Bayes network */ boolean m_bInitAsNaiveBayes = true; /** * Use the experimental ADTree datastructure for calculating contingency tables */ boolean m_bUseADTree = true; /** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated * successfully */ public void buildClassifier(Instances instances) throws Exception { // Check that class is nominal if (!instances.classAttribute().isNominal()) { throw new UnsupportedClassTypeException("BayesNet: nominal class, please."); } // check that all variables are nominal and that there // are no missing values Enumeration enum = instances.enumerateAttributes(); while (enum.hasMoreElements()) { Attribute attribute = (Attribute) enum.nextElement(); if (attribute.type() != Attribute.NOMINAL) { throw new UnsupportedAttributeTypeException("BayesNet handles nominal variables only. Non-nominal variable in dataset detected."); } Enumeration enum2 = instances.enumerateInstances(); while (enum2.hasMoreElements()) { if (((Instance) enum2.nextElement()).isMissing(attribute)) { throw new NoSupportForMissingValuesException("BayesNet: no missing values, please."); } } } // Copy the instances m_Instances = new Instances(instances); // sanity check: need more than 1 variable in datat set m_NumClasses = instances.numClasses(); // initialize ADTree if (m_bUseADTree) { m_ADTree = ADNode.MakeADTree(instances);// System.out.println("Oef, done!"); } // build the network structure initStructure(); // build the network structure buildStructure(); // build the set of CPTs estimateCPTs(); // Save space // m_Instances = new Instances(m_Instances, 0); m_ADTree = null; } // buildClassifier /** * Init structure initializes the structure to an empty graph or a Naive Bayes * graph (depending on the -N flag). */ public void initStructure() throws Exception { // initialize topological ordering m_nOrder = new int[m_Instances.numAttributes()]; m_nOrder[0] = m_Instances.classIndex(); int nAttribute = 0; for (int iOrder = 1; iOrder < m_Instances.numAttributes(); iOrder++) { if (nAttribute == m_Instances.classIndex()) { nAttribute++; } m_nOrder[iOrder] = nAttribute++; } // reserce memory m_ParentSets = new ParentSet[m_Instances.numAttributes()]; for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { m_ParentSets[iAttribute] = new ParentSet(m_Instances.numAttributes()); } if (m_bInitAsNaiveBayes) { // initialize parent sets to have arrow from classifier node to // each of the other nodes for (int iOrder = 1; iOrder < m_Instances.numAttributes(); iOrder++) { int iAttribute = m_nOrder[iOrder]; m_ParentSets[iAttribute].AddParent(m_Instances.classIndex(), m_Instances); } } } /** * buildStructure determines the network structure/graph of the network. * The default behavior is creating a network where all nodes have the first * node as its parent (i.e., a BayesNet that behaves like a naive Bayes classifier). * This method can be overridden by derived classes to restrict the class * of network structures that are acceptable. */ public void buildStructure() throws Exception { // place holder for structure learing algorithms like K2, B, etc. } // buildStructure /** * estimateCPTs estimates the conditional probability tables for the Bayes * Net using the network structure. */ public void estimateCPTs() throws Exception { // Reserve space for CPTs int nMaxParentCardinality = 1; for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { if (m_ParentSets[iAttribute].GetCardinalityOfParents() > nMaxParentCardinality) { nMaxParentCardinality = m_ParentSets[iAttribute].GetCardinalityOfParents(); } } // Reserve plenty of memory m_Distributions = new Estimator[m_Instances.numAttributes()][nMaxParentCardinality]; // estimate CPTs for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { for (int iParent = 0; iParent < m_ParentSets[iAttribute].GetCardinalityOfParents(); iParent++) { m_Distributions[iAttribute][iParent] = new DiscreteEstimatorBayes(m_Instances.attribute(iAttribute) .numValues(), m_fAlpha); } } // Compute counts Enumeration enumInsts = m_Instances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = (Instance) enumInsts.nextElement(); updateClassifier(instance); } } // estimateCPTs /** * Updates the classifier with the given instance. * * @param instance the new training instance to include in the model * @exception Exception if the instance could not be incorporated in * the model. */ public void updateClassifier(Instance instance) throws Exception { for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { double iCPT = 0; for (int iParent = 0; iParent < m_ParentSets[iAttribute].GetNrOfParents(); iParent++) { int nParent = m_ParentSets[iAttribute].GetParent(iParent); iCPT = iCPT * m_Instances.attribute(nParent).numValues() + instance.value(nParent); } m_Distributions[iAttribute][(int) iCPT] .addValue(instance.value(iAttribute), instance.weight()); } } // updateClassifier /** * Calculates the class membership probabilities for the given test * instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if there is a problem generating the prediction */ public double[] distributionForInstance(Instance instance) throws Exception { double[] fProbs = new double[m_NumClasses]; for (int iClass = 0; iClass < m_NumClasses; iClass++) { fProbs[iClass] = 1.0; } for (int iClass = 0; iClass < m_NumClasses; iClass++) { double logfP = 0; for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { double iCPT = 0; for (int iParent = 0; iParent < m_ParentSets[iAttribute].GetNrOfParents(); iParent++) { int nParent = m_ParentSets[iAttribute].GetParent(iParent); if (nParent == m_Instances.classIndex()) { iCPT = iCPT * m_NumClasses + iClass; } else { iCPT = iCPT * m_Instances.attribute(nParent).numValues() + instance.value(nParent); } } if (iAttribute == m_Instances.classIndex()) {// fP *= // m_Distributions[iAttribute][(int) iCPT].getProbability(iClass); logfP += Math.log(m_Distributions[iAttribute][(int) iCPT].getProbability(iClass)); } else {// fP *= // m_Distributions[iAttribute][(int) iCPT]// .getProbability(instance.value(iAttribute)); logfP += Math.log(m_Distributions[iAttribute][(int) iCPT] .getProbability(instance.value(iAttribute))); } } // fProbs[iClass] *= fP; fProbs[iClass] += logfP; } // Find maximum double fMax = fProbs[0]; for (int iClass = 0; iClass < m_NumClasses; iClass++) { if (fProbs[iClass] > fMax) { fMax = fProbs[iClass]; } } // transform from log-space to normal-space for (int iClass = 0; iClass < m_NumClasses; iClass++) { fProbs[iClass] = Math.exp(fProbs[iClass] - fMax); } // Display probabilities Utils.normalize(fProbs); return fProbs; } // distributionForInstance /** * Calculates the counts for Dirichlet distribution for the * class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return counts for Dirichlet distribution for class probability * @exception Exception if there is a problem generating the prediction */ public double[] countsForInstance(Instance instance) throws Exception { double[] fCounts = new double[m_NumClasses]; for (int iClass = 0; iClass < m_NumClasses; iClass++) { fCounts[iClass] = 0.0; } for (int iClass = 0; iClass < m_NumClasses; iClass++) { double fCount = 0; for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { double iCPT = 0; for (int iParent = 0;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -