Created
December 20, 2020 04:02
-
-
Save ezyang/c3db0e55a7661998c8a66ea8619f1081 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from typing import List, Tuple, Union, Dict, Iterator, Optional | |
import sys | |
import itertools | |
Id = int | |
class UnionFind: | |
parents: List[Id] | |
def __init__(self) -> None: | |
self.parents = [] | |
def make_set(self) -> Id: | |
id = len(self.parents) | |
self.parents.append(id) | |
return id | |
def parent(self, query: Id) -> Id: | |
return self.parents[query] | |
def set_parent(self, query: Id, new_parent: Id) -> None: | |
self.parents[query] = new_parent | |
def find(self, current: Id) -> Id: | |
while True: | |
parent = self.parents[current] | |
if current == parent: | |
return parent | |
grandparent = self.parent(parent) | |
self.set_parent(current, grandparent) | |
current = grandparent | |
def union(self, set1: Id, set2: Id) -> Tuple[Id, Id]: | |
root1 = self.find(set1) | |
root2 = self.find(set2) | |
if root1 == root2: | |
return (root1, root2) | |
else: | |
if root1 > root2: | |
# NOTE egg relies on returned id being minimum | |
root1, root2 = root2, root1 | |
self.set_parent(root2, root1) | |
return (root1, root2) | |
def make_union_find(n: int) -> UnionFind: | |
uf = UnionFind() | |
for _ in range(n): | |
uf.make_set() | |
return uf | |
def union_find() -> None: | |
n = 10 | |
uf = make_union_find(n) | |
for i in range(n): | |
assert uf.find(i) == i | |
assert uf.find(i) == i | |
assert uf.union(0, 1) == (0, 1) | |
assert uf.union(1, 2) == (0, 2) | |
assert uf.union(3, 2) == (0, 3) | |
assert uf.union(6, 7) == (6, 7) | |
assert uf.union(8, 9) == (8, 9) | |
assert uf.union(7, 9) == (6, 8) | |
# make sure union on same set returns to == from | |
assert uf.union(1, 3) == (0, 0) | |
assert uf.union(7, 8) == (6, 6) | |
# walk over everything | |
for i in range(n): | |
uf.find(i) | |
# all paths are compressed | |
for i in range(n): | |
assert uf.parent(i) == uf.find(i) | |
class Symbol: | |
sym: str | |
def __init__(self, sym: str) -> None: | |
self.sym = sys.intern(sym) | |
def __str__(self) -> str: | |
return self.sym | |
def __eq__(self, other) -> bool: | |
return self.sym is other.sym | |
def __hash__(self) -> int: | |
return hash(self.sym) | |
@dataclass(init=False, frozen=True) | |
class SymbolLang: | |
op: Symbol | |
children: Tuple[Id, ...] | |
def __init__(self, op: Union[str, Symbol], children: Tuple[Id, ...] = ()) -> None: | |
object.__setattr__(self, 'op', op if isinstance(op, Symbol) else Symbol(op)) | |
object.__setattr__(self, 'children', children) | |
@classmethod | |
def leaf(cls, op: Union[str, Symbol]) -> 'SymbolLang': | |
return SymbolLang(op, ()) | |
def len(self) -> int: | |
return len(self.children) | |
def is_leaf(self) -> bool: | |
return not self.children | |
def matches(self, other: 'SymbolLang') -> bool: | |
return self.op == other.op and self.len() == other.len() | |
def __str__(self) -> str: | |
return f"({' '.join([str(self.op), *map(str, self.children)])})" | |
L = SymbolLang | |
# Conceptually a recursive expression, but actually just a list of | |
# enodes. Invariant: enodes children must refer to elements that | |
# come before it in the list. | |
@dataclass | |
class RecExpr: | |
nodes: List[L] | |
def add(self, node: L) -> Id: | |
assert all(id < len(self.nodes) for id in node.children) | |
self.nodes.append(node) | |
return len(self.nodes) - 1 | |
def __getitem__(self, id: Id) -> L: | |
return self.nodes[id] | |
# An equivalence class of enodes | |
@dataclass | |
class EClass: | |
# This eclass's id | |
id: Id | |
# The equivalent enodes in this equivalence class | |
nodes: List[L] | |
parents: List[Tuple[L, Id]] | |
def is_empty(self) -> bool: | |
return not self.nodes | |
def len(self) -> int: | |
return len(self.nodes) | |
def leaves(self) -> Iterator[L]: | |
return filter(lambda n: n.is_leaf(), self.nodes) | |
class EGraph: | |
memo: Dict[L, Id] | |
unionfind: UnionFind | |
classes: List[Optional[EClass]] | |
# invariant: number of non-None classes | |
n_classes: int | |
dirty_unions: List[Id] | |
repairs_since_rebuild: int | |
def __init__(self): | |
self.memo = {} | |
self.unionfind = UnionFind() | |
self.classes = [] | |
self.n_classes = 0 | |
self.dirty_unions = [] | |
self.repairs_since_rebuild = 0 | |
def is_empty(self): | |
return not self.memo | |
def total_size(self) -> int: | |
return len(self.memo) | |
def total_number_of_nodes(self) -> int: | |
return sum(c.len() for c in self.classes if c is not None) | |
def find(self, id: Id) -> Id: | |
return self.unionfind.find(id) | |
def __getitem__(self, id: Id) -> EClass: | |
eclass = self.classes[self.find(id)] | |
assert eclass is not None | |
return eclass | |
def canonicalize(self, enode: L) -> L: | |
return L(enode.op, tuple(self.find(id) for id in enode.children)) | |
# Add a RecExpr to the graph | |
def add_expr(self, expr: RecExpr) -> Id: | |
new_ids: List[Id] = [] | |
for node in expr.nodes: | |
node = SymbolLang(node.op, tuple(new_ids[i] for i in node.children)) | |
new_ids.append(self.add(node)) | |
return new_ids[-1] | |
# Lookup the eclass of a given enode. | |
def lookup(self, enode: L) -> Optional[Id]: | |
enode = self.canonicalize(enode) | |
id = self.memo.get(enode) | |
if id is None: | |
return None | |
return self.find(id) | |
# Adds an enode to the EGraph. When adding an enode to the egraph, | |
# it performs hashconsing, ensuring only one copy of the enode is | |
# in the graph. If the copy was in theg raph, add simply returns | |
# the id of the class in which the enode was found. | |
def add(self, enode: L) -> Id: | |
enode = self.canonicalize(enode) | |
# slight pessimization, lookup will recanon it again | |
r = self.lookup(enode) | |
if r is not None: | |
return r | |
id = self.unionfind.make_set() | |
clas = EClass( | |
id, | |
nodes=[enode], # noclone | |
# N::make | |
parents=[], | |
) | |
for child in enode.children: | |
self[child].parents.append((enode, id)) # noclone | |
assert len(self.classes) == id | |
self.classes.append(clas) | |
self.n_classes += 1 | |
assert enode not in self.memo | |
self.memo[enode] = id | |
# N::modify | |
return id | |
# Checks if two RecExprs are equivalent. Return a list of id where | |
# both expression are represented. In most case there will be none | |
# or exactly one id. | |
# def equivs(self, expr1: RecExpr, expr2: RexExpr) -> List[Id]: | |
# def check_goals | |
def union_impl(self, id1: Id, id2: Id) -> Tuple[Id, bool]: | |
# N::pre_union | |
to, fro = self.unionfind.union(id1, id2) | |
if to != fro: | |
self.dirty_unions.append(to) | |
fro_class = self.classes[fro] | |
to_class = self.classes[to] | |
assert fro_class is not None | |
assert to_class is not None | |
# analysis.merge | |
# TODO: minimize amount of copying | |
to_class.nodes.extend(fro_class.nodes) | |
to_class.parents.extend(fro_class.parents) | |
self.classes[fro] = None | |
self.n_classes -= 1 | |
# N::modify | |
return to, to != fro | |
# Unions two eclasses given their ids. The given ides need not be | |
# canonical. Returned bool indicates whether a union was done, | |
# false if they were already equivalent. Both results are canonical | |
def union(self, id1: Id, id2: Id) -> Tuple[Id, bool]: | |
union = self.union_impl(id1, id2) | |
# upward-merging feature | |
return union | |
def rebuild_classes(self) -> int: | |
trimmed = 0 | |
uf = self.unionfind | |
for clas in self.classes: | |
if clas is None: | |
continue | |
old_len = clas.len() | |
clas.nodes = list({ | |
L(n.op, tuple(uf.find(id) for id in n.children)) | |
for n in clas.nodes | |
}) | |
trimmed += old_len - len(clas.nodes) | |
return trimmed | |
def check_memo(self) -> bool: | |
test_memo: Dict[L, Id] = {} | |
for id, clas in enumerate(self.classes): | |
if clas is None: | |
continue | |
assert clas.id == id | |
for node in clas.nodes: | |
old = test_memo.get(node) | |
test_memo[node] = id | |
if old is not None: | |
assert self.find(old) == self.find(id) | |
for n, e in test_memo.items(): | |
assert e == self.find(e) | |
id2 = self.memo.get(n) | |
assert id2 is not None | |
assert e == self.find(id2) | |
return True | |
def process_unions(self) -> None: | |
to_union: List[Tuple[Id, Id]] = [] | |
while self.dirty_unions: | |
todo = self.dirty_unions | |
self.dirty_unions = [] | |
todo = {self.find(id) for id in todo} | |
assert todo | |
for id in todo: | |
self.repairs_since_rebuild += 1 | |
parents = self[id].parents | |
for n, _e in parents: | |
if n in self.memo: | |
del self.memo[n] | |
parents = [ | |
(self.canonicalize(n), self.find(id)) | |
for n, id in parents | |
] | |
parents.sort() | |
# deduplicate | |
new_parents = [] | |
groups = itertools.groupby(parents, key=lambda t: t[0]) | |
for _, g in groups: | |
n, e0 = next(g) | |
new_parents.append((n, e0)) | |
for _, e in g: | |
to_union.append((e0, e)) | |
parents = new_parents | |
for n, e in parents: | |
old = self.memo.get(n) | |
self.memo[n] = e | |
if old is not None: | |
to_union.append((old, e)) | |
# propagate_metadata | |
self[id].parents = parents | |
# N::modify | |
for (id1, id2) in to_union: | |
to, did_something = self.union_impl(id1, id2) | |
if did_something: | |
self.dirty_unions.append(to) | |
to_union = [] | |
assert not self.dirty_unions | |
assert not to_union | |
def rebuild(self) -> int: | |
old_hc_size = len(self.memo) | |
old_n_eclasses = self.n_classes | |
self.process_unions() | |
n_unions = self.repairs_since_rebuild | |
trimmed_nodes = self.rebuild_classes() | |
print(f"""\ | |
REBUILT! | |
Old: hc size {old_hc_size}, eclasses: {old_n_eclasses} | |
New: hc size {len(self.memo)}, eclasses: {self.n_classes} | |
unions: {self.repairs_since_rebuild}, trimmed nodes: {trimmed_nodes}""") | |
self.repairs_since_rebuild = 0 | |
assert self.check_memo() | |
return n_unions | |
def __str__(self) -> str: | |
ids = [c.id for c in self.classes if c is not None] | |
ids.sort() | |
s = "" | |
for id in ids: | |
nodes = self[id].nodes # not sorted | |
s += f"{id}: {', '.join(str(n) for n in nodes)}\n" | |
return s | |
def simple_add(): | |
egraph = EGraph() | |
x = egraph.add(L.leaf('x')) | |
x2 = egraph.add(L.leaf('x')) | |
_plus = egraph.add(L('+', (x, x2))) | |
y = egraph.add(L.leaf("y")) | |
egraph.union(x, y) | |
egraph.rebuild() | |
print(egraph) | |
def exercise_trim(): | |
# goal is to get trimmed nodes nonzero | |
# for this to happen, we must rebuild a class and discover after | |
# rebuilding that two nodes we previously thought were inequivalent | |
# are now equivalent after canonicalization | |
egraph = EGraph() | |
x = egraph.add(L('x')) | |
y = egraph.add(L('y')) | |
z = egraph.add(L('z')) | |
f1 = egraph.add(L('f', (x, y))) | |
f2 = egraph.add(L('f', (z, y))) | |
egraph.union(f1, f2) | |
egraph.union(x, z) | |
egraph.rebuild() | |
print(egraph) | |
def exercise_n_classes_reduction(): | |
# goal is to reduce n classes during rebuild | |
# for this to happen, a union needs to occur that wasn't obvious | |
# from the beginning, e.g., consequence of congruence | |
egraph = EGraph() | |
x = egraph.add(L('x')) | |
y = egraph.add(L('y')) | |
fx = egraph.add(L('f', (x,))) | |
fy = egraph.add(L('f', (y,))) | |
egraph.union(x, y) | |
egraph.rebuild() | |
print(egraph) | |
union_find() | |
simple_add() | |
exercise_trim() | |
exercise_n_classes_reduction() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment