Skip to content

Instantly share code, notes, and snippets.

@LeeeeT
Last active April 27, 2026 12:55
Show Gist options
  • Select an option

  • Save LeeeeT/bf3d6ea8aeed8b92e80f7dadda9ae14b to your computer and use it in GitHub Desktop.

Select an option

Save LeeeeT/bf3d6ea8aeed8b92e80f7dadda9ae14b to your computer and use it in GitHub Desktop.
k-SIC + REF
import itertools
from collections.abc import Callable
from dataclasses import dataclass
type Location = int
location = itertools.count()
type Label = int
label = itertools.count()
type Identifier = str
type Pos = VAR | NIL | LAM | SUP | REF
type Neg = SUB | ERA | APP | DUP
@dataclass(frozen=True)
class VAR:
loc: Location
@dataclass(frozen=True)
class NIL:
pass
@dataclass(frozen=True)
class LAM:
bnd: Location
bod: Location
@dataclass(frozen=True)
class SUP:
spl: Label
sp0: Location
sp1: Location
@dataclass(frozen=True)
class REF:
ide: Identifier
@dataclass(frozen=True)
class SUB:
loc: Location
@dataclass(frozen=True)
class ERA:
pass
@dataclass(frozen=True)
class APP:
arg: Location
ret: Location
@dataclass(frozen=True)
class DUP:
dpl: Label
dp0: Location
dp1: Location
type Package = tuple[Net, Location]
type Redex = tuple[Location, Location]
@dataclass(frozen=True)
class Net:
vars: dict[Location, Pos]
subs: dict[Location, Neg]
scop: dict[Identifier, Package]
book: set[Redex]
def net_empty() -> Net:
return Net({}, {}, {}, set())
def net_embed(net: Net, embedding: Net) -> None:
net.vars.update(embedding.vars)
net.subs.update(embedding.subs)
net.scop.update(embedding.scop)
net.book.update(embedding.book)
def net_clone(net: Net, locations: dict[Location, Location], labels: dict[Label, Label]) -> Net:
def clone_loc(loc: Location) -> Location:
if loc not in locations:
locations[loc] = next(location)
return locations[loc]
def clone_lab(lab: Label) -> Label:
if lab not in labels:
labels[lab] = next(label)
return labels[lab]
def clone_pos(pos: Pos) -> Pos:
match pos:
case VAR(loc):
return VAR(clone_loc(loc))
case NIL():
return NIL()
case LAM(bnd, bod):
return LAM(clone_loc(bnd), clone_loc(bod))
case SUP(spl, sp0, sp1):
return SUP(clone_lab(spl), clone_loc(sp0), clone_loc(sp1))
case REF(ide):
return REF(ide)
def clone_neg(neg: Neg) -> Neg:
match neg:
case SUB(loc):
return SUB(clone_loc(loc))
case ERA():
return ERA()
case APP(arg, ret):
return APP(clone_loc(arg), clone_loc(ret))
case DUP(dpl, dp0, dp1):
return DUP(clone_lab(dpl), clone_loc(dp0), clone_loc(dp1))
new = net_empty()
for loc, pos in net.vars.items():
new.vars[clone_loc(loc)] = clone_pos(pos)
for loc, neg in net.subs.items():
new.subs[clone_loc(loc)] = clone_neg(neg)
for ide, dfn in net.scop.items():
new.scop[ide] = dfn
for lhs, rhs in net.book:
new.book.add((clone_loc(lhs), clone_loc(rhs)))
return new
def pos(net: Net, term: Pos) -> Location:
pos = next(location)
net.vars[pos] = term
return pos
def neg(net: Net, term: Neg) -> Location:
neg = next(location)
net.subs[neg] = term
return neg
def var(net: Net, loc: Location) -> Location:
return pos(net, VAR(loc))
def nil(net: Net) -> Location:
return pos(net, NIL())
def lam(net: Net, bnd: Location, bod: Location) -> Location:
return pos(net, LAM(bnd, bod))
def sup(net: Net, spl: Label, sp0: Location, sp1: Location) -> Location:
return pos(net, SUP(spl, sp0, sp1))
def ref(net: Net, ide: Identifier) -> Location:
return pos(net, REF(ide))
def sub(net: Net, loc: Location) -> Location:
return neg(net, SUB(loc))
def era(net: Net) -> Location:
return neg(net, ERA())
def app(net: Net, arg: Location, ret: Location) -> Location:
return neg(net, APP(arg, ret))
def dup(net: Net, dpl: Label, dp0: Location, dp1: Location) -> Location:
return neg(net, DUP(dpl, dp0, dp1))
def wire(net: Net) -> tuple[Location, Location]:
nam = next(location)
return var(net, nam), sub(net, nam)
def define(net: Net, ide: Identifier, cons: Callable[[Net], Location]) -> None:
dfn = net_empty()
rot = cons(dfn)
net.scop[ide] = dfn, rot
def show_pos(net: Net, pos: Location) -> str:
match net.vars[pos]:
case VAR(loc) if loc in net.vars:
return show_pos(net, loc)
case VAR(loc):
return f"+{loc}"
case NIL():
return "+_"
case LAM(bnd, bod):
return f"+({show_neg(net, bnd)} {show_pos(net, bod)})"
case SUP(spl, sp0, sp1):
return f"+{spl}{{{show_pos(net, sp0)} {show_pos(net, sp1)}}}"
case REF(ide):
return f"+@{ide}"
def show_neg(net: Net, neg: Location) -> str:
match net.subs[neg]:
case SUB(loc) if loc in net.subs:
return show_neg(net, loc)
case SUB(loc):
return f"-{loc}"
case ERA():
return "-_"
case APP(arg, ret):
return f"-({show_pos(net, arg)} {show_neg(net, ret)})"
case DUP(dpl, dp0, dp1):
return f"-&{dpl}{{{show_neg(net, dp0)} {show_neg(net, dp1)}}}"
def reduce(net: Net) -> int:
itrs = 0
while net.book:
itrs += 1
lhs, rhs = net.book.pop()
match net.subs.pop(lhs), net.vars.pop(rhs):
case lhsc, VAR(loc) if loc in net.vars:
net.subs[lhs] = lhsc
net.book.add((lhs, loc))
case lhsc, VAR(loc):
net.subs[loc] = lhsc
case SUB(loc), rhsc if loc in net.subs:
net.vars[rhs] = rhsc
net.book.add((loc, rhs))
case SUB(loc), rhsc:
net.vars[loc] = rhsc
case ERA(), NIL():
pass
case ERA(), LAM(bnd, bod):
net.book.add((bnd, nil(net)))
net.book.add((era(net), bod))
case ERA(), SUP(spl, sp0, sp1):
net.book.add((era(net), sp0))
net.book.add((era(net), sp1))
case ERA(), REF(ide):
pass
case APP(arg, ret), NIL():
net.book.add((era(net), arg))
net.book.add((ret, nil(net)))
case APP(arg, ret), LAM(bnd, bod):
net.book.add((bnd, arg))
net.book.add((ret, bod))
case APP(arg, ret), SUP(spl, sp0, sp1):
ap, an = wire(net)
bp, bn = wire(net)
cp, cn = wire(net)
dp, dn = wire(net)
net.book.add((dup(net, spl, an, bn), arg))
net.book.add((ret, sup(net, spl, cp, dp)))
net.book.add((app(net, ap, cn), sp0))
net.book.add((app(net, bp, dn), sp1))
case DUP(dpl, dp0, dp1), NIL():
net.book.add((dp0, nil(net)))
net.book.add((dp1, nil(net)))
case DUP(dpl, dp0, dp1), LAM(bnd, bod):
ap, an = wire(net)
bp, bn = wire(net)
cp, cn = wire(net)
dp, dn = wire(net)
net.book.add((dp0, lam(net, an, bp)))
net.book.add((dp1, lam(net, cn, dp)))
net.book.add((bnd, sup(net, dpl, ap, cp)))
net.book.add((dup(net, dpl, bn, dn), bod))
case DUP(dpl, dp0, dp1), SUP(spl, sp0, sp1) if dpl == spl:
net.book.add((dp0, sp0))
net.book.add((dp1, sp1))
case DUP(dpl, dp0, dp1), SUP(spl, sp0, sp1):
ap, an = wire(net)
bp, bn = wire(net)
cp, cn = wire(net)
dp, dn = wire(net)
net.book.add((dp0, sup(net, spl, ap, bp)))
net.book.add((dp1, sup(net, spl, cp, dp)))
net.book.add((dup(net, dpl, an, cn), sp0))
net.book.add((dup(net, dpl, bn, dn), sp1))
case DUP(dpl, dp0, dp1), REF(ide):
net.book.add((dp0, ref(net, ide)))
net.book.add((dp1, ref(net, ide)))
case lhsc, REF(ide):
net.subs[lhs] = lhsc
dfn, rot = net.scop[ide]
net_embed(net, net_clone(dfn, locations := {}, labels := {}))
net.book.add((lhs, locations[rot]))
return itrs
def print_scope(net: Net) -> None:
for i, (ide, (dfn, rot)) in enumerate(net.scop.items()):
print(f"{ide} = {show_pos(dfn, rot)}")
for lhs, rhs in dfn.book:
print(f" {show_neg(dfn, lhs)} ⋈ {show_pos(dfn, rhs)}")
if i + 1 < len(net.scop):
print()
def print_state(net: Net, root: Location, *, heap: bool = False) -> None:
print("ROOT:")
print(f" {show_pos(net, root)}")
print()
print("BOOK:")
for lhs, rhs in net.book:
print(f" {show_neg(net, lhs)} ⋈ {show_pos(net, rhs)}")
if heap:
print()
print("VARS:")
for loc in net.vars:
print(f" {loc} = {show_pos(net, loc)}")
print()
print("SUBS:")
for loc in net.subs:
print(f" {loc} = {show_neg(net, loc)}")
def print_reduction(net: Net, root: Location, *, heap: bool = False) -> None:
print("=" * 30)
print("=", "SCOPE".center(26), "=")
print("=" * 30)
print()
print_scope(net)
print()
print("=" * 30)
print("=", "INITIAL".center(26), "=")
print("=" * 30)
print()
print_state(net, root, heap=heap)
print()
print("=" * 30)
print("=", "NORMALIZED".center(26), "=")
print("=" * 30)
print()
itrs = reduce(net)
print_state(net, root, heap=heap)
print()
print(f"ITRS: {itrs}")
def mk_app(net: Net, fun: Location, arg: Location) -> Location:
ap, an = wire(net)
net.book.add((app(net, arg, an), fun))
return ap
def mk_dup(net: Net, bod: Location, lab: Label | None = None) -> tuple[Location, Location]:
if lab is None:
lab = next(label)
ap, an = wire(net)
bp, bn = wire(net)
net.book.add((dup(net, lab, an, bn), bod))
return ap, bp
# @C2 = λf.λx.(f (f x))
def mk_c2(net: Net) -> Location:
fp, fn = wire(net)
xp, xn = wire(net)
f0, f1 = mk_dup(net, fp)
fx = mk_app(net, f0, xp)
ffx = mk_app(net, f1, fx)
return lam(net, fn, lam(net, xn, ffx))
def mk_cpow2(net: Net, k: int) -> Location:
# k=1 -> c2, k=2 -> c4, k=3 -> c8, ...
# c_{2^k} = λf. c2 (c_{2^(k-1)} f)
if k < 1:
raise ValueError("k must be >= 1")
if k == 1:
return ref(net, "C2")
fp, fn = wire(net)
prev = mk_cpow2(net, k - 1) # c_{2^(k-1)}
prev_f = mk_app(net, prev, fp) # f^(2^(k-1))
body = mk_app(net, ref(net, "C2"), prev_f) # square -> f^(2^k)
return lam(net, fn, body)
# @F0 = λb.λt.λf.(b f f)
def mk_F0(net: Net) -> Location:
bp, bn = wire(net)
fp, fn = wire(net)
f0, f1 = mk_dup(net, fp)
return lam(net, bn, lam(net, era(net), lam(net, fn, mk_app(net, mk_app(net, bp, f0), f1))))
# @F1 = λb.λt.λf.(b (f f) f)
def mk_F1(net: Net) -> Location:
bp, bn = wire(net)
fp, fn = wire(net)
f0, f1 = mk_dup(net, fp)
f1, f2 = mk_dup(net, f1)
return lam(net, bn, lam(net, era(net), lam(net, fn, mk_app(net, mk_app(net, bp, mk_app(net, f0, f1)), f2))))
# F^(2^N)
def test_fusion(net: Net, F: Location, N: int) -> Location:
return mk_app(net, mk_cpow2(net, N), F)
def main() -> None:
net = net_empty()
define(net, "C2", mk_c2)
define(net, "F0", mk_F0)
define(net, "F1", mk_F1)
root = test_fusion(net, ref(net, "F1"), 5)
print_reduction(net, root)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment