📄 variableelimination.java
字号:
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.grmm;import java.util.HashSet;import java.util.Set;import java.util.Iterator;import java.util.Collection;/** * The variable elimination algorithm for inference in graphical * models. * * Created: Mon Sep 22 17:34:00 2003 * * @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a> * @version $Id: VariableElimination.java,v 1.1 2004/07/15 17:53:31 casutton Exp $ */public class VariableElimination extends AbstractInferencer { private MultinomialPotential eliminate (Collection allPhi, Variable node) { HashSet phiSet = new HashSet(); /* collect the potentials that include this variable */ for (Iterator j = allPhi.iterator(); j.hasNext(); ) { DiscretePotential cpf = (DiscretePotential) j.next (); if (cpf.containsVar (node)) { phiSet.add (cpf); j.remove (); } } return MultinomialPotential.multiplyAll (phiSet); } /** * The bulk of the variable-elimination algorithm. Returns the * marginal density of the variable QUERY in the undirected * model MODEL, except that the density is un-normalized. * The normalization is done in a separate function to make * computeNormalizationFactor easier. */ private DiscretePotential unnormalizedQuery (UndirectedModel model, Variable query) { /* here the elimination order is random */ /* note that using buckets would make this more efficient as well. */ /* make a copy of potentials */ HashSet allPhi = new HashSet(); for (Iterator i = model.potentialsIterator(); i.hasNext(); ){ MultinomialPotential cpf = (MultinomialPotential) i.next (); allPhi.add(cpf.clone()); } Set nodes = model.getVertexSet(); /* Eliminate each node in turn */ for (Iterator i = nodes.iterator(); i.hasNext(); ) { Variable node = (Variable) i.next(); if (node == query) continue; // Eliminate the query variable last! DiscretePotential newCPF = eliminate (allPhi, node); /* Extract (marginalize) over this variables */ DiscretePotential singleCPF; if(newCPF.varSet().size() == 1) { singleCPF = newCPF; } else { singleCPF = newCPF.marginalizeOut (node); } /* add it back to the list of potentials */ allPhi.add(singleCPF); } /* Now, all the potentials that are left should contain only the * query variable.... UNLESS the graph is disconnected. So just * eliminate the query var. */ DiscretePotential marginal = eliminate (allPhi, query); assert marginal.containsVar (query); assert marginal.varSet().size() == 1; return marginal; } /** * Computes the normalization constant for a model. */ public double computeNormalizationFactor (UndirectedModel m) { /* What we'll do is get the unnormalized marginal of an arbitrary * node; then sum the marginal to get the normalization factor. */ Variable var = (Variable) m.getVertexSet().iterator().next(); DiscretePotential marginal = unnormalizedQuery (m, var); return marginal.sum (); } UndirectedModel mdlCurrent; // Inert. All work done in lookupMarginal(). public void computeMarginals (UndirectedModel m) { mdlCurrent = m; } public DiscretePotential lookupMarginal (Variable var) { DiscretePotential marginal = unnormalizedQuery (mdlCurrent, var); marginal.normalize(); return marginal; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -