Created
February 21, 2025 05:10
-
-
Save LeeeeT/70f29d144683e792a8950436288ee29c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import itertools | |
import sys | |
from collections import Counter, defaultdict | |
from dataclasses import dataclass | |
port = itertools.count() | |
wires = dict[int, int]() | |
cells = dict[int, tuple[str, int, tuple[int, ...]]]() | |
queue = set[tuple[int, int]]() | |
def wire() -> tuple[int, int]: | |
a, b = next(port), next(port) | |
wires[a], wires[b] = b, a | |
return a, b | |
def join(first: int, second: int) -> None: | |
queue.update({(wires[first], wires[second]), (wires[second], wires[first])}) | |
wires[wires.pop(first)], wires[wires.pop(second)] = wires[second], wires[first] | |
def cut(first: int, second: int) -> None: | |
del cells[first]; del cells[second] | |
del wires[first]; del wires[second] | |
def cell(s: str, arity: int) -> tuple[int, tuple[int, ...]]: | |
p, xs = wire(), tuple(wire() for _ in range(arity)) | |
cells[p[0]] = (s, p[0], tuple(x[0] for x in xs)) | |
return p[1], tuple(x[1] for x in xs) | |
def print_net() -> None: | |
id = itertools.count() | |
port_to_id = dict[int, int]() | |
for a, b in wires.items(): | |
if a not in port_to_id: | |
port_to_id[a] = port_to_id[b] = next(id) | |
for s, p, xs in cells.values(): | |
print(f"{s}({", ".join(str(port_to_id[x]) for x in xs)}) → {port_to_id[p]}" if xs else f"{s} → {port_to_id[p]}") | |
def show_root(root: int) -> str: | |
id = itertools.count() | |
port_to_id = dict[int, int]() | |
id_to_cell = dict[int, tuple[str, int, tuple[int, ...]]]() | |
for a, b in wires.items(): | |
if a not in port_to_id: | |
port_to_id[a] = port_to_id[b] = next(id) | |
if a in cells: | |
id_to_cell[port_to_id[a]] = cells[a] | |
if b in cells: | |
id_to_cell[port_to_id[b]] = cells[b] | |
def show(port: int) -> str: | |
id = port_to_id[port] | |
if id in id_to_cell: | |
s, _, xs = id_to_cell[id] | |
return f"{s} {" ".join(f"({show(x)})" if port_to_id[x] in id_to_cell and id_to_cell[port_to_id[x]][2] else show(x) for x in xs)}" if xs else s | |
return str(id) | |
return show(root) | |
@dataclass(frozen=True) | |
class Program: | |
rules: list[Cut] | |
net: list[Cut] | |
roots: list[Tree] | |
@dataclass(frozen=True) | |
class Cut: | |
left: Tree | |
right: Tree | |
type Tree = Wire | Cell | |
type Wire = str | |
@dataclass(frozen=True) | |
class Cell: | |
symbol: Symbol | |
auxes: list[Tree] | |
type Symbol = str | |
def parse_program(input: str) -> tuple[Program, str] | None: | |
match parse_cuts(input): | |
case None: | |
return None | |
case (rules, cuts), input: | |
pass | |
match parse_newlines(input): | |
case None: | |
return None | |
case _, input: | |
pass | |
roots = list[Tree]() | |
match parse_tree(input): | |
case None: | |
return None | |
case root, input: | |
roots.append(root) | |
if roots: | |
while True: | |
match parse_newlines(input): | |
case None: | |
break | |
case _, input: | |
pass | |
match parse_tree(input): | |
case None: | |
break | |
case root, input: | |
roots.append(root) | |
match parse_newlines(input): | |
case None: | |
return None | |
case _, input: | |
pass | |
return Program(rules, cuts, roots), input | |
def parse_cuts(input: str) -> tuple[tuple[list[Cut], list[Cut]], str] | None: | |
cuts = list[Cut]() | |
match parse_cut(input): | |
case None: | |
pass | |
case cut, input: | |
cuts.append(cut) | |
if cuts: | |
while True: | |
match parse_newlines(input): | |
case None: | |
break | |
case _, input: | |
pass | |
match parse_cut(input): | |
case None: | |
break | |
case cut, input: | |
cuts.append(cut) | |
rules = list[Cut]() | |
net = list[Cut]() | |
for cut in cuts: | |
if balanced(cut): | |
rules.append(cut) | |
else: | |
net.append(cut) | |
return (rules, net), input | |
def parse_cut(input: str) -> tuple[Cut, str] | None: | |
match parse_tree(input): | |
case None: | |
return None | |
case left, input: | |
pass | |
match parse_string(" ~ ", input): | |
case None: | |
return None | |
case _, input: | |
pass | |
match parse_tree(input): | |
case None: | |
return None | |
case right, input: | |
pass | |
return Cut(left, right), input | |
def parse_tree(input: str) -> tuple[Tree, str] | None: | |
match parse_wire(input): | |
case None: | |
pass | |
case wire, input: | |
return wire, input | |
match parse_cell(input): | |
case None: | |
pass | |
case cell, input: | |
return cell, input | |
return None | |
def parse_wire(input: str) -> tuple[Wire, str] | None: | |
match parse_wire_first(input): | |
case None: | |
return None | |
case char, input: | |
wire = char | |
while True: | |
match parse_id_char(input): | |
case None: | |
break | |
case char, input: | |
wire += char | |
return wire, input | |
def parse_cell(input: str) -> tuple[Cell, str] | None: | |
match parse_symbol(input): | |
case None: | |
return None | |
case symbol, input: | |
pass | |
auxes = list[Tree]() | |
match parse_string(" ", input): | |
case None: | |
pass | |
case _, input: | |
match parse_string("(", input): | |
case None: | |
match parse_wire(input): | |
case None: | |
match parse_symbol(input): | |
case None: | |
input = " " + input | |
pass | |
case aux, input: | |
auxes.append(Cell(aux, [])) | |
case aux, input: | |
auxes.append(aux) | |
case _, input: | |
match parse_tree(input): | |
case None: | |
input = " " + input | |
pass | |
case aux, input: | |
auxes.append(aux) | |
match parse_string(")", input): | |
case None: | |
pass | |
case _, input: | |
pass | |
if auxes: | |
while True: | |
match parse_string(" ", input): | |
case None: | |
break | |
case _, input: | |
pass | |
match parse_string("(", input): | |
case None: | |
match parse_wire(input): | |
case None: | |
match parse_symbol(input): | |
case None: | |
input = " " + input | |
break | |
case aux, input: | |
auxes.append(Cell(aux, [])) | |
case aux, input: | |
auxes.append(aux) | |
case _, input: | |
match parse_tree(input): | |
case None: | |
input = " " + input | |
break | |
case aux, input: | |
auxes.append(aux) | |
match parse_string(")", input): | |
case None: | |
return None | |
case _, input: | |
pass | |
return Cell(symbol, auxes), input | |
def parse_symbol(input: str) -> tuple[Symbol, str] | None: | |
match parse_symbol_first(input): | |
case None: | |
return None | |
case char, input: | |
symbol = char | |
while True: | |
match parse_id_char(input): | |
case None: | |
break | |
case char, input: | |
symbol += char | |
return symbol, input | |
def parse_id_char(input: str) -> tuple[str, str] | None: | |
if not input or input[0] in {" ", "(", ")", "~", "\r", "\n"}: | |
return None | |
return input[0], input[1:] | |
def parse_wire_first(input: str) -> tuple[str, str] | None: | |
if not input or input[0] in {" ", "(", ")", "~", "\r", "\n", "_"} or input[0].isupper(): | |
return None | |
return input[0], input[1:] | |
def parse_symbol_first(input: str) -> tuple[str, str] | None: | |
if not input or not (input[0].isupper() or input[0] == "_"): | |
return None | |
return input[0], input[1:] | |
def parse_string(string: str, input: str) -> tuple[str, str] | None: | |
if not input.startswith(string): | |
return None | |
return string, input[len(string):] | |
def parse_newlines(input: str) -> tuple[str, str] | None: | |
newlines = "" | |
while True: | |
match parse_string("\n", input): | |
case None: | |
break | |
case char, input: | |
newlines += char | |
return newlines, input | |
def balanced(cut: Cut) -> bool: | |
wires = extract_wires(cut.left) + extract_wires(cut.right) | |
counter = Counter(wires) | |
return all(value == 2 for value in counter.values()) | |
def extract_wires(tree: Tree) -> list[Wire]: | |
match tree: | |
case str(wire): | |
return [wire] | |
case Cell(_, auxes): | |
return sum((extract_wires(aux) for aux in auxes), list[Wire]()) | |
def parse_text(text: str) -> Program: | |
match parse_program(text): | |
case None: | |
raise SyntaxError | |
case program, rest: | |
if rest: | |
raise SyntaxError | |
return program | |
def parse_file(path: str) -> Program: | |
with open(path) as file: | |
text = file.read() | |
return parse_text(text) | |
def show_program(program: Program) -> str: | |
return f"{show_cuts(program.rules)}\n\n{show_cuts(program.net)}\n\n{"\n".join(show_tree(root) for root in program.roots)}" | |
def show_cuts(cuts: list[Cut]) -> str: | |
return "\n".join(show_cut(cut) for cut in cuts) | |
def show_cut(cut: Cut) -> str: | |
return f"{show_tree(cut.left)} ~ {show_tree(cut.right)}" | |
def show_tree(tree: Tree) -> str: | |
match tree: | |
case str(wire): | |
return wire | |
case Cell(symbol, auxes): | |
return f"{symbol} {show_auxes(auxes)}" if auxes else symbol | |
def show_auxes(auxes: list[Tree]) -> str: | |
return " ".join(f"({show_tree(aux)})" if isinstance(aux, Cell) and aux.auxes else f"{show_tree(aux)}" for aux in auxes) | |
def load_net(net: list[Cut], id_to_port: dict[str, list[int]]) -> None: | |
for cut in net: | |
p1, p2 = load_tree(cut.left, id_to_port), load_tree(cut.right, id_to_port) | |
join(p1, p2) | |
def load_tree(tree: Tree, id_to_port: dict[str, list[int]]) -> int: | |
match tree: | |
case str(id): | |
a, b = wire() | |
id_to_port[id].append(b) | |
return a | |
case Cell(symbol, auxes): | |
p, xs = cell(symbol, len(auxes)) | |
for aux, x_port in zip(auxes, xs): | |
join(load_tree(aux, id_to_port), x_port) | |
return p | |
rules = dict[tuple[str, str], tuple[list[Tree], list[Tree]]]() | |
def load_rule(rule: Cut) -> None: | |
match rule.left, rule.right: | |
case Cell(s1, xs1), Cell(s2, xs2): | |
rules[s1, s2] = xs1, xs2 | |
case _, _: | |
pass | |
def load_rules(rules: list[Cut]) -> None: | |
for rule in rules: | |
load_rule(rule) | |
def load_program(program: Program) -> list[int]: | |
load_rules(program.rules) | |
id_to_port = defaultdict[str, list[int]](list) | |
load_net(program.net, id_to_port) | |
roots = [load_tree(root, id_to_port) for root in program.roots] | |
for ports in id_to_port.values(): | |
for a, b in zip(ports[::2], ports[1::2]): | |
join(a, b) | |
return roots | |
def interact(first: tuple[str, int, tuple[int, ...]], second: tuple[str, int, tuple[int, ...]]) -> None: | |
rule = rules.get((first[0], second[0])) | |
if rule is None: | |
return | |
left, right = rule | |
cut(first[1], second[1]) | |
id_to_port = defaultdict[str, list[int]](list) | |
for a, b in zip(left, first[2]): | |
join(load_tree(a, id_to_port), b) | |
for a, b in zip(right, second[2]): | |
join(load_tree(a, id_to_port), b) | |
for ports in id_to_port.values(): | |
match ports: | |
case [a, b]: | |
join(a, b) | |
case _: | |
pass | |
def reduce() -> None: | |
while queue: | |
a, b = queue.pop() | |
if a in cells and b in cells: | |
interact(cells[a], cells[b]) | |
def main() -> None: | |
path = sys.argv[1] | |
program = parse_file(path) | |
roots = load_program(program) | |
reduce() | |
for root in roots: | |
print(show_root(root)) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment