Last active
October 26, 2025 16:16
-
-
Save dariusf/fd9a93de6b9e20dd1cf2f64b146b8e53 to your computer and use it in GitHub Desktop.
Genuine shift/reset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| (* Based on | |
| https://okmij.org/ftp/continuations/implementations.html | |
| https://okmij.org/ftp/Haskell/ShiftResetGenuine.hs | |
| https://www.cs.tsukuba.ac.jp/~kam/paper/aplas07.pdf *) | |
| (* Parameterised monads *) | |
| module type GM = sig | |
| type ('i, 'o, 'a) t | |
| (* pure values are polymorphic in the answer type *) | |
| val ret : 'tau -> ('a, 'a, 'tau) t | |
| (* bind composes changes in answer type *) | |
| val bind : | |
| ('b, 'g, 'sigma) t -> ('sigma -> ('a, 'b, 'tau) t) -> ('a, 'g, 'tau) t | |
| end | |
| module ShiftReset : sig | |
| include GM | |
| (** The type of functions that can cause ATM [sigma/a -> tau/b] *) | |
| type ('sigma, 'a, 'tau, 'b) fn = 'sigma -> ('a, 'b, 'tau) t | |
| val reset : ('sigma, 'tau, 'sigma) t -> ('a, 'a, 'tau) t | |
| (* the captured continuation is not just answer-type polymorphic, but pure *) | |
| val shift : (('tau -> 'a) -> ('s, 'b, 's) t) -> ('a, 'b, 'tau) t | |
| val run : ('tau, 'tau, 'tau) t -> 'tau | |
| end = struct | |
| type ('i, 'o, 'a) t = C of (('a -> 'i) -> 'o) [@@unboxed] | |
| type ('sigma, 'a, 'tau, 'b) fn = 'sigma -> ('a, 'b, 'tau) t | |
| let unC (C f) = f | |
| let ret x = C (fun k -> k x) | |
| let bind (C f) h = C (fun k -> f (fun s -> unC (h s) k)) | |
| let id = Fun.id | |
| let reset (C f) = C (fun k -> k (f id)) | |
| let shift f = C (fun k -> unC (f k) id) | |
| let run (C f) = f id | |
| end | |
| open ShiftReset | |
| let ( let* ) = bind | |
| let rec append : ('a list, 'a list, 'a list, 'a list -> ('b, 'b, 'a list) t) fn | |
| = | |
| fun xs -> | |
| match xs with | |
| | [] -> shift (fun k -> ret (fun x -> ret (k x))) | |
| | a :: rest -> | |
| let* r' = append rest in | |
| ret (a :: r') | |
| let test_append = | |
| run | |
| (let* dl = reset (append [1; 2; 3]) in | |
| dl [4; 5; 6]) | |
| let test_multishot = | |
| run | |
| @@ reset | |
| (let* r = shift (fun k -> ret (k 1 + k 2)) in | |
| ret (r + 1)) | |
| let test_multishot_atm = | |
| run | |
| begin | |
| reset | |
| (let* r = shift (fun k -> ret (k 1 ^ k 2)) in | |
| ret (string_of_int r ^ "!")) | |
| end | |
| let test_printf = | |
| run | |
| (reset | |
| (let* r = shift (fun k -> ret (fun v -> k (string_of_int v))) in | |
| let* b = shift (fun k -> ret (fun v -> k (string_of_bool v))) in | |
| ret (r ^ b ^ "!"))) | |
| 1 true | |
| (* https://www.cl.cam.ac.uk/teaching/2324/R277/handout-delimited-continuations.pdf *) | |
| (* ⟨1 + ⟨(S k1. k1 100 + k1 10) + (S k2. S k3. 1)⟩⟩ *) | |
| let test_shift_shift0_control = | |
| let r = | |
| run | |
| @@ reset | |
| (let* r = | |
| reset | |
| (let* a = shift (fun k1 -> ret (k1 100 + k1 10)) in | |
| let* b = shift (fun k2 -> shift (fun k3 -> ret 1)) in | |
| ret (a + b)) | |
| in | |
| ret (1 + r)) | |
| in | |
| match r with 3 -> "shift" | 1 -> "shift0" | 2 -> "control" | _ -> "??" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment