📄 viterbipropagation.java
字号:
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.grmm;import salvo.jesus.graph.GraphException;import java.util.Set;import java.util.Collection;import java.util.Iterator;import salvo.jesus.graph.Graph;import java.util.logging.Logger;/** * Uses max-product propagation to calculate most * likely values of hidden nodes, given the evidence. * * Created: Wed Oct 1 11:28:32 2003 * * @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a> * @version $Id: ViterbiPropagation.java,v 1.1 2004/07/15 17:53:31 casutton Exp $ */public class ViterbiPropagation extends BeliefPropagation{ public ViterbiPropagation () { normalizeBeliefs = false; } private boolean delogify = true; public boolean getDelogify() { return delogify; } public void setDelogify(boolean d) { delogify = d; } public boolean getInLogSpace() { return inLogSpace; } public void setInLogSpace(boolean d) { inLogSpace = d; } /** * Sends a max-product message from FROM to TO. */ protected void sendMessage (JunctionTree jt, Clique from, Clique to) { logger.fine ("Sending message "+from+" --> "+to); Collection sepset = jt.getSepset (from, to); DiscretePotential fromCpf = jt.getCPF (from); DiscretePotential toCpf = jt.getCPF (to); DiscretePotential oldSepsetPot = jt.getSepsetPot (from, to); DiscretePotential lambda = fromCpf.extractMax (sepset); jt.setSepsetPot (lambda, from, to); toCpf.multiplyBy (lambda); toCpf.divideBy (oldSepsetPot);// logger.finer ("Message was "+lambda); } // TODO: Make more efficient by doing edgePtl.duplicate() first protected void sendMessage (UndirectedModel mdl, Variable from, Variable to) { logger.fine ("Sending message "+from+" --> "+to); int fromIdx = mdl.getIndex (from); int toIdx = mdl.getIndex (to); DiscretePotential product = msgProduct (fromIdx, toIdx); DiscretePotential edgePtl = mdl.potentialOfEdge (from, to); if (inLogSpace) edgePtl = edgePtl.log(); product.multiplyBy (edgePtl); DiscretePotential msg = product.extractMax (to); msg.normalize(); assert msg.varSet().size() == 1 : "Error in message from "+from+" to "+to+": varSet too big: "+msg; assert msg.varSet().contains (to) : "Error in message from "+from+" to "+to+"\n potential "+msg + "does not contain "+to; messages[toIdx][fromIdx] = msg; // logger.finer ("Message was: "+msg); } protected DiscretePotential lookupMarginal (JunctionTree jt, Variable var) { Clique parent = jt.findParentCluster (var); DiscretePotential ptl = jt.getCPF (parent); DiscretePotential maxmarg = ptl.extractMax (var); if (delogify) maxmarg.delogify(); return maxmarg; } protected DiscretePotential lookupMarginal (JunctionTree jt, Clique clique) { Clique parent = jt.findParentCluster (clique); DiscretePotential ptl = jt.getCPF (parent); DiscretePotential maxmarg = ptl.extractMax (clique); if (delogify) maxmarg.delogify(); return maxmarg; } protected DiscretePotential msgProduct (int idx, int excludeMsgFrom) { DiscretePotential sum = new MultinomialPotential (); if (inLogSpace) sum.logify (); for (int j = 0; j < messages[idx].length; j++) { if ((messages[idx][j] != null) && (j != excludeMsgFrom)) { sum.multiplyBy (messages [idx][j]); } } logger.finest ("msgProduct returned "+sum); return sum; } protected DiscretePotential lookupMarginal (UndirectedModel mdl, Variable var) { DiscretePotential ptl = super.lookupMarginal (mdl, var); if (delogify) ptl.delogify(); return ptl; }} // ViterbiPropagation
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -