Last active
August 9, 2025 02:06
-
-
Save dariusf/6dd5917e69cf34ecc417a05824aeb4ad to your computer and use it in GitHub Desktop.
Shift/reset monad with ATM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| -- h ttps://stackoverflow.com/questions/72833519/how-to-extract-delimited-continuation-reset-shift-for-future-use-in-haskell/72836141#72836141 | |
| {-# LANGUAGE RebindableSyntax #-} | |
| import Data.String | |
| import Prelude hiding ((>>=), return) | |
| newtype Cont i o a = Cont { runCont :: (a -> i) -> o } | |
| return :: a -> Cont r r a | |
| return x = Cont ($ x) | |
| (>>=) :: Cont m o a -> (a -> Cont i m b) -> Cont i o b | |
| Cont x >>= f = Cont $ \k -> x (($ k) . runCont . f) | |
| (>>) :: Cont m o a -> Cont i m b -> Cont i o b | |
| a >> b = a >>= const b | |
| evalCont :: Cont a o a -> o | |
| evalCont (Cont x) = x id | |
| reset = evalCont | |
| shift = Cont | |
| printf :: Int -> String | |
| printf = reset $ do | |
| r <- shift $ \k -> k | |
| return $ "hello " ++ show r | |
| main :: IO () | |
| main = do | |
| putStrLn $ printf 3 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| module ShiftResetCont : sig | |
| type ('r, 'a) t = ('a -> 'r) -> 'r | |
| val return : 'a -> ('r, 'a) t | |
| val ( >>= ) : ('r, 'a) t -> ('a -> ('r, 'b) t) -> ('r, 'b) t | |
| val ( let* ) : ('r, 'a) t -> ('a -> ('r, 'b) t) -> ('r, 'b) t | |
| val eval_cont : ('a, 'a) t -> 'a | |
| val shift : (('a -> 'r) -> 'r) -> ('r, 'a) t | |
| val reset : ('a, 'a) t -> 'a | |
| val map : ('a -> 'b) -> ('r, 'a) t -> ('r, 'b) t | |
| val ( <$> ) : ('a -> 'b) -> ('r, 'a) t -> ('r, 'b) t | |
| end = struct | |
| type ('r, 'a) t = ('a -> 'r) -> 'r | |
| let return a k = k a | |
| let ( >>= ) m k c = m (fun x -> (k x) c) | |
| let[@inline] ( let* ) m k = m >>= k | |
| let eval_cont f = f (fun x -> x) | |
| let shift k = k | |
| let reset f = eval_cont f | |
| (* this version supports nesting shift, | |
| but it's the same as wrapping with a reset *) | |
| (* let shift f k = eval_cont (f k) *) | |
| let map f m = | |
| let* a = m in | |
| return (f a) | |
| let[@inline] ( <$> ) f m = map f m | |
| end | |
| (* shift k. k cannot be typed *) | |
| (* open ShiftResetCont *) | |
| (* let f a = shift (fun k -> k) a *) | |
| (* https://stackoverflow.com/questions/72833519 *) | |
| module ShiftReset : sig | |
| type ('i, 'o, 'a) t = ('a -> 'i) -> 'o | |
| val return : 'a -> ('r, 'r, 'a) t | |
| val ( >>= ) : ('m, 'o, 'a) t -> ('a -> ('i, 'm, 'b) t) -> ('i, 'o, 'b) t | |
| val ( let* ) : ('m, 'o, 'a) t -> ('a -> ('i, 'm, 'b) t) -> ('i, 'o, 'b) t | |
| val eval_cont : ('a, 'o, 'a) t -> 'o | |
| val shift : (('a -> 'i) -> 'o) -> ('i, 'o, 'a) t | |
| val reset : ('a, 'o, 'a) t -> 'o | |
| val map : ('a -> 'b) -> ('c, 'd, 'a) t -> ('c, 'd, 'b) t | |
| val ( <$> ) : ('a -> 'b) -> ('c, 'd, 'a) t -> ('c, 'd, 'b) t | |
| end = struct | |
| type ('i, 'o, 'a) t = ('a -> 'i) -> 'o | |
| let return a k = k a | |
| let ( >>= ) m k c = m (fun x -> (k x) c) | |
| let[@inline] ( let* ) m k = m >>= k | |
| let eval_cont f = f (fun x -> x) | |
| let shift k = k | |
| let reset f = eval_cont f | |
| (* this version supports nesting shift, | |
| but it's the same as wrapping with a reset *) | |
| (* let shift f k = eval_cont (f k) *) | |
| let map f m = | |
| let* a = m in | |
| return (f a) | |
| let[@inline] ( <$> ) f m = map f m | |
| end | |
| open ShiftReset | |
| let test1 = | |
| reset | |
| (let* r = | |
| shift (fun k1 -> | |
| reset | |
| (let* r1 = shift (fun k2 -> 100 + k2 50) in | |
| return (3 + r1))) | |
| in | |
| return (1 + r + 10000)) | |
| let () = Format.printf "test1 %b@." (test1 = 153) | |
| let test2 = | |
| reset | |
| (let* r = shift (fun k -> k) in | |
| return (1 + r)) | |
| 0 | |
| let () = Format.printf "test2 %b@." (test2 = 1) | |
| let test3 = | |
| reset | |
| (let* r = | |
| shift (fun k -> | |
| let x = k 10 in | |
| let y = k 10 in | |
| x + y) | |
| in | |
| return (1 + r)) | |
| let () = Format.printf "test3 %b@." (test3 = 22) | |
| let printf = | |
| reset | |
| (let* r = shift (fun k -> k) in | |
| return (r ^ "!")) | |
| "a" | |
| let () = Format.printf "printf %b@." (printf = "a!") | |
| let rec append_aux x = | |
| match x with | |
| | [] -> shift (fun k -> k) | |
| | x1 :: xs -> | |
| let* xs = append_aux xs in | |
| return (x1 :: xs) | |
| let append x y = reset (append_aux x) y | |
| let () = Format.printf "append %b@." (append [1; 2] [3; 4] = [1; 2; 3; 4]) | |
| let xpl l = | |
| let rec aux (l : int list) : (int list, int list list, int list) t = | |
| match l with | |
| | [] -> shift (fun c -> []) | |
| | n :: ns' -> | |
| let* a = shift (fun k -> k [] :: reset (k <$> aux ns')) in | |
| return (n :: a) | |
| in | |
| reset (aux l) | |
| let () = Format.printf "xpl %b@." (xpl [0; 1; 2] = [[0]; [0; 1]; [0; 1; 2]]) | |
| let original_prefixes l = | |
| let rec aux (l : int list) : (int list, int list list, int list) t = | |
| match l with | |
| | [] -> shift (fun c -> []) | |
| | n :: ns' -> | |
| shift (fun k -> k [] :: reset (k <$> ((fun r -> n :: r) <$> aux ns'))) | |
| in | |
| reset (aux l) | |
| let () = | |
| Format.printf "original_prefixes %b@." | |
| (original_prefixes [0; 1; 2] = [[]; [0]; [0; 1]]) | |
| let prefixes l = | |
| let rec aux (l : int list) : (int list, int list list, int list) t = | |
| shift (fun c -> | |
| c [] | |
| :: | |
| (match l with | |
| | [] -> [] | |
| | n :: ns' -> reset (c <$> ((fun r -> n :: r) <$> aux ns')))) | |
| in | |
| reset (aux l) | |
| let () = | |
| Format.printf "prefixes %b@." | |
| (prefixes [0; 1; 2] = [[]; [0]; [0; 1]; [0; 1; 2]]) | |
| let length_prefixes l = | |
| let rec aux (l : int list) : (int list, int, int list) t = | |
| shift (fun c -> | |
| List.length (c []) | |
| + | |
| match l with | |
| | [] -> 0 | |
| | n :: ns' -> reset (c <$> ((fun r -> n :: r) <$> aux ns'))) | |
| in | |
| reset (aux l) | |
| let string_of_list f xs = | |
| let s = List.map f xs |> String.concat "; " in | |
| Format.sprintf "[%s]" s | |
| let () = Format.printf "length_prefixes %b@." (length_prefixes [0; 1; 2] = 6) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment