Skip to content

Instantly share code, notes, and snippets.

@dariusf
Last active October 26, 2025 16:16
Show Gist options
  • Save dariusf/fd9a93de6b9e20dd1cf2f64b146b8e53 to your computer and use it in GitHub Desktop.
Save dariusf/fd9a93de6b9e20dd1cf2f64b146b8e53 to your computer and use it in GitHub Desktop.
Genuine shift/reset
(* Based on
https://okmij.org/ftp/continuations/implementations.html
https://okmij.org/ftp/Haskell/ShiftResetGenuine.hs
https://www.cs.tsukuba.ac.jp/~kam/paper/aplas07.pdf *)
(* Parameterised monads *)
module type GM = sig
type ('i, 'o, 'a) t
(* pure values are polymorphic in the answer type *)
val ret : 'tau -> ('a, 'a, 'tau) t
(* bind composes changes in answer type *)
val bind :
('b, 'g, 'sigma) t -> ('sigma -> ('a, 'b, 'tau) t) -> ('a, 'g, 'tau) t
end
module ShiftReset : sig
include GM
(** The type of functions that can cause ATM [sigma/a -> tau/b] *)
type ('sigma, 'a, 'tau, 'b) fn = 'sigma -> ('a, 'b, 'tau) t
val reset : ('sigma, 'tau, 'sigma) t -> ('a, 'a, 'tau) t
(* the captured continuation is not just answer-type polymorphic, but pure *)
val shift : (('tau -> 'a) -> ('s, 'b, 's) t) -> ('a, 'b, 'tau) t
val run : ('tau, 'tau, 'tau) t -> 'tau
end = struct
type ('i, 'o, 'a) t = C of (('a -> 'i) -> 'o) [@@unboxed]
type ('sigma, 'a, 'tau, 'b) fn = 'sigma -> ('a, 'b, 'tau) t
let unC (C f) = f
let ret x = C (fun k -> k x)
let bind (C f) h = C (fun k -> f (fun s -> unC (h s) k))
let id = Fun.id
let reset (C f) = C (fun k -> k (f id))
let shift f = C (fun k -> unC (f k) id)
let run (C f) = f id
end
open ShiftReset
let ( let* ) = bind
let rec append : ('a list, 'a list, 'a list, 'a list -> ('b, 'b, 'a list) t) fn
=
fun xs ->
match xs with
| [] -> shift (fun k -> ret (fun x -> ret (k x)))
| a :: rest ->
let* r' = append rest in
ret (a :: r')
let test_append =
run
(let* dl = reset (append [1; 2; 3]) in
dl [4; 5; 6])
let test_multishot =
run
@@ reset
(let* r = shift (fun k -> ret (k 1 + k 2)) in
ret (r + 1))
let test_multishot_atm =
run
begin
reset
(let* r = shift (fun k -> ret (k 1 ^ k 2)) in
ret (string_of_int r ^ "!"))
end
let test_printf =
run
(reset
(let* r = shift (fun k -> ret (fun v -> k (string_of_int v))) in
let* b = shift (fun k -> ret (fun v -> k (string_of_bool v))) in
ret (r ^ b ^ "!")))
1 true
(* https://www.cl.cam.ac.uk/teaching/2324/R277/handout-delimited-continuations.pdf *)
(* ⟨1 + ⟨(S k1. k1 100 + k1 10) + (S k2. S k3. 1)⟩⟩ *)
let test_shift_shift0_control =
let r =
run
@@ reset
(let* r =
reset
(let* a = shift (fun k1 -> ret (k1 100 + k1 10)) in
let* b = shift (fun k2 -> shift (fun k3 -> ret 1)) in
ret (a + b))
in
ret (1 + r))
in
match r with 3 -> "shift" | 1 -> "shift0" | 2 -> "control" | _ -> "??"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment