Skip to content

Instantly share code, notes, and snippets.

@dariusf
Last active August 9, 2025 02:06
Show Gist options
  • Save dariusf/6dd5917e69cf34ecc417a05824aeb4ad to your computer and use it in GitHub Desktop.
Save dariusf/6dd5917e69cf34ecc417a05824aeb4ad to your computer and use it in GitHub Desktop.
Shift/reset monad with ATM
-- h ttps://stackoverflow.com/questions/72833519/how-to-extract-delimited-continuation-reset-shift-for-future-use-in-haskell/72836141#72836141
{-# LANGUAGE RebindableSyntax #-}
import Data.String
import Prelude hiding ((>>=), return)
newtype Cont i o a = Cont { runCont :: (a -> i) -> o }
return :: a -> Cont r r a
return x = Cont ($ x)
(>>=) :: Cont m o a -> (a -> Cont i m b) -> Cont i o b
Cont x >>= f = Cont $ \k -> x (($ k) . runCont . f)
(>>) :: Cont m o a -> Cont i m b -> Cont i o b
a >> b = a >>= const b
evalCont :: Cont a o a -> o
evalCont (Cont x) = x id
reset = evalCont
shift = Cont
printf :: Int -> String
printf = reset $ do
r <- shift $ \k -> k
return $ "hello " ++ show r
main :: IO ()
main = do
putStrLn $ printf 3
module ShiftResetCont : sig
type ('r, 'a) t = ('a -> 'r) -> 'r
val return : 'a -> ('r, 'a) t
val ( >>= ) : ('r, 'a) t -> ('a -> ('r, 'b) t) -> ('r, 'b) t
val ( let* ) : ('r, 'a) t -> ('a -> ('r, 'b) t) -> ('r, 'b) t
val eval_cont : ('a, 'a) t -> 'a
val shift : (('a -> 'r) -> 'r) -> ('r, 'a) t
val reset : ('a, 'a) t -> 'a
val map : ('a -> 'b) -> ('r, 'a) t -> ('r, 'b) t
val ( <$> ) : ('a -> 'b) -> ('r, 'a) t -> ('r, 'b) t
end = struct
type ('r, 'a) t = ('a -> 'r) -> 'r
let return a k = k a
let ( >>= ) m k c = m (fun x -> (k x) c)
let[@inline] ( let* ) m k = m >>= k
let eval_cont f = f (fun x -> x)
let shift k = k
let reset f = eval_cont f
(* this version supports nesting shift,
but it's the same as wrapping with a reset *)
(* let shift f k = eval_cont (f k) *)
let map f m =
let* a = m in
return (f a)
let[@inline] ( <$> ) f m = map f m
end
(* shift k. k cannot be typed *)
(* open ShiftResetCont *)
(* let f a = shift (fun k -> k) a *)
(* https://stackoverflow.com/questions/72833519 *)
module ShiftReset : sig
type ('i, 'o, 'a) t = ('a -> 'i) -> 'o
val return : 'a -> ('r, 'r, 'a) t
val ( >>= ) : ('m, 'o, 'a) t -> ('a -> ('i, 'm, 'b) t) -> ('i, 'o, 'b) t
val ( let* ) : ('m, 'o, 'a) t -> ('a -> ('i, 'm, 'b) t) -> ('i, 'o, 'b) t
val eval_cont : ('a, 'o, 'a) t -> 'o
val shift : (('a -> 'i) -> 'o) -> ('i, 'o, 'a) t
val reset : ('a, 'o, 'a) t -> 'o
val map : ('a -> 'b) -> ('c, 'd, 'a) t -> ('c, 'd, 'b) t
val ( <$> ) : ('a -> 'b) -> ('c, 'd, 'a) t -> ('c, 'd, 'b) t
end = struct
type ('i, 'o, 'a) t = ('a -> 'i) -> 'o
let return a k = k a
let ( >>= ) m k c = m (fun x -> (k x) c)
let[@inline] ( let* ) m k = m >>= k
let eval_cont f = f (fun x -> x)
let shift k = k
let reset f = eval_cont f
(* this version supports nesting shift,
but it's the same as wrapping with a reset *)
(* let shift f k = eval_cont (f k) *)
let map f m =
let* a = m in
return (f a)
let[@inline] ( <$> ) f m = map f m
end
open ShiftReset
let test1 =
reset
(let* r =
shift (fun k1 ->
reset
(let* r1 = shift (fun k2 -> 100 + k2 50) in
return (3 + r1)))
in
return (1 + r + 10000))
let () = Format.printf "test1 %b@." (test1 = 153)
let test2 =
reset
(let* r = shift (fun k -> k) in
return (1 + r))
0
let () = Format.printf "test2 %b@." (test2 = 1)
let test3 =
reset
(let* r =
shift (fun k ->
let x = k 10 in
let y = k 10 in
x + y)
in
return (1 + r))
let () = Format.printf "test3 %b@." (test3 = 22)
let printf =
reset
(let* r = shift (fun k -> k) in
return (r ^ "!"))
"a"
let () = Format.printf "printf %b@." (printf = "a!")
let rec append_aux x =
match x with
| [] -> shift (fun k -> k)
| x1 :: xs ->
let* xs = append_aux xs in
return (x1 :: xs)
let append x y = reset (append_aux x) y
let () = Format.printf "append %b@." (append [1; 2] [3; 4] = [1; 2; 3; 4])
let xpl l =
let rec aux (l : int list) : (int list, int list list, int list) t =
match l with
| [] -> shift (fun c -> [])
| n :: ns' ->
let* a = shift (fun k -> k [] :: reset (k <$> aux ns')) in
return (n :: a)
in
reset (aux l)
let () = Format.printf "xpl %b@." (xpl [0; 1; 2] = [[0]; [0; 1]; [0; 1; 2]])
let original_prefixes l =
let rec aux (l : int list) : (int list, int list list, int list) t =
match l with
| [] -> shift (fun c -> [])
| n :: ns' ->
shift (fun k -> k [] :: reset (k <$> ((fun r -> n :: r) <$> aux ns')))
in
reset (aux l)
let () =
Format.printf "original_prefixes %b@."
(original_prefixes [0; 1; 2] = [[]; [0]; [0; 1]])
let prefixes l =
let rec aux (l : int list) : (int list, int list list, int list) t =
shift (fun c ->
c []
::
(match l with
| [] -> []
| n :: ns' -> reset (c <$> ((fun r -> n :: r) <$> aux ns'))))
in
reset (aux l)
let () =
Format.printf "prefixes %b@."
(prefixes [0; 1; 2] = [[]; [0]; [0; 1]; [0; 1; 2]])
let length_prefixes l =
let rec aux (l : int list) : (int list, int, int list) t =
shift (fun c ->
List.length (c [])
+
match l with
| [] -> 0
| n :: ns' -> reset (c <$> ((fun r -> n :: r) <$> aux ns')))
in
reset (aux l)
let string_of_list f xs =
let s = List.map f xs |> String.concat "; " in
Format.sprintf "[%s]" s
let () = Format.printf "length_prefixes %b@." (length_prefixes [0; 1; 2] = 6)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment