Skip to content

Instantly share code, notes, and snippets.

@LeeeeT
Created February 21, 2025 05:10
Show Gist options
  • Save LeeeeT/70f29d144683e792a8950436288ee29c to your computer and use it in GitHub Desktop.
Save LeeeeT/70f29d144683e792a8950436288ee29c to your computer and use it in GitHub Desktop.
from __future__ import annotations
import itertools
import sys
from collections import Counter, defaultdict
from dataclasses import dataclass
port = itertools.count()
wires = dict[int, int]()
cells = dict[int, tuple[str, int, tuple[int, ...]]]()
queue = set[tuple[int, int]]()
def wire() -> tuple[int, int]:
a, b = next(port), next(port)
wires[a], wires[b] = b, a
return a, b
def join(first: int, second: int) -> None:
queue.update({(wires[first], wires[second]), (wires[second], wires[first])})
wires[wires.pop(first)], wires[wires.pop(second)] = wires[second], wires[first]
def cut(first: int, second: int) -> None:
del cells[first]; del cells[second]
del wires[first]; del wires[second]
def cell(s: str, arity: int) -> tuple[int, tuple[int, ...]]:
p, xs = wire(), tuple(wire() for _ in range(arity))
cells[p[0]] = (s, p[0], tuple(x[0] for x in xs))
return p[1], tuple(x[1] for x in xs)
def print_net() -> None:
id = itertools.count()
port_to_id = dict[int, int]()
for a, b in wires.items():
if a not in port_to_id:
port_to_id[a] = port_to_id[b] = next(id)
for s, p, xs in cells.values():
print(f"{s}({", ".join(str(port_to_id[x]) for x in xs)}) → {port_to_id[p]}" if xs else f"{s} → {port_to_id[p]}")
def show_root(root: int) -> str:
id = itertools.count()
port_to_id = dict[int, int]()
id_to_cell = dict[int, tuple[str, int, tuple[int, ...]]]()
for a, b in wires.items():
if a not in port_to_id:
port_to_id[a] = port_to_id[b] = next(id)
if a in cells:
id_to_cell[port_to_id[a]] = cells[a]
if b in cells:
id_to_cell[port_to_id[b]] = cells[b]
def show(port: int) -> str:
id = port_to_id[port]
if id in id_to_cell:
s, _, xs = id_to_cell[id]
return f"{s} {" ".join(f"({show(x)})" if port_to_id[x] in id_to_cell and id_to_cell[port_to_id[x]][2] else show(x) for x in xs)}" if xs else s
return str(id)
return show(root)
@dataclass(frozen=True)
class Program:
rules: list[Cut]
net: list[Cut]
roots: list[Tree]
@dataclass(frozen=True)
class Cut:
left: Tree
right: Tree
type Tree = Wire | Cell
type Wire = str
@dataclass(frozen=True)
class Cell:
symbol: Symbol
auxes: list[Tree]
type Symbol = str
def parse_program(input: str) -> tuple[Program, str] | None:
match parse_cuts(input):
case None:
return None
case (rules, cuts), input:
pass
match parse_newlines(input):
case None:
return None
case _, input:
pass
roots = list[Tree]()
match parse_tree(input):
case None:
return None
case root, input:
roots.append(root)
if roots:
while True:
match parse_newlines(input):
case None:
break
case _, input:
pass
match parse_tree(input):
case None:
break
case root, input:
roots.append(root)
match parse_newlines(input):
case None:
return None
case _, input:
pass
return Program(rules, cuts, roots), input
def parse_cuts(input: str) -> tuple[tuple[list[Cut], list[Cut]], str] | None:
cuts = list[Cut]()
match parse_cut(input):
case None:
pass
case cut, input:
cuts.append(cut)
if cuts:
while True:
match parse_newlines(input):
case None:
break
case _, input:
pass
match parse_cut(input):
case None:
break
case cut, input:
cuts.append(cut)
rules = list[Cut]()
net = list[Cut]()
for cut in cuts:
if balanced(cut):
rules.append(cut)
else:
net.append(cut)
return (rules, net), input
def parse_cut(input: str) -> tuple[Cut, str] | None:
match parse_tree(input):
case None:
return None
case left, input:
pass
match parse_string(" ~ ", input):
case None:
return None
case _, input:
pass
match parse_tree(input):
case None:
return None
case right, input:
pass
return Cut(left, right), input
def parse_tree(input: str) -> tuple[Tree, str] | None:
match parse_wire(input):
case None:
pass
case wire, input:
return wire, input
match parse_cell(input):
case None:
pass
case cell, input:
return cell, input
return None
def parse_wire(input: str) -> tuple[Wire, str] | None:
match parse_wire_first(input):
case None:
return None
case char, input:
wire = char
while True:
match parse_id_char(input):
case None:
break
case char, input:
wire += char
return wire, input
def parse_cell(input: str) -> tuple[Cell, str] | None:
match parse_symbol(input):
case None:
return None
case symbol, input:
pass
auxes = list[Tree]()
match parse_string(" ", input):
case None:
pass
case _, input:
match parse_string("(", input):
case None:
match parse_wire(input):
case None:
match parse_symbol(input):
case None:
input = " " + input
pass
case aux, input:
auxes.append(Cell(aux, []))
case aux, input:
auxes.append(aux)
case _, input:
match parse_tree(input):
case None:
input = " " + input
pass
case aux, input:
auxes.append(aux)
match parse_string(")", input):
case None:
pass
case _, input:
pass
if auxes:
while True:
match parse_string(" ", input):
case None:
break
case _, input:
pass
match parse_string("(", input):
case None:
match parse_wire(input):
case None:
match parse_symbol(input):
case None:
input = " " + input
break
case aux, input:
auxes.append(Cell(aux, []))
case aux, input:
auxes.append(aux)
case _, input:
match parse_tree(input):
case None:
input = " " + input
break
case aux, input:
auxes.append(aux)
match parse_string(")", input):
case None:
return None
case _, input:
pass
return Cell(symbol, auxes), input
def parse_symbol(input: str) -> tuple[Symbol, str] | None:
match parse_symbol_first(input):
case None:
return None
case char, input:
symbol = char
while True:
match parse_id_char(input):
case None:
break
case char, input:
symbol += char
return symbol, input
def parse_id_char(input: str) -> tuple[str, str] | None:
if not input or input[0] in {" ", "(", ")", "~", "\r", "\n"}:
return None
return input[0], input[1:]
def parse_wire_first(input: str) -> tuple[str, str] | None:
if not input or input[0] in {" ", "(", ")", "~", "\r", "\n", "_"} or input[0].isupper():
return None
return input[0], input[1:]
def parse_symbol_first(input: str) -> tuple[str, str] | None:
if not input or not (input[0].isupper() or input[0] == "_"):
return None
return input[0], input[1:]
def parse_string(string: str, input: str) -> tuple[str, str] | None:
if not input.startswith(string):
return None
return string, input[len(string):]
def parse_newlines(input: str) -> tuple[str, str] | None:
newlines = ""
while True:
match parse_string("\n", input):
case None:
break
case char, input:
newlines += char
return newlines, input
def balanced(cut: Cut) -> bool:
wires = extract_wires(cut.left) + extract_wires(cut.right)
counter = Counter(wires)
return all(value == 2 for value in counter.values())
def extract_wires(tree: Tree) -> list[Wire]:
match tree:
case str(wire):
return [wire]
case Cell(_, auxes):
return sum((extract_wires(aux) for aux in auxes), list[Wire]())
def parse_text(text: str) -> Program:
match parse_program(text):
case None:
raise SyntaxError
case program, rest:
if rest:
raise SyntaxError
return program
def parse_file(path: str) -> Program:
with open(path) as file:
text = file.read()
return parse_text(text)
def show_program(program: Program) -> str:
return f"{show_cuts(program.rules)}\n\n{show_cuts(program.net)}\n\n{"\n".join(show_tree(root) for root in program.roots)}"
def show_cuts(cuts: list[Cut]) -> str:
return "\n".join(show_cut(cut) for cut in cuts)
def show_cut(cut: Cut) -> str:
return f"{show_tree(cut.left)} ~ {show_tree(cut.right)}"
def show_tree(tree: Tree) -> str:
match tree:
case str(wire):
return wire
case Cell(symbol, auxes):
return f"{symbol} {show_auxes(auxes)}" if auxes else symbol
def show_auxes(auxes: list[Tree]) -> str:
return " ".join(f"({show_tree(aux)})" if isinstance(aux, Cell) and aux.auxes else f"{show_tree(aux)}" for aux in auxes)
def load_net(net: list[Cut], id_to_port: dict[str, list[int]]) -> None:
for cut in net:
p1, p2 = load_tree(cut.left, id_to_port), load_tree(cut.right, id_to_port)
join(p1, p2)
def load_tree(tree: Tree, id_to_port: dict[str, list[int]]) -> int:
match tree:
case str(id):
a, b = wire()
id_to_port[id].append(b)
return a
case Cell(symbol, auxes):
p, xs = cell(symbol, len(auxes))
for aux, x_port in zip(auxes, xs):
join(load_tree(aux, id_to_port), x_port)
return p
rules = dict[tuple[str, str], tuple[list[Tree], list[Tree]]]()
def load_rule(rule: Cut) -> None:
match rule.left, rule.right:
case Cell(s1, xs1), Cell(s2, xs2):
rules[s1, s2] = xs1, xs2
case _, _:
pass
def load_rules(rules: list[Cut]) -> None:
for rule in rules:
load_rule(rule)
def load_program(program: Program) -> list[int]:
load_rules(program.rules)
id_to_port = defaultdict[str, list[int]](list)
load_net(program.net, id_to_port)
roots = [load_tree(root, id_to_port) for root in program.roots]
for ports in id_to_port.values():
for a, b in zip(ports[::2], ports[1::2]):
join(a, b)
return roots
def interact(first: tuple[str, int, tuple[int, ...]], second: tuple[str, int, tuple[int, ...]]) -> None:
rule = rules.get((first[0], second[0]))
if rule is None:
return
left, right = rule
cut(first[1], second[1])
id_to_port = defaultdict[str, list[int]](list)
for a, b in zip(left, first[2]):
join(load_tree(a, id_to_port), b)
for a, b in zip(right, second[2]):
join(load_tree(a, id_to_port), b)
for ports in id_to_port.values():
match ports:
case [a, b]:
join(a, b)
case _:
pass
def reduce() -> None:
while queue:
a, b = queue.pop()
if a in cells and b in cells:
interact(cells[a], cells[b])
def main() -> None:
path = sys.argv[1]
program = parse_file(path)
roots = load_program(program)
reduce()
for root in roots:
print(show_root(root))
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment