Created
December 27, 2024 00:09
-
-
Save Joker-vD/7cddcfada042ed486dd74690ae986101 to your computer and use it in GitHub Desktop.
First-order CPS transform
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def cps(expr, tail_cont=None): | |
''' | |
When called without `tail_cont`, CPS-transforms `expr` in "general position" and \ | |
returns `(cps_builder, tmp_name)` pair: `cps_builder` is a partially-built CPS expression | |
that has the computed value of `expr` bound to `tmp_name` inside of it. | |
When called with `tail_cont`, CPS-transforms `expr` in "tail position" with `tail_cont` being | |
the continuation that is immediately applied to the value of `expr`, and returns | |
a fully completed CPS expression. | |
''' | |
def maybe_tail(result: CpsExprBuilder, result_name: str): | |
if tail_cont is None: | |
return result, result_name | |
return result.plug_hole(['$ret', tail_cont, result_name]) | |
# Variables are themselves | |
if isinstance(expr, str): | |
var = expr | |
result = CpsExprBuilder() | |
return maybe_tail(result, var) | |
# Literals get their own temporaries | |
if isinstance(expr, int): | |
result = CpsExprBuilder() | |
tmp_name = gen_tmp() | |
result.add_let_without_body([[tmp_name, expr]]) | |
return maybe_tail(result, tmp_name) | |
# All other forms are lists | |
if isinstance(expr, list): | |
what, *rest = expr | |
# A single-variable let. Can be pretty straightforwardly extended to let* and/or letrec | |
if what == 'let': | |
(name, init_expr), body = rest | |
cont_name = gen_cont() | |
# Surprisingly, the variable initializer is actually in tail position. It's the `let`'s | |
# body that may or may not be in tail position | |
init_tree = cps(init_expr, cont_name) | |
body_result = cps(body, tail_cont) | |
result = CpsExprBuilder() | |
result.add_let_without_defns(init_tree) | |
result.add_cont_without_body(cont_name, [name]) | |
# Essentially, `(let (name expr) body)` turns into | |
# `(let (@cont_name (name) body) (evaluate expr with @cont_name)` | |
if tail_cont is None: | |
scope_subtree, result_name = body_result | |
result.splice_from(scope_subtree) | |
return result, result_name | |
return result.plug_hole(body_result) | |
# An `if`. Generalizing to pattern-matching `switch` is, again, straightforward: each branch | |
# turns into a separate continuation with captured variables as its arguments. | |
if what == 'if': | |
cond, *branches = rest | |
result, cond_name = cps(cond) | |
branch_names = [gen_cont() for _ in branches] | |
# In tail position, all of the branches will, in the end, invoke the `tail_cont`. But | |
# in the general position, we need to "join" the control flow into a new continuation | |
if tail_cont is None: | |
cont_name, result_name = gen_cont(), gen_tmp() | |
result.add_let_without_defns(['$if', cond_name, *branch_names]) | |
else: | |
cont_name = tail_cont | |
result.add_let_without_body([]) | |
# All branches execute in a tail position | |
for branch_name, branch in zip(branch_names, branches): | |
result.add_defn_to_let(branch_name, ['cont', [], cps(branch, cont_name)]) | |
if tail_cont is None: | |
result.add_cont_without_body(cont_name, [result_name]) | |
return result, result_name | |
return result.plug_hole(['$if', cond_name, *branch_names]) | |
# It's a function application. Built-ins are evaluated statically. Another option | |
# would be to treat them as value constructors and turn them into definitions in `let`, | |
# e.g. `(+ x y)` would turn into `let (#tN ($+ x y) ...)` | |
if what in ('+', '-'): | |
bound_tmps = [f'${what}'] | |
subexprs = rest | |
else: | |
bound_tmps = [] | |
subexprs = expr | |
# Evaluate subexpressions. Glue the sub-CPS expressions together, while remembering the | |
# names `$tN`` to which they are bound | |
result = CpsExprBuilder() | |
for subexpr in subexprs: | |
sub_cps, bound_tmp = cps(subexpr) | |
result.splice_from(sub_cps) | |
bound_tmps.append(bound_tmp) | |
# In general position we need to generate an empty continuation `@j`, inside which the rest of | |
# the surrounding expression will be placed. Then we need to invoke `$t0` with `$tN`s and this | |
# continuation. The resulting builder will look like | |
# | |
# (code-that-binds-$tNs... (LET (@j ($cont ($x) <HOLE>)) ($t0 $t1 ... $tN @j))) | |
if tail_cont is None: | |
cont_name, result_name = gen_cont(), gen_tmp() | |
result.add_let_without_defns([*bound_tmps, cont_name]) | |
result.add_cont_without_body(cont_name, [result_name]) | |
return result, result_name | |
# In tail position, we need to just call `($t0 $t1 ... $tN tail_cont)`, and we're done | |
return result.plug_hole([*bound_tmps, tail_cont]) | |
raise Exception(f'not an expr: {expr}') | |
class CpsExprBuilder: | |
''' | |
Essentially, a tree zipper with a single hole. During CPS transformation, we build the resulting expression | |
from outside in, and we always know where exactly we need to continue extending it. Sometimes it's the body | |
of a `let` expression, sometimes it's inside one of the `let`s definitions | |
''' | |
def __init__(self): | |
self.layers = [] | |
def splice_from(self, other): | |
self.layers.extend(other.layers) | |
def add_let_without_body(self, defns): | |
self.layers.append(['LET/BODY', defns]) | |
def add_let_without_defns(self, body): | |
self.layers.append(['LET/DEFNS', [], body]) | |
def add_cont_without_body(self, name, params): | |
self.layers.append(['CONT/BODY', name, params]) | |
def add_defn_to_let(self, name, value): | |
let = self.layers[-1] | |
if let[0] == 'LET/BODY' or let[0] == 'LET/DEFNS': | |
let[1].append([name, value]) | |
else: | |
raise Exception(let[0]) | |
def plug_hole(self, cps_expr): | |
'''Wrap all the layers around the `cps_expr`, and return the result''' | |
result = cps_expr | |
for layer in reversed(self.layers): | |
what, *rest = layer | |
# Make a `let` expression with `result` as its body | |
if what == 'LET/BODY': | |
defns, = rest | |
# A `let` expression with no definitions is equivalent to just its body | |
if defns: | |
result = make_LET(defns, result) | |
# Make a `let` expression with `result` as one of its definitions | |
elif what == 'LET/DEFNS': | |
defns, body = rest | |
result = make_LET([result, *defns], body) | |
# Make a continuation with `result` as its body. It will end up as a definition in | |
# a bigger `let`, so don't forget its name | |
elif what == 'CONT/BODY': | |
name, params = rest | |
result = [name, ['cont', params, result]] | |
else: | |
raise ValueError(f'unrecognized partial tree layer: {layer}') | |
return result | |
def make_LET(defns, body): | |
''' | |
Flattens `(let (defns1...) (let (defns2...) body))` into `(let (defns1... defns2...) body)`. | |
Not strictly needed, but nice to have and simple to implement | |
''' | |
if isinstance(body, list) and body and body[0] == 'let': | |
inner_defns, inner_body = body[1:] | |
return ['let', [*defns, *inner_defns], inner_body] | |
return ['let', [*defns], body] | |
# Boring utility functions | |
COUNTER = [0] | |
def gen_cont(): | |
num = COUNTER[0] | |
COUNTER[0] += 1 | |
return f'@k{num}' | |
def gen_tmp(): | |
num = COUNTER[0] | |
COUNTER[0] += 1 | |
return f'#t{num}' | |
# A rather inefficient way to pretty print | |
def pretty_cps(expr): | |
if not isinstance(expr, list): | |
return str(expr) | |
what, *rest = expr | |
if what == '$ret': | |
cont, *args = rest | |
return f'$ret {pretty_cps(rest[0])}{arglist(rest[1:])}' | |
if what == 'let': | |
defns, body = rest | |
return f'let\n{entab(deflist(defns))}\nin {pretty_cps(body)}' | |
if what == 'cont': | |
params, body = rest | |
return f'cont{arglist(map(pretty_cps, params))}:\n{entab(pretty_cps(body))}' | |
return f'{pretty_cps(what)}{arglist(rest)}' | |
def arglist(args): | |
return '(' + ', '.join(map(str, args)) + ')' | |
def deflist(defns): | |
result = [] | |
for name, defn in defns: | |
result.append(f'{name} = {pretty_cps(defn)}') | |
return '\n'.join(result) | |
def entab(s): | |
return '\n'.join(map(lambda l: f' {l}', s.splitlines())) | |
def run_test(expr): | |
COUNTER[0] = 0 | |
result = cps(expr, '@halt') | |
print(expr, '=>') | |
try: | |
print(pretty_cps(result)) | |
except: | |
print(result) | |
print() | |
def main(): | |
run_test('x') | |
run_test(['f']) | |
run_test(['f', 'x']) | |
run_test(['f', 'x', 'y']) | |
run_test(['f', 'x', 'y', 'z']) | |
run_test(5) | |
run_test(['+', 5, 6]) | |
run_test(['+', ['f', 5], 6]) | |
run_test(['+', 5, ['f', 6]]) | |
run_test(['+', ['f', 5], ['g', 6]]) | |
run_test(['-', ['+', 5, 6], ['+', 7, ['+', 8, 9]]]) | |
run_test(['i', ['f', 'x', 'y'], ['h', 'z', ['g', 't'], 'u']]) | |
run_test(['if', 'b', 'true', 'false']) | |
run_test(['if', 'b0', ['if', 'b1', 'both_true', 'true_false'], 'false']) | |
run_test(['f', ['if', 'b', 'true', 'false']]) | |
run_test(['let', ['x', 1], 'x']) | |
run_test(['let', ['x', 1], 'y']) | |
run_test(['let', ['x', 1], ['+', 'x', 2]]) | |
run_test(['let', ['x', 1], ['let', ['y', '2'], ['+', 'x', 'y']]]) | |
run_test(['let', ['x', ['if', 'b', [f'f', 'a'], ['g', 'b']]], ['let', ['y', ['-', 'z', '2']], ['+', 'x', 'y']]]) | |
run_test(['let', ['x', ['let', ['y', 2], ['+', 'y', 1]]], ['-', 'x', 'z']]) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment