📄 dotwriter.py
字号:
# -*- coding: iso-8859-1 -*-## This file is a part of the Bayes Blocks library## Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander# Ilin, Tapani Raiko, Harri Valpola and Tomas 謘tman.## This program is free software; you can redistribute it and/or modify# it under the terms of the GNU General Public License as published by# the Free Software Foundation; either version 2, or (at your option)# any later version.## This program is distributed in the hope that it will be useful,# but WITHOUT ANY WARRANTY; without even the implied warranty of# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the# GNU General Public License (included in file License.txt in the# program package) for more details.## $Id: DotWriter.py 7 2006-10-26 10:26:41Z ah $#"""Convert the BB network to dot source code"""from __future__ import nested_scopesimport bblocks.Label as Labelimport bblocks.Helpers as Helpersimport copyimport stringimport osimport os.path# All node types supported by Bayes Blocks librarynodetypes = [ "Constant", "ConstantV", "Prod", "ProdV", "Sum2", "Sum2V", "SumN", "SumNV", "Rectification", "RectificationV", "DelayV", "Gaussian", "GaussianV", "DelayGaussV", "SparseGaussV", "RectifiedGaussian", "RectifiedGaussianV", "GaussRect", "GaussRectV", "GaussNonlin", "GaussNonlinV", "MoG", "MoGV", "Discrete", "DiscreteV", "DiscreteDirichlet", "DiscreteDirichletV", "Dirichlet", "Proxy", "Relay", "Evidence", "EvidenceV", "Memory", "OLDelayS", "OLDelayD", ]# Node types representing variablesvariables = [ "Gaussian", "GaussianV", "DelayGaussV", "SparseGaussV", "RectifiedGaussian", "RectifiedGaussianV", "GaussRect", "GaussRectV", "GaussNonlin", "GaussNonlinV", "MoG", "MoGV", "Discrete", "DiscreteV", "DiscreteDirichlet", "DiscreteDirichletV", "Dirichlet", ]# Possible node propertiesall_properties = {'shape': ['ellipse', 'box'], 'style': ['plain', 'bold', 'filled'], 'peripheries': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}class NodeProperties: """A class to represent visual node properties supported by dot""" def __init__(self, shape='ellipse', style='plain', color='#000000', fillcolor='#bebebe', peripheries=1): self.SetShape(shape) self.SetStyle(style) self.SetColor(color) self.SetFillcolor(fillcolor) self.SetPeripheries(peripheries) def SetShape(self, shape): if shape not in all_properties['shape']: raise ValueError('Invalid shape') self.shape = shape def SetStyle(self, style): if style not in all_properties['style']: raise ValueError('Invalid style') self.style = style def SetColor(self, color): self.color = color def SetFillcolor(self, color): self.fillcolor = color def SetPeripheries(self, peripheries): self.peripheries = peripheries def SetAll(self, propd): """Set all properties to values defined in given dictionary""" self.SetShape(propd['shape']) self.SetStyle(propd['style']) self.SetPeripheries(propd['peripheries']) self.SetColor(propd['color']) self.SetFillcolor(propd['fillcolor']) def PropertyDict(self): """Return the values of the properties in a dictionary""" return {'shape': self.shape, 'style': self.style, 'peripheries': self.peripheries, 'color': self.color, 'fillcolor': self.fillcolor} def PropertyString(self): """Return a string of dot source code with the specified properties""" components = [] if self.shape != 'ellipse': components.append('shape=%s' % self.shape) if self.style not in ['', 'plain']: components.append('style=%s' % self.style) if self.color != '#000000': components.append('color="%s"' % self.color) if self.fillcolor != '#bebebe': components.append('fillcolor="%s"' % self.fillcolor) if self.peripheries != 1: components.append('peripheries=%d' % self.peripheries) if len(components) == 0: return '' else: return '[' + string.join(components, ',') + ']'# Default properties for different node typesnodeproperties = {}for n in nodetypes: prop = NodeProperties() if n.endswith("V"): prop.SetStyle("bold") if n.startswith("Discrete"): prop.SetShape("box") if n.startswith("Evidence"): prop.SetColor("#bebebe") nodeproperties[n] = propclass PlainWriter: """BBNet to dot source converter with network modification properties""" def __init__(self, net): """Create a writer for the given network (Net or PyNet)""" self.net = net self.net.SortNodes() # Create an auxiliary representation of the network self.d = {} for i in range(self.net.NodeCount()): node = self.net.GetNodeByIndex(i) l = [node.GetType(), []] j = 0 p = node.GetParent(j) while p: l[1].append(p.GetLabel()) j += 1 p = node.GetParent(j) self.d[node.GetLabel()] = l # Make an active copy of the representation for modification self.d_act = copy.deepcopy(self.d) # Node visual properties specified by label self.specpropsauto = {} self.specpropsmanual = {} # Get a list of node types in the network and add new ones to # the default list self.types = copy.deepcopy(nodetypes) self.properties = copy.deepcopy(nodeproperties) self.variables = copy.deepcopy(variables) for n in self.d.keys(): if self.d[n][0] not in self.types: self.types.append(self.d[n][0]) self.properties[self.d[n][0]] = NodeProperties() if self.net.GetVariable(n) is not None: self.variables.append(self.d[n][0]) # Number of last indices to drop by label stem self.indexdrops = {} # Drop all indices? self.dropall = 0 def WriteGraph(self, f): """Write the current representation to a dot source file""" if type(f) == type('str'): f = open(f, 'w') f.write("// This file was automatically generated by Bayes Blocks Visualiser\n") f.write("digraph G {\n") for n in self.d_act.keys(): if self.specpropsmanual.has_key(n): f.write(" \"" + n + "\" " + self.specpropsmanual[n].PropertyString() + ";\n") elif self.specpropsauto.has_key(n): f.write(" \"" + n + "\" " + self.specpropsauto[n].PropertyString() + ";\n") else: f.write(" \"" + n + "\" " + self.properties[self.d_act[n][0]].PropertyString() + ";\n") for p in self.d_act[n][1]: f.write(" \"" + p + "\" -> \"" + n + "\";\n") f.write("}\n") f.close() def WriteToTk(self, fname): """Write a graph and convert it to Tk source using dot and fig2dev""" f = open(fname, 'w') (pin, pout) = os.popen2('dot -Tfig | fig2dev -Ltk') self.WriteGraph(pin) for line in pout.readlines(): f.write(line) pout.close() f.close() def ResetFilter(self): """Restore the working copy to be equal to the original""" self.dropall = 0 self.specpropsauto = {} self.d_act = copy.deepcopy(self.d) def FilterType(self, type): """Remove all nodes of given types (list) from the graph""" removed = [] for n in self.d_act.keys(): if self.d_act[n][0] in type: removed.append(n) del self.d_act[n] for n in self.d_act.keys(): self.d_act[n][1] = filter( lambda x: x not in removed, self.d_act[n][1]) def FilterButType(self, type): """Remove nodes of all but given types (list) from the graph""" removed = [] for n in self.d_act.keys(): if self.d_act[n][0] not in type: removed.append(n) for r in removed: for n in self.d_act.keys(): if r in self.d_act[n][1]: self.d_act[n][1] = filter(lambda x: x != r, self.d_act[n][1]) self.d_act[n][1] += self.d_act[r][1] del self.d_act[r] def FilterConst(self): """Remove constant nodes""" self.FilterType(['Constant', 'ConstantV']) def FilterEvidence(self): """Remove evidence nodes""" self.FilterType(['Evidence', 'EvidenceV']) def CombineSumtrees(self): """Hide the internal structure of sum trees by combining them to a single node""" # Phase 1: Find all Sum2(V) nodes s = [] for n in self.d_act.keys(): if (self.d_act[n][0] in ['Sum2', 'Sum2V']): s.append(n) # Phase 2: Remove all that are not roots of sum trees l = s for n in s: l = filter(lambda x: x not in self.d_act[n][1], l) # Phase 3: Combine all other nodes in given tree with the root for n in l: parents = self.d_act[n][1] k = 0 while k < len(parents): if not self.d_act.has_key(parents[k]): del parents[k] elif (self.d_act[parents[k]][0] in ['Sum2', 'Sum2V']): parents += self.d_act[parents[k]][1] del self.d_act[parents[k]] del parents[k] else: k += 1 parents.sort() parents = Helpers.Unique(parents) self.d_act[n] = [self.d_act[n][0], parents] def DropIndices(self): """Drop all indices from labels""" # Phase 1: Drop all indices, combine nodes when necessary d = {} for n in self.d_act.keys(): stem, inds = Label.Unlabel(n) if d.has_key(stem): d[stem][1] += map( lambda x: Label.Unlabel(x)[0], self.d_act[n][1]) else: if len(inds) > 0: # Add peripheries to show indices are being dropped self.specpropsauto[stem] = copy.deepcopy( self.properties[self.d_act[n][0]]) self.specpropsauto[stem].SetPeripheries(len(inds) + 1) d[stem] = [self.d_act[n][0], map( lambda x: Label.Unlabel(x)[0], self.d_act[n][1])] # Phase 2: Cleanup by removing possible duplicate parents for n in d.keys(): d[n][1].sort() d[n][1] = Helpers.Unique(d[n][1]) self.d_act = d self.dropall = 1 def GetAllLabels(self): """Return a dictionary with all occurring label stems and numbers of indices in the corresponding labels""" d = {} if self.dropall: for n in self.d_act.keys(): stem, inds = Label.Unlabel(n) if not d.has_key(stem): if self.specpropsauto.has_key(stem): d[stem] = len(inds) + self.indexdrops.get(stem, 0) + \ self.specpropsauto[stem].peripheries - 1 else: d[stem] = len(inds) + self.indexdrops.get(stem, 0) else: for n in self.d_act.keys(): stem, inds = Label.Unlabel(n) if not d.has_key(stem): d[stem] = len(inds) + self.indexdrops.get(stem, 0) return d def TruncateLabel(self, label): """Drop the extra indices from the end of given label according to the preset list of what to drop""" stem, inds = Label.Unlabel(label) if self.indexdrops.has_key(stem) and self.indexdrops[stem] > 0: label = Label.Label(stem, inds[:-self.indexdrops[stem]]) return label def SelectivelyDropIndices(self): """Selectively drop indices from the labels""" # Phase 0: check whether we are needed at all need_to_run = 0 for n in self.indexdrops.keys(): if self.indexdrops[n] > 0: need_to_run = 1 break if not need_to_run: return d = {} # Phase 1: drop indices from dictionary keys and combine items for n in self.d_act.keys(): newlabel = self.TruncateLabel(n) if d.has_key(newlabel): d[newlabel][1] += self.d_act[n][1] else: d[newlabel] = self.d_act[n] # Phase 2: drop indices from lists in dictionary for n in d.keys(): d[n][1] = map(self.TruncateLabel, d[n][1]) d[n][1].sort() d[n][1] = Helpers.Unique(d[n][1]) self.d_act = d
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -