Skip to content

Instantly share code, notes, and snippets.

@iitalics
Created November 21, 2024 17:30
Show Gist options
  • Save iitalics/4931c9b8cb7a42b9c8146bcc7075631c to your computer and use it in GitHub Desktop.
Save iitalics/4931c9b8cb7a42b9c8146bcc7075631c to your computer and use it in GitHub Desktop.
(* This module, [S], defines a very minimal syntax tree that we are going to try to
convert into A-normal form (ANF). ANF is an intermediate form that makes control flow
explicit and lifts intermediate values into named variables. We are going to start with
a very basic lowering algorithm and gradually refine it until it is robust. *)
module S = struct
type name = string [@@deriving show]
type t =
| Var of name
| Int of int
| Bop of t * t
| Let of name * t * t
| If0 of t * t * t
[@@deriving show {with_path = false}]
end
(* The first compiler, [C0], does nothing at all. It traverses the syntax tree and then
reproduces it using helper functions for constructing each variant of the AST. This is
going to be a template that acts as a good foundation for future passes. In fact, we
are hardly going to touch the [lower] function at all, instead only changing the helper
functions. *)
module C0 = struct
type t = S.t
let var x : t = Var x
let int i : t = Int i
let bop (t1 : t) (t2 : t) : t = Bop (t1, t2)
let let_ x (t1 : t) (t2 : t) : t = Let (x, t1, t2)
let if0 (t1 : t) (t2 : t) (t3 : t) : t = If0 (t1, t2, t3)
let rec lower : S.t -> t = function
| Var x -> var x
| Int i -> int i
| Bop (s1, s2) -> bop (lower s1) (lower s2)
| Let (x, s1, s2) -> let_ x (lower s1) (lower s2)
| If0 (s1, s2, s3) -> if0 (lower s1) (lower s2) (lower s3)
end
(* The first real compiler, [C1], turns intermediate expressions into bound
variables. This is simply accomplished by editing the helper functions that take
multiple terms as arguments, and creating [fresh] variables for the arguments, then
introducing them in scope.
Example input:
(Let ("x", (Bop (Int 1, (Bop (Int 2, Int 3))),
(Bop (Var "x", Int 4))))
C1 output:
(Let ("x", (Let ("%5", Int 1,
(Let ("%6", (Let ("%3", Int 2,
(Let ("%4", Int 3,
(Bop (Var "%3", Var "%4")))))),
(Bop (Var "%5", Var "%6")))))),
(Let ("%1", Var "x",
(Let ("%2", Int 4,
(Bop (Var "%1", Var "%2"))))))))
This works okay, but it creates [Let] bindings that are at various nesting levels. We
would prefer if they were instead "flattened" so that [Let] only appears in the body of
another [Let]: "(Let (x1, e1, (Let (x2, e3, (Let (x3, e3, ...))))))". *)
let _fresh = ref 0
let fresh () = Printf.sprintf "%%%d" (incr _fresh; !_fresh)
module C1 = struct
type t = S.t
let var x : t = Var x
let int i : t = Int i
let bop (t1 : t) (t2 : t) : t =
let x1 = fresh () in
let x2 = fresh () in
let t : S.t = Bop (Var x1, Var x2) in
let t : S.t = Let (x2, t2, t) in
let t : S.t = Let (x1, t1, t) in
t
let let_ x (t1 : t) (t2 : t) : t =
let t : S.t = Let (x, t1, t2) in
t
let if0 (t1 : t) (t2 : t) (t3 : t) : t =
let x1 = fresh () in
let t : S.t = If0 (Var x1, t2, t3) in
let t : S.t = Let (x1, t1, t) in
t
let rec lower : S.t -> t = function
| Var x -> var x
| Int i -> int i
| Bop (s1, s2) -> bop (lower s1) (lower s2)
| Let (x, s1, s2) -> let_ x (lower s1) (lower s2)
| If0 (s1, s2, s3) -> if0 (lower s1) (lower s2) (lower s3)
end
(* To solve the nested [Let] problem, we do a clever trick. We simply convert the output
type of [lower] into its continuation-passing-style equivalent: [(S.t -> S.t) -> S.t].
Values of this type should construct a result and then pass that result to the
continuation argument passed to it. Adapting the helper functions to use this type is
not very difficult in many cases, but it requires care. We only pass the final result
(e.g. [Bop (...)]) to the given continuation, and for the intermediate values we pass a
lambda to them to extract the expression (e.g. [t1 (fun e1 -> ... e1 ...)]). Finally,
to "kick-off" the algorithm, we have to pass [Fun.id] as the initial continuation.
Example input (same as before):
(Let ("x", (Bop (Int 1, (Bop (Int 2, Int 3))),
(Bop (Var "x", Int 4))))
C2 output:
(Let ("%5", Int 1,
(Let ("%3", Int 2,
(Let ("%4", Int 3,
(Let ("%6", (Bop (Var "%3", Var "%4")),
(Let ("x", (Bop (Var "%5", Var "%6")),
(Let ("%1", Var "x",
(Let ("%2", Int 4,
(Bop (Var "%1", Var "%2"))))))))))))))))
Incredibly, this produces exactly the output we were looking for! All of the [Let]'s
are perfect nested within eachother. *)
module C2 = struct
type t = (S.t -> S.t) -> S.t
let var x : t = fun k -> k (Var x)
let int i : t = fun k -> k (Int i)
let bop (t1 : t) (t2 : t) : t =
let x1 = fresh () in
let x2 = fresh () in
fun k ->
let t : S.t = k (Bop (Var x1, Var x2)) in
let t : S.t = t2 (fun e2 -> Let (x2, e2, t)) in
let t : S.t = t1 (fun e1 -> Let (x1, e1, t)) in
t
let let_ x (t1 : t) (t2 : t) : t =
fun k ->
let t : S.t = t1 (fun e1 -> Let (x, e1, t2 k)) in
t
let if0 (t1 : t) (t2 : t) (t3 : t) : t =
let x1 = fresh () in
fun k ->
let t : S.t = If0 (Var x1, t2 k, t3 k) in
let t : S.t = t1 (fun e1 -> Let (x1, e1, t)) in
t
let rec lower : S.t -> t = function
| Var x -> var x
| Int i -> int i
| Bop (s1, s2) -> bop (lower s1) (lower s2)
| Let (x, s1, s2) -> let_ x (lower s1) (lower s2)
| If0 (s1, s2, s3) -> if0 (lower s1) (lower s2) (lower s3)
let lower (s : S.t) : S.t =
lower s Fun.id
end
(* At this point, we are ready to define an actual ANF data type which will be the target
of the compiler. This type only allows [name]'s to be the argument to [Bop] and [If0],
and it only allows [Let] and [If0] to appear as the body of another [Let] or [If0],
which means it is syntactically impossible to have nested [Let]'s like we encountered
in [C1].
C3 output:
(Let ("%5", Int 1,
(Let ("%3", Int 2,
(Let ("%4", Int 3,
(Let ("%6", (Bop ("%3", "%4")),
(Let ("x", (Bop ("%5", "%6")),
(Let ("%1", Var "x",
(Let ("%2", Int 4,
(Return (Bop ("%1", "%2")))))))))))))))))
This is more or less identical to [C2] except for some terms are no longer [Var] since
the requirement of taking names as arguments is made implicit.
Unfortunately this algorithm suffers from a problem which existed in [C1] but was not
pointed out. Sometimes when we lower an [If] expression it will duplicate the remainder
part of the program.
Example program:
(Let ("x", (If0 (Int 0, Int 1, Int 2)),
(If0 ((Bop (Var "x", Int 3)),
Int 4,
Int 5))))
C3 output:
(Let ("%4", Int 0,
(If0 ("%4",
(Let ("x", Int 1,
(Let ("%1", Var "x",
(Let ("%2", Int 3,
(Let ("%3", (Bop ("%1", "%2")),
(If0 ("%3",
(Return (Int 4)),
(Return (Int 5)))))))))))),
(Let ("x", Int 2,
(Let ("%1", Var "x",
(Let ("%2", Int 3,
(Let ("%3", (Bop ("%1", "%2")),
(If0 ("%3",
(Return (Int 4)),
(Return (Int 5))))))))))))))))
It is not hard to notice that there is a lot of duplicated code here. The construction
of [Bop ("%1", "%2")] and the following [If0 ("%3", ...)] is copy-pasted in its
entirety for both branches of the prior [If0 ("%4", ...)]! This problem can quickly
grow out of control as it will exponentially duplicate code for each [If0] expression
in the source program. Another issue is that we have duplicate [Let] bindings for the
same name, which is undesirable since we would like variables to have a single unique
place where they are defined. *)
module A1 = struct
type name = string [@@deriving show]
type exp =
| Var of name
| Int of int
| Bop of name * name
[@@deriving show {with_path = false}]
type t =
| Return of exp
| Let of name * exp * t
| If0 of name * t * t
[@@deriving show {with_path = false}]
end
module C3 = struct
module A = A1
type t = (A.exp -> A.t) -> A.t
let var x : t = fun k -> k (Var x)
let int i : t = fun k -> k (Int i)
let bop (t1 : t) (t2 : t) : t =
let x1 = fresh () in
let x2 = fresh () in
fun k ->
let t : A.t = k (Bop (x1, x2)) in
let t : A.t = t2 (fun e2 -> Let (x2, e2, t)) in
let t : A.t = t1 (fun e1 -> Let (x1, e1, t)) in
t
let let_ x (t1 : t) (t2 : t) : t =
fun k ->
let t : A.t = t1 (fun e1 -> Let (x, e1, t2 k)) in
t
let if0 (t1 : t) (t2 : t) (t3 : t) : t =
let x1 = fresh () in
fun k ->
let t : A.t = If0 (x1, t2 k, t3 k) in
let t : A.t = t1 (fun e1 -> Let (x1, e1, t)) in
t
let rec lower : S.t -> t = function
| Var x -> var x
| Int i -> int i
| Bop (s1, s2) -> bop (lower s1) (lower s2)
| Let (x, s1, s2) -> let_ x (lower s1) (lower s2)
| If0 (s1, s2, s3) -> if0 (lower s1) (lower s2) (lower s3)
let lower (s : S.t) : A.t =
lower s (fun e -> Return e)
end
(* In order to fix the duplication problem, we have to extend our ANF type to include a
new construct, a "join point". Join points are labelled functions which are only
allowed to be called in "tail-position", ie. in the body of a [Let], not the right hand
side. [LetJoin] defines a named join point with a single argument, and [Join] jumps to
the join point by passing an expression to it. The new ANF module is named [A2].
Join points are introduced during compilation of [If0]. Our aim is that we need to
avoid calling the continuation [k] multiple times, so we instead create a join point
[j] and then pass [fun e -> Join (j, e)] to both branches.
Example program (same as before):
(Let ("x", (If0 (Int 0, Int 1, Int 2)),
(If0 ((Bop (Var "x", Int 3)),
Int 4,
Int 5))))
C4 output:
(LetJoin ("%5", "%6",
(Let ("x", Var "%6",
(LetJoin ("%7", "%8", (Return (Var "%8")),
(Let ("%1", Var "x",
(Let ("%2", Int 3,
(Let ("%3", (Bop ("%1", "%2")),
(If0 ("%3",
(Join ("%7", Int 4)),
(Join ("%7", Int 5))))))))))))))
(Let ("%4", Int 0,
(If0 ("%4", (Join ("%5", Int 1)), (Join ("%5", Int 2))))))))
Our duplication issue is solved, as we now have introduced join points for each [If0]
in the source program. These join points are called rather than duplicating code for
each continuation. *)
module A2 = struct
type name = string [@@deriving show]
type exp =
| Var of name
| Int of int
| Bop of name * name
[@@deriving show {with_path = false}]
type t =
| Return of exp
| Let of name * exp * t
| If0 of name * t * t
| LetJoin of name * name * t * t
| Join of name * exp
[@@deriving show {with_path = false}]
end
module C4 = struct
module A = A2
type t = (A.exp -> A.t) -> A.t
let var x : t = fun k -> k (Var x)
let int i : t = fun k -> k (Int i)
let bop (t1 : t) (t2 : t) : t =
let x1 = fresh () in
let x2 = fresh () in
fun k ->
let t : A.t = k (Bop (x1, x2)) in
let t : A.t = t2 (fun e2 -> Let (x2, e2, t)) in
let t : A.t = t1 (fun e1 -> Let (x1, e1, t)) in
t
let let_ x (t1 : t) (t2 : t) : t =
fun k ->
let t : A.t = t1 (fun e1 -> Let (x, e1, t2 k)) in
t
let if0 (t1 : t) (t2 : t) (t3 : t) : t =
let x1 = fresh () in
fun k0 ->
let j = fresh () in
let x0 = fresh () in
let t0 = k0 (Var x0) in
let k : A.exp -> A.t = fun e -> Join (j, e) in
let t : A.t = If0 (x1, t2 k, t3 k) in
let t : A.t = t1 (fun e1 -> Let (x1, e1, t)) in
let t : A.t = LetJoin (j, x0, t0, t) in
t
let rec lower : S.t -> t = function
| Var x -> var x
| Int i -> int i
| Bop (s1, s2) -> bop (lower s1) (lower s2)
| Let (x, s1, s2) -> let_ x (lower s1) (lower s2)
| If0 (s1, s2, s3) -> if0 (lower s1) (lower s2) (lower s3)
let lower (s : S.t) : A.t =
lower s (fun e -> Return e)
end
(* The previous version of the compiler works reasonably well, but there is a subtle
issue. Because join points are introduced for every [If0], certain expressions that
used to be in tail-position are no longer, as they become arguments to joins. In the
running example, the [Int 4] and [Int 5] expressions are no longer in tail-position
even though they ought to be, had we followed the structure of the source program where
they are obviously the final expressions returned. In other words, we would like to
produce the ANF terms [Return (Int 4)] and [Return (Int 5)] rather than [Join ("%7",
Int 4)] and [Join ("%7", Int 5)].
The source of this issue is that [if0] is defensively creating a join point to guard
every continuation. However, some continuations are perfectly reasonable to duplicate,
such as the initial continuation [fun e -> Return e]. In [C4] continuations are all
lambda functions, which means that they cannot be inspected to determine if they should
be allowed to be duplicated.
We can solve this problem by defunctionalizing the continuation type, replacing it with
a concrete type representing the different lambdas that appear in the program. Luckily,
there are not that many ways that we create continuations. In fact, the resulting
defunctionalized data type only has three constructors. [Return] which represents [fun
e -> Return e], [Join j] which represents [fun e -> Join (j, e)], and [Let (x, t)]
which represents [fun e -> Let (x, e, t)]. Only the last such continuation is the type
we should avoid duplicating, so we single it out in [C5] during the [if0] function.
When a continuation is allowed to be copied then we lower it the "old way".
Example program (same as before):
(Let ("x", (If0 (Int 0, Int 1, Int 2)),
(If0 ((Bop (Var "x", Int 3)),
Int 4,
Int 5))))
C5 output:
(LetJoin ("%5", "%6",
(Let ("x", Var "%6",
(Let ("%1", Var "x",
(Let ("%2", Int 3,
(Let ("%3", (Bop ("%1", "%2")),
(If0 ("%3", (Return (Int 4)), (Return (Int 5)))))))))))),
(Let ("%4", (Int 0),
(If0 ("%4",
(Join ("%5", Int 1)),
(Join ("%5", Int 2))))))))
This variation does not have any problematic duplication, but it maintains
tail-position for expressions that were in tail-position in the original syntax tree.*)
module C5 = struct
module A = A2
type cont =
| Return
| Join of A.name
| Let of A.name * A.t
let ( $ ) : cont -> A.exp -> A.t = function
| Return -> fun e -> Return e
| Join j -> fun e -> Join (j, e)
| Let (x, t) -> fun e -> Let (x, e, t)
let is_copyable_cont : cont -> bool = function
| Return | Join _ -> true
| Let (_, _) -> false
type t = cont -> A.t
let var x : t = fun k -> k $ Var x
let int i : t = fun k -> k $ Int i
let bop (t1 : t) (t2 : t) : t =
let x1 = fresh () in
let x2 = fresh () in
fun k ->
let t : A.t = k $ Bop (x1, x2) in
let t : A.t = t2 (Let (x2, t)) in
let t : A.t = t1 (Let (x1, t)) in
t
let let_ x (t1 : t) (t2 : t) : t =
fun k ->
let t : A.t = t1 (Let (x, t2 k)) in
t
let if0 (t1 : t) (t2 : t) (t3 : t) : t =
let x1 = fresh () in
function
| k when is_copyable_cont k ->
let t : A.t = If0 (x1, t2 k, t3 k) in
let t : A.t = t1 (Let (x1, t)) in
t
| k ->
let j = fresh () in
let x0 = fresh () in
let t0 = k $ Var x0 in
let k : cont = Join j in
let t : A.t = If0 (x1, t2 k, t3 k) in
let t : A.t = t1 (Let (x1, t)) in
let t : A.t = LetJoin (j, x0, t0, t) in
t
let rec lower : S.t -> t = function
| Var x -> var x
| Int i -> int i
| Bop (s1, s2) -> bop (lower s1) (lower s2)
| Let (x, s1, s2) -> let_ x (lower s1) (lower s2)
| If0 (s1, s2, s3) -> if0 (lower s1) (lower s2) (lower s3)
let lower (s : S.t) : A.t =
lower s Return
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment