Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created December 20, 2020 04:02
Show Gist options
  • Save ezyang/c3db0e55a7661998c8a66ea8619f1081 to your computer and use it in GitHub Desktop.
Save ezyang/c3db0e55a7661998c8a66ea8619f1081 to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
from typing import List, Tuple, Union, Dict, Iterator, Optional
import sys
import itertools
Id = int
class UnionFind:
parents: List[Id]
def __init__(self) -> None:
self.parents = []
def make_set(self) -> Id:
id = len(self.parents)
self.parents.append(id)
return id
def parent(self, query: Id) -> Id:
return self.parents[query]
def set_parent(self, query: Id, new_parent: Id) -> None:
self.parents[query] = new_parent
def find(self, current: Id) -> Id:
while True:
parent = self.parents[current]
if current == parent:
return parent
grandparent = self.parent(parent)
self.set_parent(current, grandparent)
current = grandparent
def union(self, set1: Id, set2: Id) -> Tuple[Id, Id]:
root1 = self.find(set1)
root2 = self.find(set2)
if root1 == root2:
return (root1, root2)
else:
if root1 > root2:
# NOTE egg relies on returned id being minimum
root1, root2 = root2, root1
self.set_parent(root2, root1)
return (root1, root2)
def make_union_find(n: int) -> UnionFind:
uf = UnionFind()
for _ in range(n):
uf.make_set()
return uf
def union_find() -> None:
n = 10
uf = make_union_find(n)
for i in range(n):
assert uf.find(i) == i
assert uf.find(i) == i
assert uf.union(0, 1) == (0, 1)
assert uf.union(1, 2) == (0, 2)
assert uf.union(3, 2) == (0, 3)
assert uf.union(6, 7) == (6, 7)
assert uf.union(8, 9) == (8, 9)
assert uf.union(7, 9) == (6, 8)
# make sure union on same set returns to == from
assert uf.union(1, 3) == (0, 0)
assert uf.union(7, 8) == (6, 6)
# walk over everything
for i in range(n):
uf.find(i)
# all paths are compressed
for i in range(n):
assert uf.parent(i) == uf.find(i)
class Symbol:
sym: str
def __init__(self, sym: str) -> None:
self.sym = sys.intern(sym)
def __str__(self) -> str:
return self.sym
def __eq__(self, other) -> bool:
return self.sym is other.sym
def __hash__(self) -> int:
return hash(self.sym)
@dataclass(init=False, frozen=True)
class SymbolLang:
op: Symbol
children: Tuple[Id, ...]
def __init__(self, op: Union[str, Symbol], children: Tuple[Id, ...] = ()) -> None:
object.__setattr__(self, 'op', op if isinstance(op, Symbol) else Symbol(op))
object.__setattr__(self, 'children', children)
@classmethod
def leaf(cls, op: Union[str, Symbol]) -> 'SymbolLang':
return SymbolLang(op, ())
def len(self) -> int:
return len(self.children)
def is_leaf(self) -> bool:
return not self.children
def matches(self, other: 'SymbolLang') -> bool:
return self.op == other.op and self.len() == other.len()
def __str__(self) -> str:
return f"({' '.join([str(self.op), *map(str, self.children)])})"
L = SymbolLang
# Conceptually a recursive expression, but actually just a list of
# enodes. Invariant: enodes children must refer to elements that
# come before it in the list.
@dataclass
class RecExpr:
nodes: List[L]
def add(self, node: L) -> Id:
assert all(id < len(self.nodes) for id in node.children)
self.nodes.append(node)
return len(self.nodes) - 1
def __getitem__(self, id: Id) -> L:
return self.nodes[id]
# An equivalence class of enodes
@dataclass
class EClass:
# This eclass's id
id: Id
# The equivalent enodes in this equivalence class
nodes: List[L]
parents: List[Tuple[L, Id]]
def is_empty(self) -> bool:
return not self.nodes
def len(self) -> int:
return len(self.nodes)
def leaves(self) -> Iterator[L]:
return filter(lambda n: n.is_leaf(), self.nodes)
class EGraph:
memo: Dict[L, Id]
unionfind: UnionFind
classes: List[Optional[EClass]]
# invariant: number of non-None classes
n_classes: int
dirty_unions: List[Id]
repairs_since_rebuild: int
def __init__(self):
self.memo = {}
self.unionfind = UnionFind()
self.classes = []
self.n_classes = 0
self.dirty_unions = []
self.repairs_since_rebuild = 0
def is_empty(self):
return not self.memo
def total_size(self) -> int:
return len(self.memo)
def total_number_of_nodes(self) -> int:
return sum(c.len() for c in self.classes if c is not None)
def find(self, id: Id) -> Id:
return self.unionfind.find(id)
def __getitem__(self, id: Id) -> EClass:
eclass = self.classes[self.find(id)]
assert eclass is not None
return eclass
def canonicalize(self, enode: L) -> L:
return L(enode.op, tuple(self.find(id) for id in enode.children))
# Add a RecExpr to the graph
def add_expr(self, expr: RecExpr) -> Id:
new_ids: List[Id] = []
for node in expr.nodes:
node = SymbolLang(node.op, tuple(new_ids[i] for i in node.children))
new_ids.append(self.add(node))
return new_ids[-1]
# Lookup the eclass of a given enode.
def lookup(self, enode: L) -> Optional[Id]:
enode = self.canonicalize(enode)
id = self.memo.get(enode)
if id is None:
return None
return self.find(id)
# Adds an enode to the EGraph. When adding an enode to the egraph,
# it performs hashconsing, ensuring only one copy of the enode is
# in the graph. If the copy was in theg raph, add simply returns
# the id of the class in which the enode was found.
def add(self, enode: L) -> Id:
enode = self.canonicalize(enode)
# slight pessimization, lookup will recanon it again
r = self.lookup(enode)
if r is not None:
return r
id = self.unionfind.make_set()
clas = EClass(
id,
nodes=[enode], # noclone
# N::make
parents=[],
)
for child in enode.children:
self[child].parents.append((enode, id)) # noclone
assert len(self.classes) == id
self.classes.append(clas)
self.n_classes += 1
assert enode not in self.memo
self.memo[enode] = id
# N::modify
return id
# Checks if two RecExprs are equivalent. Return a list of id where
# both expression are represented. In most case there will be none
# or exactly one id.
# def equivs(self, expr1: RecExpr, expr2: RexExpr) -> List[Id]:
# def check_goals
def union_impl(self, id1: Id, id2: Id) -> Tuple[Id, bool]:
# N::pre_union
to, fro = self.unionfind.union(id1, id2)
if to != fro:
self.dirty_unions.append(to)
fro_class = self.classes[fro]
to_class = self.classes[to]
assert fro_class is not None
assert to_class is not None
# analysis.merge
# TODO: minimize amount of copying
to_class.nodes.extend(fro_class.nodes)
to_class.parents.extend(fro_class.parents)
self.classes[fro] = None
self.n_classes -= 1
# N::modify
return to, to != fro
# Unions two eclasses given their ids. The given ides need not be
# canonical. Returned bool indicates whether a union was done,
# false if they were already equivalent. Both results are canonical
def union(self, id1: Id, id2: Id) -> Tuple[Id, bool]:
union = self.union_impl(id1, id2)
# upward-merging feature
return union
def rebuild_classes(self) -> int:
trimmed = 0
uf = self.unionfind
for clas in self.classes:
if clas is None:
continue
old_len = clas.len()
clas.nodes = list({
L(n.op, tuple(uf.find(id) for id in n.children))
for n in clas.nodes
})
trimmed += old_len - len(clas.nodes)
return trimmed
def check_memo(self) -> bool:
test_memo: Dict[L, Id] = {}
for id, clas in enumerate(self.classes):
if clas is None:
continue
assert clas.id == id
for node in clas.nodes:
old = test_memo.get(node)
test_memo[node] = id
if old is not None:
assert self.find(old) == self.find(id)
for n, e in test_memo.items():
assert e == self.find(e)
id2 = self.memo.get(n)
assert id2 is not None
assert e == self.find(id2)
return True
def process_unions(self) -> None:
to_union: List[Tuple[Id, Id]] = []
while self.dirty_unions:
todo = self.dirty_unions
self.dirty_unions = []
todo = {self.find(id) for id in todo}
assert todo
for id in todo:
self.repairs_since_rebuild += 1
parents = self[id].parents
for n, _e in parents:
if n in self.memo:
del self.memo[n]
parents = [
(self.canonicalize(n), self.find(id))
for n, id in parents
]
parents.sort()
# deduplicate
new_parents = []
groups = itertools.groupby(parents, key=lambda t: t[0])
for _, g in groups:
n, e0 = next(g)
new_parents.append((n, e0))
for _, e in g:
to_union.append((e0, e))
parents = new_parents
for n, e in parents:
old = self.memo.get(n)
self.memo[n] = e
if old is not None:
to_union.append((old, e))
# propagate_metadata
self[id].parents = parents
# N::modify
for (id1, id2) in to_union:
to, did_something = self.union_impl(id1, id2)
if did_something:
self.dirty_unions.append(to)
to_union = []
assert not self.dirty_unions
assert not to_union
def rebuild(self) -> int:
old_hc_size = len(self.memo)
old_n_eclasses = self.n_classes
self.process_unions()
n_unions = self.repairs_since_rebuild
trimmed_nodes = self.rebuild_classes()
print(f"""\
REBUILT!
Old: hc size {old_hc_size}, eclasses: {old_n_eclasses}
New: hc size {len(self.memo)}, eclasses: {self.n_classes}
unions: {self.repairs_since_rebuild}, trimmed nodes: {trimmed_nodes}""")
self.repairs_since_rebuild = 0
assert self.check_memo()
return n_unions
def __str__(self) -> str:
ids = [c.id for c in self.classes if c is not None]
ids.sort()
s = ""
for id in ids:
nodes = self[id].nodes # not sorted
s += f"{id}: {', '.join(str(n) for n in nodes)}\n"
return s
def simple_add():
egraph = EGraph()
x = egraph.add(L.leaf('x'))
x2 = egraph.add(L.leaf('x'))
_plus = egraph.add(L('+', (x, x2)))
y = egraph.add(L.leaf("y"))
egraph.union(x, y)
egraph.rebuild()
print(egraph)
def exercise_trim():
# goal is to get trimmed nodes nonzero
# for this to happen, we must rebuild a class and discover after
# rebuilding that two nodes we previously thought were inequivalent
# are now equivalent after canonicalization
egraph = EGraph()
x = egraph.add(L('x'))
y = egraph.add(L('y'))
z = egraph.add(L('z'))
f1 = egraph.add(L('f', (x, y)))
f2 = egraph.add(L('f', (z, y)))
egraph.union(f1, f2)
egraph.union(x, z)
egraph.rebuild()
print(egraph)
def exercise_n_classes_reduction():
# goal is to reduce n classes during rebuild
# for this to happen, a union needs to occur that wasn't obvious
# from the beginning, e.g., consequence of congruence
egraph = EGraph()
x = egraph.add(L('x'))
y = egraph.add(L('y'))
fx = egraph.add(L('f', (x,)))
fy = egraph.add(L('f', (y,)))
egraph.union(x, y)
egraph.rebuild()
print(egraph)
union_find()
simple_add()
exercise_trim()
exercise_n_classes_reduction()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment