Created
January 30, 2021 01:13
-
-
Save kitlith/37c40cf68bbf66dfa6744beeec69bba6 to your computer and use it in GitHub Desktop.
Layer visualization. Requires the graphviz and pillow python libraries be installed, and graphviz itself installed
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import deque | |
from io import BytesIO | |
from itertools import islice | |
from graphviz import Graph | |
from PIL import Image | |
class Comparator: | |
def __init__(self, value: int, is_sub: bool): | |
# assert(type(value) is int) | |
# assert(type(is_sub) is bool) | |
self.value = value | |
self.is_sub = is_sub | |
def __str__(self) -> str: | |
return f'{self.is_sub * "*"}{self.value}' | |
def __repr__(self) -> str: | |
return f'Comp(value={self.value}, is_sub={self.is_sub})' | |
def eval(self, val: int, internal_is_base: bool) -> int: | |
if internal_is_base: | |
side = val | |
base = self.value | |
else: | |
side = self.value | |
base = val | |
return 0 if side > base else base - side if self.is_sub else max(base, side) | |
@classmethod | |
def from_str(cls, inp: str): | |
# "*0" | |
is_sub = inp.startswith('*') | |
return cls(int(inp.removeprefix('*'), 10), is_sub) | |
class Layer: | |
def __init__(self, sub: Comparator, inv: Comparator): | |
# assert(type(sub) is Comparator) | |
# assert(type(inv) is Comparator) | |
self.sub = sub | |
self.inv = inv | |
def __str__(self) -> str: | |
return f'{self.sub},{self.inv}' | |
def __repr__(self) -> str: | |
return f'Layer(sub={repr(self.sub)}, inv={repr(self.inv)})' | |
def eval(self, val: int) -> int: | |
return max(self.sub.eval(val, False), self.inv.eval(val, True)) | |
@classmethod | |
def from_str(cls, inp: str): | |
# "0,*0" | |
return cls(*map(Comparator.from_str, inp.split(','))) | |
@classmethod | |
def parse_layers(cls, inp: str): | |
return list(map(cls.from_str, inp.split())) | |
def window(iterable, n: int = 2, cls: type = tuple) -> tuple: | |
it = iter(iterable) | |
win = deque(islice(it, n), n) | |
if len(win) < n: | |
return | |
append = win.append | |
yield cls(win) | |
for e in it: | |
append(e) | |
yield cls(win) | |
def gen_graph(layers: list[Layer], present: set[int] = None): | |
name = ''.join(map(str, layers)) | |
graph = Graph(name, engine='neato', format='png') | |
spacing_y = 2 | |
spacing_x = 1 | |
layer_nodes = [] | |
# input layer | |
layer = [] | |
for i in range(16): | |
name = f'i_{i}' | |
graph.node(name, hex(i)[2:], pos=f'{i*spacing_x},0!') | |
layer.append(name) | |
layer_nodes.append(layer) | |
# internal layers | |
for _ in range(len(layers) - 1): | |
layer_idx = len(layer_nodes) | |
layer = [] | |
for i in range(16): | |
name = f'{layer_idx}_{i}' | |
graph.node(name, "", pos=f'{i*spacing_x},{-layer_idx*spacing_y}!', shape='square', style='filled', fillcolor='black', fixedsize="true", width="0.25") | |
layer.append(name) | |
layer_nodes.append(layer) | |
# output layer | |
layer = [] | |
for i in range(16): | |
name = f'o_{i}' | |
graph.node(name, hex(i)[2:], pos=f'{i*spacing_x},{-len(layer_nodes)*spacing_y}!') | |
layer.append(name) | |
layer_nodes.append(layer) | |
# eval layers | |
if present is None: | |
present = set(range(16)) | |
for layer, (input_nodes, output_nodes) in zip(layers, window(layer_nodes, 2)): | |
new_present = set() | |
for i in range(16): | |
attrs = {'headport': 'n', 'tailport':'s'} | |
output = layer.eval(i) | |
if i in present: | |
new_present.add(output) | |
else: | |
attrs['color'] = 'gray80' | |
graph.edge(input_nodes[i], output_nodes[output], **attrs) | |
present = new_present | |
return graph | |
if __name__ == "__main__": | |
import sys | |
layers = Layer.parse_layers(sys.argv[1]) | |
graph = gen_graph(layers) | |
graph.render("test.gv") | |
frames = list(map(lambda frame: Image.open(BytesIO(gen_graph(layers, set([frame])).pipe())), range(16))) | |
print() | |
frames[0].save("test.gv.gif", save_all=True, append_images=frames[1:], duration=250, loop=0) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment