Skip to content

Instantly share code, notes, and snippets.

@Joker-vD
Created December 27, 2024 00:09
Show Gist options
  • Save Joker-vD/7cddcfada042ed486dd74690ae986101 to your computer and use it in GitHub Desktop.
Save Joker-vD/7cddcfada042ed486dd74690ae986101 to your computer and use it in GitHub Desktop.
First-order CPS transform
def cps(expr, tail_cont=None):
'''
When called without `tail_cont`, CPS-transforms `expr` in "general position" and \
returns `(cps_builder, tmp_name)` pair: `cps_builder` is a partially-built CPS expression
that has the computed value of `expr` bound to `tmp_name` inside of it.
When called with `tail_cont`, CPS-transforms `expr` in "tail position" with `tail_cont` being
the continuation that is immediately applied to the value of `expr`, and returns
a fully completed CPS expression.
'''
def maybe_tail(result: CpsExprBuilder, result_name: str):
if tail_cont is None:
return result, result_name
return result.plug_hole(['$ret', tail_cont, result_name])
# Variables are themselves
if isinstance(expr, str):
var = expr
result = CpsExprBuilder()
return maybe_tail(result, var)
# Literals get their own temporaries
if isinstance(expr, int):
result = CpsExprBuilder()
tmp_name = gen_tmp()
result.add_let_without_body([[tmp_name, expr]])
return maybe_tail(result, tmp_name)
# All other forms are lists
if isinstance(expr, list):
what, *rest = expr
# A single-variable let. Can be pretty straightforwardly extended to let* and/or letrec
if what == 'let':
(name, init_expr), body = rest
cont_name = gen_cont()
# Surprisingly, the variable initializer is actually in tail position. It's the `let`'s
# body that may or may not be in tail position
init_tree = cps(init_expr, cont_name)
body_result = cps(body, tail_cont)
result = CpsExprBuilder()
result.add_let_without_defns(init_tree)
result.add_cont_without_body(cont_name, [name])
# Essentially, `(let (name expr) body)` turns into
# `(let (@cont_name (name) body) (evaluate expr with @cont_name)`
if tail_cont is None:
scope_subtree, result_name = body_result
result.splice_from(scope_subtree)
return result, result_name
return result.plug_hole(body_result)
# An `if`. Generalizing to pattern-matching `switch` is, again, straightforward: each branch
# turns into a separate continuation with captured variables as its arguments.
if what == 'if':
cond, *branches = rest
result, cond_name = cps(cond)
branch_names = [gen_cont() for _ in branches]
# In tail position, all of the branches will, in the end, invoke the `tail_cont`. But
# in the general position, we need to "join" the control flow into a new continuation
if tail_cont is None:
cont_name, result_name = gen_cont(), gen_tmp()
result.add_let_without_defns(['$if', cond_name, *branch_names])
else:
cont_name = tail_cont
result.add_let_without_body([])
# All branches execute in a tail position
for branch_name, branch in zip(branch_names, branches):
result.add_defn_to_let(branch_name, ['cont', [], cps(branch, cont_name)])
if tail_cont is None:
result.add_cont_without_body(cont_name, [result_name])
return result, result_name
return result.plug_hole(['$if', cond_name, *branch_names])
# It's a function application. Built-ins are evaluated statically. Another option
# would be to treat them as value constructors and turn them into definitions in `let`,
# e.g. `(+ x y)` would turn into `let (#tN ($+ x y) ...)`
if what in ('+', '-'):
bound_tmps = [f'${what}']
subexprs = rest
else:
bound_tmps = []
subexprs = expr
# Evaluate subexpressions. Glue the sub-CPS expressions together, while remembering the
# names `$tN`` to which they are bound
result = CpsExprBuilder()
for subexpr in subexprs:
sub_cps, bound_tmp = cps(subexpr)
result.splice_from(sub_cps)
bound_tmps.append(bound_tmp)
# In general position we need to generate an empty continuation `@j`, inside which the rest of
# the surrounding expression will be placed. Then we need to invoke `$t0` with `$tN`s and this
# continuation. The resulting builder will look like
#
# (code-that-binds-$tNs... (LET (@j ($cont ($x) <HOLE>)) ($t0 $t1 ... $tN @j)))
if tail_cont is None:
cont_name, result_name = gen_cont(), gen_tmp()
result.add_let_without_defns([*bound_tmps, cont_name])
result.add_cont_without_body(cont_name, [result_name])
return result, result_name
# In tail position, we need to just call `($t0 $t1 ... $tN tail_cont)`, and we're done
return result.plug_hole([*bound_tmps, tail_cont])
raise Exception(f'not an expr: {expr}')
class CpsExprBuilder:
'''
Essentially, a tree zipper with a single hole. During CPS transformation, we build the resulting expression
from outside in, and we always know where exactly we need to continue extending it. Sometimes it's the body
of a `let` expression, sometimes it's inside one of the `let`s definitions
'''
def __init__(self):
self.layers = []
def splice_from(self, other):
self.layers.extend(other.layers)
def add_let_without_body(self, defns):
self.layers.append(['LET/BODY', defns])
def add_let_without_defns(self, body):
self.layers.append(['LET/DEFNS', [], body])
def add_cont_without_body(self, name, params):
self.layers.append(['CONT/BODY', name, params])
def add_defn_to_let(self, name, value):
let = self.layers[-1]
if let[0] == 'LET/BODY' or let[0] == 'LET/DEFNS':
let[1].append([name, value])
else:
raise Exception(let[0])
def plug_hole(self, cps_expr):
'''Wrap all the layers around the `cps_expr`, and return the result'''
result = cps_expr
for layer in reversed(self.layers):
what, *rest = layer
# Make a `let` expression with `result` as its body
if what == 'LET/BODY':
defns, = rest
# A `let` expression with no definitions is equivalent to just its body
if defns:
result = make_LET(defns, result)
# Make a `let` expression with `result` as one of its definitions
elif what == 'LET/DEFNS':
defns, body = rest
result = make_LET([result, *defns], body)
# Make a continuation with `result` as its body. It will end up as a definition in
# a bigger `let`, so don't forget its name
elif what == 'CONT/BODY':
name, params = rest
result = [name, ['cont', params, result]]
else:
raise ValueError(f'unrecognized partial tree layer: {layer}')
return result
def make_LET(defns, body):
'''
Flattens `(let (defns1...) (let (defns2...) body))` into `(let (defns1... defns2...) body)`.
Not strictly needed, but nice to have and simple to implement
'''
if isinstance(body, list) and body and body[0] == 'let':
inner_defns, inner_body = body[1:]
return ['let', [*defns, *inner_defns], inner_body]
return ['let', [*defns], body]
# Boring utility functions
COUNTER = [0]
def gen_cont():
num = COUNTER[0]
COUNTER[0] += 1
return f'@k{num}'
def gen_tmp():
num = COUNTER[0]
COUNTER[0] += 1
return f'#t{num}'
# A rather inefficient way to pretty print
def pretty_cps(expr):
if not isinstance(expr, list):
return str(expr)
what, *rest = expr
if what == '$ret':
cont, *args = rest
return f'$ret {pretty_cps(rest[0])}{arglist(rest[1:])}'
if what == 'let':
defns, body = rest
return f'let\n{entab(deflist(defns))}\nin {pretty_cps(body)}'
if what == 'cont':
params, body = rest
return f'cont{arglist(map(pretty_cps, params))}:\n{entab(pretty_cps(body))}'
return f'{pretty_cps(what)}{arglist(rest)}'
def arglist(args):
return '(' + ', '.join(map(str, args)) + ')'
def deflist(defns):
result = []
for name, defn in defns:
result.append(f'{name} = {pretty_cps(defn)}')
return '\n'.join(result)
def entab(s):
return '\n'.join(map(lambda l: f' {l}', s.splitlines()))
def run_test(expr):
COUNTER[0] = 0
result = cps(expr, '@halt')
print(expr, '=>')
try:
print(pretty_cps(result))
except:
print(result)
print()
def main():
run_test('x')
run_test(['f'])
run_test(['f', 'x'])
run_test(['f', 'x', 'y'])
run_test(['f', 'x', 'y', 'z'])
run_test(5)
run_test(['+', 5, 6])
run_test(['+', ['f', 5], 6])
run_test(['+', 5, ['f', 6]])
run_test(['+', ['f', 5], ['g', 6]])
run_test(['-', ['+', 5, 6], ['+', 7, ['+', 8, 9]]])
run_test(['i', ['f', 'x', 'y'], ['h', 'z', ['g', 't'], 'u']])
run_test(['if', 'b', 'true', 'false'])
run_test(['if', 'b0', ['if', 'b1', 'both_true', 'true_false'], 'false'])
run_test(['f', ['if', 'b', 'true', 'false']])
run_test(['let', ['x', 1], 'x'])
run_test(['let', ['x', 1], 'y'])
run_test(['let', ['x', 1], ['+', 'x', 2]])
run_test(['let', ['x', 1], ['let', ['y', '2'], ['+', 'x', 'y']]])
run_test(['let', ['x', ['if', 'b', [f'f', 'a'], ['g', 'b']]], ['let', ['y', ['-', 'z', '2']], ['+', 'x', 'y']]])
run_test(['let', ['x', ['let', ['y', 2], ['+', 'y', 1]]], ['-', 'x', 'z']])
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment