Skip to content

Instantly share code, notes, and snippets.

@kitlith
Created January 30, 2021 01:13
Show Gist options
  • Save kitlith/37c40cf68bbf66dfa6744beeec69bba6 to your computer and use it in GitHub Desktop.
Save kitlith/37c40cf68bbf66dfa6744beeec69bba6 to your computer and use it in GitHub Desktop.
Layer visualization. Requires the graphviz and pillow python libraries be installed, and graphviz itself installed
from collections import deque
from io import BytesIO
from itertools import islice
from graphviz import Graph
from PIL import Image
class Comparator:
def __init__(self, value: int, is_sub: bool):
# assert(type(value) is int)
# assert(type(is_sub) is bool)
self.value = value
self.is_sub = is_sub
def __str__(self) -> str:
return f'{self.is_sub * "*"}{self.value}'
def __repr__(self) -> str:
return f'Comp(value={self.value}, is_sub={self.is_sub})'
def eval(self, val: int, internal_is_base: bool) -> int:
if internal_is_base:
side = val
base = self.value
else:
side = self.value
base = val
return 0 if side > base else base - side if self.is_sub else max(base, side)
@classmethod
def from_str(cls, inp: str):
# "*0"
is_sub = inp.startswith('*')
return cls(int(inp.removeprefix('*'), 10), is_sub)
class Layer:
def __init__(self, sub: Comparator, inv: Comparator):
# assert(type(sub) is Comparator)
# assert(type(inv) is Comparator)
self.sub = sub
self.inv = inv
def __str__(self) -> str:
return f'{self.sub},{self.inv}'
def __repr__(self) -> str:
return f'Layer(sub={repr(self.sub)}, inv={repr(self.inv)})'
def eval(self, val: int) -> int:
return max(self.sub.eval(val, False), self.inv.eval(val, True))
@classmethod
def from_str(cls, inp: str):
# "0,*0"
return cls(*map(Comparator.from_str, inp.split(',')))
@classmethod
def parse_layers(cls, inp: str):
return list(map(cls.from_str, inp.split()))
def window(iterable, n: int = 2, cls: type = tuple) -> tuple:
it = iter(iterable)
win = deque(islice(it, n), n)
if len(win) < n:
return
append = win.append
yield cls(win)
for e in it:
append(e)
yield cls(win)
def gen_graph(layers: list[Layer], present: set[int] = None):
name = ''.join(map(str, layers))
graph = Graph(name, engine='neato', format='png')
spacing_y = 2
spacing_x = 1
layer_nodes = []
# input layer
layer = []
for i in range(16):
name = f'i_{i}'
graph.node(name, hex(i)[2:], pos=f'{i*spacing_x},0!')
layer.append(name)
layer_nodes.append(layer)
# internal layers
for _ in range(len(layers) - 1):
layer_idx = len(layer_nodes)
layer = []
for i in range(16):
name = f'{layer_idx}_{i}'
graph.node(name, "", pos=f'{i*spacing_x},{-layer_idx*spacing_y}!', shape='square', style='filled', fillcolor='black', fixedsize="true", width="0.25")
layer.append(name)
layer_nodes.append(layer)
# output layer
layer = []
for i in range(16):
name = f'o_{i}'
graph.node(name, hex(i)[2:], pos=f'{i*spacing_x},{-len(layer_nodes)*spacing_y}!')
layer.append(name)
layer_nodes.append(layer)
# eval layers
if present is None:
present = set(range(16))
for layer, (input_nodes, output_nodes) in zip(layers, window(layer_nodes, 2)):
new_present = set()
for i in range(16):
attrs = {'headport': 'n', 'tailport':'s'}
output = layer.eval(i)
if i in present:
new_present.add(output)
else:
attrs['color'] = 'gray80'
graph.edge(input_nodes[i], output_nodes[output], **attrs)
present = new_present
return graph
if __name__ == "__main__":
import sys
layers = Layer.parse_layers(sys.argv[1])
graph = gen_graph(layers)
graph.render("test.gv")
frames = list(map(lambda frame: Image.open(BytesIO(gen_graph(layers, set([frame])).pipe())), range(16)))
print()
frames[0].save("test.gv.gif", save_all=True, append_images=frames[1:], duration=250, loop=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment