Skip to content

Instantly share code, notes, and snippets.

@Guest0x0
Last active December 28, 2024 12:53
Show Gist options
  • Save Guest0x0/98bbff786b955435f44a7c5e933386d2 to your computer and use it in GitHub Desktop.
Save Guest0x0/98bbff786b955435f44a7c5e933386d2 to your computer and use it in GitHub Desktop.
demonstration of the refocusing technique for transforming small step reduction to big step interpreter, based on <https://arxiv.org/pdf/2302.10455>
type op = Add | Sub
type expr =
| Lit of int
| Op of expr * op * expr
type result =
| Val of int
| Term of expr
| Stuck
(* plain small-step reduction implemented with structural recursion *)
module V1 = struct
let rec reduce expr =
match expr with
| Lit i -> Val i
| Op (l, op, r) ->
match reduce l with
| Val x ->
(match reduce r with
| Val y ->
(match op with
| Add -> Term (Lit (x + y))
| Sub ->
if x < y
then Stuck
else Term (Lit (x - y)))
| Term r' -> Term (Op (Lit x, op, r'))
| Stuck -> Stuck)
| Term l' -> Term (Op (l', op, r))
| Stuck -> Stuck
let rec normalize expr =
match reduce expr with
| Val v -> Some v
| Stuck -> None
| Term t -> normalize t
end
(* CPS *)
module V2 = struct
let rec reduce expr k =
match expr with
| Lit i -> k (Val i)
| Op (l, op, r) ->
reduce l (function
| Val x ->
reduce r (function
| Val y ->
(match op with
| Add -> k (Term (Lit (x + y)))
| Sub ->
if x < y
then k Stuck
else k (Term (Lit (x - y))))
| Term r' -> k (Term (Op (Lit x, op, r')))
| Stuck -> k Stuck)
| Term l' -> k (Term (Op (l', op, r)))
| Stuck -> k Stuck)
let rec normalize expr =
match reduce expr (fun x -> x) with
| Val v -> Some v
| Stuck -> None
| Term t' -> normalize t'
end
(* split the val/term continuation with the stuck continuation,
using the equation
(A + B) -> R = (A -> R) * (B -> R)
*)
type vot = V of int | T of expr
module V3 = struct
let rec reduce expr k ks =
match expr with
| Lit i -> k (V i)
| Op (l, op, r) ->
reduce
l
(function
| V x ->
reduce
r
(function
| V y ->
(match op with
| Add -> k (T (Lit (x + y)))
| Sub ->
if x < y
then ks ()
else k (T (Lit (x - y))))
| T r' -> k (T (Op (Lit x, op, r'))))
ks
| T l' -> k (T (Op (l', op, r))))
ks
let rec normalize expr =
match reduce expr (function V v -> Val v | T t -> Term t) (fun () -> Stuck) with
| Val v -> Some v
| Stuck -> None
| Term t' -> normalize t'
end
(* notice that [ks] is always the same,
so inline [ks] here and introduce discontinuity *)
module V3_1 = struct
let rec reduce expr k =
match expr with
| Lit i -> k (V i)
| Op (l, op, r) ->
reduce
l
(function
| V x ->
reduce
r
(function
| V y ->
(match op with
| Add -> k (T (Lit (x + y)))
| Sub ->
if x < y
then Stuck (* ===== CHANGES HERE ===== *)
else k (T (Lit (x - y))))
| T r' -> k (T (Op (Lit x, op, r'))))
| T l' -> k (T (Op (l', op, r))))
let rec normalize expr =
match reduce expr (function V v -> Val v | T t -> Term t) with
| Val v -> Some v
| Stuck -> None
| Term t' -> normalize t'
end
(* defunctionalize the continuation *)
module V4 = struct
type frame =
| L of op * expr
| R of int * op
let rec reduce expr k =
match expr with
| Lit i -> apply k (V i)
| Op (l, op, r) -> reduce l (L (op, r) :: k)
and apply k v =
match k with
| [] ->
(match v with
| V v -> Val v
| T t -> Term t)
| L (op, r) :: k ->
(match v with
| V x -> reduce r (R (x, op) :: k)
| T l' -> apply k (T (Op (l', op, r))))
| R (x, op) :: k ->
match v with
| V y ->
(match op with
| Add -> apply k (T (Lit (x + y)))
| Sub ->
if x < y
then Stuck
else apply k (T (Lit (x - y))))
| T r' -> apply k (T (Op (Lit x, op, r')))
and normalize expr =
match reduce expr [] with
| Val v -> Some v
| Stuck -> None
| Term t' -> normalize t'
end
(* Next, for clarity, we split the [apply] function for [V] case and [T] case.
Note that this is DIFFERENT from splitting the value/term continuation early at [V3] above:
we still have only *ONE* defunctionalized continuation, just two apply functions for it.
*)
module V4_1 = struct
type frame =
| L of op * expr
| R of int * op
(* [apply k (T t) *)
let rec apply_t k t =
match k with
| [] ->
(* we also do fixedpoint demotion here, move the [Term] constructor to [apply_v] *)
t
| L (op, r) :: k -> apply_t k (Op (t, op, r))
| R (x, op) :: k -> apply_t k (Op (Lit x, op, t))
let rec reduce expr k =
match expr with
| Lit i -> apply_v k i
| Op (l, op, r) -> reduce l (L (op, r) :: k)
(* apply k [V v] *)
and apply_v k v =
match k with
| [] -> Val v
| L (op, r) :: k -> reduce r (R (v, op) :: k)
| R (x, op) :: k ->
match op with
| Add -> Term (apply_t k (Lit (x + v)))
| Sub ->
if x < v
then Stuck
else Term (apply_t k (Lit (x - v)))
and normalize expr =
match reduce expr [] with
| Val v -> Some v
| Stuck -> None
| Term t' -> normalize t'
end
(* Next, if we perform lightweight fixedpoint demotion to [reduce] and [apply_v],
we can get the decompose/contract/recompose structure of small step reduction *)
module V4_2 = struct
type frame =
| L of op * expr
| R of int * op
type decomposition =
| Value of int
| Decomp of int * op * int * frame list
(* [apply_t], renamed to [recompose] *)
let rec recompose k t =
match k with
| [] -> t
| L (op, r) :: k -> recompose k (Op (t, op, r))
| R (x, op) :: k -> recompose k (Op (Lit x, op, t))
(* [reduce], rename to [decompose] *)
let rec decompose expr k =
match expr with
| Lit i -> apply_v k i
| Op (l, op, r) -> decompose l (L (op, r) :: k)
and apply_v k v =
match k with
| [] -> Value v
| L (op, r) :: k -> decompose r (R (v, op) :: k)
| R (x, op) :: k -> Decomp (x, op, v, k)
and normalize expr =
match decompose expr [] with
| Value v -> Some v
| Decomp (x, op, y, k) ->
(* simplified using case-of-case optimization here.
originally it should be
match
(match op with
| Add -> Term (recompose k ...)
| Sub -> ...)
with
| Value v -> Some v
| Stuck -> None
| Term t -> normalize t
*)
match op with
| Add -> normalize (recompose k (Lit (x + y)))
| Sub ->
if x < y
then None
else normalize (recompose k (Lit (x - y)))
end
(* notice that the input of [normalize] is always supplied to [decompose],
so we can lift the [decompose] call outside [normalize],
this way we get the [decompose (recompose ...)] structure in [normalize]
*)
module V4_3 = struct
type frame =
| L of op * expr
| R of int * op
type decomposition =
| Value of int
| Decomp of int * op * int * frame list
let rec recompose k t =
match k with
| [] -> t
| L (op, r) :: k -> recompose k (Op (t, op, r))
| R (x, op) :: k -> recompose k (Op (Lit x, op, t))
let rec decompose expr k =
match expr with
| Lit i -> apply_v k i
| Op (l, op, r) -> decompose l (L (op, r) :: k)
and apply_v k v =
match k with
| [] -> Value v
| L (op, r) :: k -> decompose r (R (v, op) :: k)
| R (x, op) :: k -> Decomp (x, op, v, k)
let rec normalize_decomp decomp =
match decomp with
| Value v -> Some v
| Decomp (x, op, y, k) ->
match op with
| Add -> normalize_decomp (decompose (recompose k (Lit (x + y))) [])
| Sub ->
if x < y
then None
else normalize_decomp (decompose (recompose k (Lit (x - y))) [])
let normalize expr = normalize_decomp (decompose expr [])
end
(* The [decompose (recompose ...)] pattern is inefficient here:
[recompose] build up a term, but [decompose] immediately breaks it down.
We want to introduce a function [refocus t k], such that:
refocus t k = decompose (recompose k t) []
Meanwhile, we wish that [refocus] will not build up an intermediate term.
In fact, we can *calculate* the definition of [refocus] by induction on [k]:
(1) case [k = []]:
refocus t [] = decompose (recompose k t) [] = decompose t []
this equation will help us fill in some cases later.
(2) case [k = L (op, r) :: k']:
refocus t k
= decompose (recompose k t) []
= decompose (recompose k' (Op (t, op, r))) []
= refocus (Op (t, op, r)) k' (IH)
(3) case [k = R (x, op) :: k']:
refocus t k
= decompose (recompose k t) []
= decompose (recompose k' (Op (Lit x, op, t))) []
= refocus (Op (Lit x, op, t)) k' (IH)
= refocus (Lit x) (L (op, t) :: k') (2), reading backwards
we are not *defining* [refocus] by induction on [k] here,
instead, we are using induction to generate equations about [refocus].
So when actually defining [refocus], we can read (2) and (3) above *backwards*.
Now we can almost get the definition of [refocus]:
let rec refocus t k =
match t with
| Lit v ->
(match k with
| [] ->
(* this case is determined by case (1) above *)
Value v
| L (op, r) :: k' ->
(* this case is determined by case (3) above, reading backwards *)
refocus r (R (v, op) :: k')
| R (x, op) :: k' ->
(* we cannot determine this case yet *)
???)
| Op (l, op, r) ->
(* this case is determined by case (2) above, reading backwards. *)
refocus l (L (op, r) :: k)
To fill in [???], the last piece of missing equation is
[refocus t [] = decompose t []] for [t = Op (l, op, r)].
But calculating [???] directly from this equation is difficult.
Notice that [refocus] is exactly the same as [decompose] everywhere except [???],
so a good guess would be:
refocus t k = decompose t k
This is indeed the case, by simple induction on [k], we can prove that:
decompose t k = decompose (recompose k t) []
(1) case [k = []]: immediate
(2) case [k = L (op, r) :: k']:
right
= decompose (recompose k' (Op (t, op, r))) []
= decompose (Op (t, op, r)) k' (IH)
= decompose t (L (op, r) :: k')
= left
(3) case [k = R (x, op) :: k']:
right
= decompose (recompose k' (Op (Lit x, op, t))) []
= decompose (Op (Lit x, op, t)) k' (IH)
= decompose (Lit x) (L (op, t) :: k')
= apply_v (L (op, t) :: k') x
= decompose t (R (x, op) :: k')
= left
*)
(* utilize the refocus theorem, and remove the need for [recompose] *)
module V5 = struct
type frame =
| L of op * expr
| R of int * op
type decomposition =
| Value of int
| Decomp of int * op * int * frame list
let rec decompose expr k =
match expr with
| Lit i -> apply_v k i
| Op (l, op, r) -> decompose l (L (op, r) :: k)
and apply_v k v =
match k with
| [] -> Value v
| L (op, r) :: k -> decompose r (R (v, op) :: k)
| R (x, op) :: k -> Decomp (x, op, v, k)
let rec normalize_decomp decomp =
match decomp with
| Value v -> Some v
| Decomp (x, op, y, k) ->
match op with
| Add -> normalize_decomp (decompose (Lit (x + y)) k)
| Sub ->
if x < y
then None
else normalize_decomp (decompose (Lit (x - y)) k)
let normalize expr = normalize_decomp (decompose expr [])
end
(* performing lightweight fixedpoint promotion,
we can get directly get a reduction-free normalization function *)
module V6 = struct
type frame =
| L of op * expr
| R of int * op
(* [decompose], with [normalize_decomp] pushed into base cases *)
let rec normalize_aux expr k =
match expr with
| Lit i -> apply_v k i
| Op (l, op, r) -> normalize_aux l (L (op, r) :: k)
and apply_v k v =
match k with
| [] -> Some v
| L (op, r) :: k -> normalize_aux r (R (v, op) :: k)
| R (x, op) :: k ->
match op with
| Add -> normalize_aux (Lit (x + v)) k
| Sub ->
if x < v
then None
else normalize_aux (Lit (x - v)) k
let normalize expr = normalize_aux expr []
end
(* refunctionalize the continuation *)
module V7 = struct
let rec normalize_aux expr k =
match expr with
| Lit i -> k i
| Op (l, op, r) ->
normalize_aux l (fun x ->
normalize_aux r (fun y ->
match op with
| Add ->
(* [normalize_aux (Lit (x + y)) k], inlined *)
k (x + y)
| Sub ->
if x < y
then None
else
(* [normalize_aux (Lit (x - y)) k], inlined *)
k (x - y)))
let normalize expr = normalize_aux expr (fun v -> Some v)
end
(* back to direct style. Now we get a big step interpreter *)
module V8 = struct
exception Overflow
let rec normalize_aux expr =
match expr with
| Lit i -> i
| Op (l, op, r) ->
let x = normalize_aux l in
let y = normalize_aux r in
match op with
| Add -> x + y
| Sub ->
if x < y
then raise Overflow
else x - y
and normalize expr = try Some (normalize_aux expr) with Overflow -> None
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment