-
-
Save sjas/40eafc307269f3a8d702160ee0d11c65 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DotGenerator(Visitor): | |
def __init__(self): | |
super().__init__() | |
self.lines = [] | |
self.next_id = 0 | |
def make_name(self, node, name=None): | |
if name is None: | |
name = "n%d" % self.next_id | |
self.next_id += 1 | |
self.values[node] = name | |
return name | |
def vertex(self, name, shape, label): | |
self.lines.append('%s [ shape = %s, label = "%s" ];' % (name, shape, label)) | |
def edge(self, from_name, to_name): | |
self.lines.append('%s:e -> %s:w;' % (from_name, to_name)) | |
def ConstantNode(self, node): | |
name = self.make_name(node) | |
self.vertex(name, 'circle', node.value) | |
return name | |
def BinaryNode(self, node): | |
name = self.make_name(node) | |
self.vertex(name, 'record', '{{<i0>|<i1>}|\\%s}' % node.op) | |
self.edge(self(node.left), name + ':i0') | |
self.edge(self(node.right), name + ':i1') | |
return name | |
def IndexNode(self, node): | |
name = self.make_name(node) | |
self.vertex(name, 'box', '[%d]' % node.index) | |
self.edge(self(node.operand), name) | |
return name | |
def SliceNode(self, node): | |
name = self.make_name(node) | |
self.vertex(name, 'box', '[%d:%d]' % (node.start, node.stop)) | |
self.edge(self(node.operand), name) | |
return name | |
def ConcatNode(self, node): | |
name = self.make_name(node) | |
label = '|'.join('<i%d>' % i for i in range(len(node.operands))) | |
self.vertex(name, 'record', '{{%s}|}' % label) | |
for i, operand in enumerate(node.operands): | |
self.edge(self(operand), '%s:i%d' % (name, i)) | |
return name | |
def InputNode(self, node): | |
name = self.make_name(node) | |
self.vertex(name, 'rarrow', node.name) | |
return name | |
def OutputNode(self, node): | |
name = self.make_name(node) | |
self.vertex(name, 'rarrow', node.name) | |
if node.operand: | |
self.edge(self(node.operand), name) | |
return name | |
def ModuleOutputNode(self, node): | |
return self.make_name(node, '%s:%s' % (self(node.module), node.name)) | |
def Module(self, module): | |
name = self.make_name(module) | |
inputs = '|'.join('<%s> %s' % (input_name, input_name) for input_name in module._inputs) | |
outputs = '|'.join('<%s> %s' % (output_name, output_name) for output_name in module._outputs) | |
self.vertex(name, 'record', '{{%s}|%s|{%s}}' % (inputs, type(module).__name__, outputs)) | |
for input_node, node in module._connections.items(): | |
self.edge(self(node), '%s:%s' % (name, input_node.name)) | |
return name | |
def default(self, x): | |
if isinstance(x, Module): | |
return self.Module(x) | |
else: | |
return super().default(x) | |
def generate_dot_file(module): | |
generator = DotGenerator() | |
for node in module._outputs.values(): | |
generator(node) | |
return 'digraph "%s" {\nrankdir = "LR";\n%s\n}\n' % (module.__name__, '\n'.join(generator.lines)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment