Skip to content

Instantly share code, notes, and snippets.

@algebraic-dev
Last active April 14, 2025 12:46
Show Gist options
  • Save algebraic-dev/e56604e833186f7795c2b43e13bf167d to your computer and use it in GitHub Desktop.
Save algebraic-dev/e56604e833186f7795c2b43e13bf167d to your computer and use it in GitHub Desktop.
Alts for Async
import Std.Internal.Async
import Std.Internal.UV
import Std.Net.Addr
open Std.Internal.IO.Async.TCP.Socket
open Std.Internal.IO.Async.TCP
open Std.Internal.IO.Async
open Std.Net
-- All of these function are used to create the Monads.
@[inline]
def emptyTask : Task (Except IO.Error Unit) :=
Task.pure (Except.pure ())
@[inline]
def block (io : Task α) : IO α :=
pure io.get
@[inline]
def nextTask {α} (t : Task α) (f : α → BaseIO Unit) : BaseIO Unit :=
discard <| BaseIO.bindTask t (f · *> pure emptyTask)
namespace Naive
-- This is the most naive Monad that we should consider: It does not allow tail
-- recursion, so it creates a huge task stack when you use it.
def Async (α : Type) := BaseIO (AsyncTask α)
@[inline]
def Async.exec (x : Async α) : IO α := do
let task ← x.toIO
task.block
-- Instances
@[inline]
def bind {α β : Type} (x : BaseIO (AsyncTask α)) (f : α → BaseIO (AsyncTask β)) : BaseIO (AsyncTask β) := do
let task ← x
task.bindIO (BaseIO.toIO ∘ f)
@[inline]
def baseMonadLift {α : Type} (x : BaseIO α) : BaseIO (AsyncTask α) := do
AsyncTask.pure <$> x
@[inline]
def monadLift {α : Type} (x : IO α) : BaseIO (AsyncTask α) := do
let res ← x.toBaseIO
match res with
| .ok a => pure (AsyncTask.pure a)
| .error e => pure (Task.pure (Except.error e))
@[inline]
def tryCatch (t : Async α) (f : IO.Error → Async α) : BaseIO (AsyncTask α) := do
let task ← t.run
BaseIO.bindTask task fun
| .ok res => pure (AsyncTask.pure res)
| .error err => f err |>.run
instance : Monad Async where
pure x := EStateM.pure (AsyncTask.pure x)
bind := bind
instance : MonadLift BaseIO Async where
monadLift := baseMonadLift
instance : MonadLift IO Async where
monadLift := monadLift
instance : MonadExcept IO.Error Async where
throw e := EStateM.pure (Task.pure (Except.error e))
tryCatch := tryCatch
@[inline]
def fromIO (x : IO (AsyncTask α)) : BaseIO (AsyncTask α) := do
let res ← x.toBaseIO
match res with
| .ok result => pure result
| .error err => pure (Task.pure (Except.error err))
@[inline]
def wait (time : Std.Time.Millisecond.Offset) : Async Unit :=
fromIO <| Sleep.mk time >>= Sleep.wait
def await (task : IO (AsyncTask α)) : Async α := fun w =>
match task w with
| .ok result s => .ok result s
| .error err s => .ok (Task.pure (.error err)) s
def parallel (task : Async α) : Async Unit :=
discard <| task.toIO
def Async.toIO {α : Type} (task : Async α) : IO α := do
BaseIO.toIO task >>= AsyncTask.block
def block (task : Async α) : IO α :=
task.toIO
end Naive
namespace CPS
-- This is a more advanced Monad that allows tail recursion, so it does not
-- create a huge task stack.
def Async (α : Type) := (AsyncTask α → BaseIO Unit) → BaseIO Unit
def Async.toIO {α : Type} (x : Async α) : IO α := do
let promise : IO.Promise (Except IO.Error α) ← IO.Promise.new
x fun task => discard <| IO.bindTask task fun x => do
promise.resolve x
return AsyncTask.pure ()
IO.ofExcept promise.result!.get
@[inline]
def Async.toTask {α : Type} [Nonempty α] (x : Async α) : IO (Task (Except IO.Error α)) := do
let promise ← IO.Promise.new
x (nextTask · promise.resolve)
pure promise.result!
@[inline]
def Async.exec {α : Type} [Nonempty α] (x : Async α) : IO α := do
match (← block =<< x.toTask) with
| .ok res => pure res
| .error err => throw err
@[inline]
def await {α : Type} (task : IO (AsyncTask α)) : Async α :=
fun k => do
match ← task.toBaseIO with
| .ok result => k result
| .error err => k (Task.pure (.error err))
@[inline]
def parallel (task : Async α) : Async Unit :=
fun k => do
discard <| IO.asTask (task (fun _ => pure ()))
k emptyTask
-- Instances
instance : Functor Async where
map f x := fun k => x (fun t => k (t.map f))
instance : Monad Async where
pure a := fun k => k (AsyncTask.pure a)
bind x f := fun k => x fun t =>
nextTask t (fun
| .ok res => (f res) k
| .error err => k (Task.pure (.error err)))
instance : MonadLift IO Async where
monadLift io := fun k => do
match ← io.toBaseIO with
| .ok res => k (AsyncTask.pure res)
| .error err => k (Task.pure (.error err))
instance : MonadExcept IO.Error Async where
throw x := fun t => t (Task.pure (.error x))
tryCatch x f := fun k => x fun t =>
nextTask t fun
| .ok res => k (Task.pure (.ok res))
| .error err => (f err) k
@[inline]
def fromIO (x : IO (AsyncTask α)) : Async α := fun k => do
let res ← x.toBaseIO
match res with
| .ok result => k result
| .error err => k (Task.pure (.error err))
@[inline]
def wait (time : Std.Time.Millisecond.Offset) : Async Unit :=
fromIO <| Sleep.mk time >>= Sleep.wait
@[inline]
def block (t : Async α) : IO α := do
t.toIO
--
end CPS
namespace SuspendableNaive
inductive AsyncResult (ε σ α : Type u) where
| ok : α → σ → AsyncResult ε σ α
| error : ε → σ → AsyncResult ε σ α
| cont : Task (EStateM.Result ε σ α) → σ →AsyncResult ε σ α
@[inline]
def pure [Inhabited σ] (a : α) : AsyncResult ε σ α :=
AsyncResult.ok a default
@[inline]
def toTask : AsyncResult ε σ α → Task (EStateM.Result ε σ α)
| .ok a s => Task.pure (.ok a s)
| .error e s => Task.pure (.error e s)
| .cont c _ => c
@[inline]
partial def map {α β : Type} (x : AsyncResult ε σ α) (f : α → β) : AsyncResult ε σ β :=
match x with
| .ok a s => .ok (f a) s
| .error e s => .error e s
| .cont c s => .cont (c.map fun
| .ok a s => .ok (f a) s
| .error e s => .error e s) s
@[inline]
partial def bind {α β : Type} (x : AsyncResult ε σ α) (f : α → σ → AsyncResult ε σ β) : AsyncResult ε σ β :=
match x with
| .ok a s => f a s
| .error e s => .error e s
| .cont c s => .cont (c.bind fun result =>
match result with
| .ok a s =>
match (f a s) with
| .ok a s => Task.pure (.ok a s)
| .error e s => Task.pure (.error e s)
| .cont c _ => c
| .error e s=> Task.pure (.error e s)) s
@[inline]
partial def «suspend» {α β : Type} (x : AsyncResult ε σ α) (f : α → σ → AsyncResult ε σ β) : AsyncResult ε σ β :=
match x with
| .ok a s => f a s
| .error e s => .error e s
| .cont c s => .cont (c.bind fun result =>
match result with
| .ok a s =>
match (f a s) with
| .ok a s => Task.pure (.ok a s)
| .error e s => Task.pure (.error e s)
| .cont c _ => c
| .error e s=> Task.pure (.error e s)) s
def Async (α : Type) := IO.RealWorld → AsyncResult IO.Error IO.RealWorld α
def Async.toTask {α : Type} (x : Async α) : IO (Task (Except IO.Error α)) := fun w =>
match x w with
| .ok a s => .ok (Task.pure (Except.ok a)) s
| .error e s => .error e s
| .cont c s => .ok (c.map fun
| .ok a _ => .ok a
| .error a _ => .error a) s
def Async.toIO {α : Type} (x : Async α) : IO α := do
let task ← x.toTask
IO.ofExcept task.get
instance : Functor Async where
map f x := fun w => match x w with
| .ok a s => .ok (f a) s
| .error e s => .error e s
| .cont c s => .cont (c.map fun
| .ok a s => .ok (f a) s
| .error e s => .error e s) s
instance : Monad Async where
pure a := fun w => .ok a w
bind x f := fun w => bind (x w) f
instance : MonadLift IO Async where
monadLift io := fun w => match io w with
| .ok a w => .ok a w
| .error e w => .error e w
instance : MonadExcept IO.Error Async where
throw x := fun w => .error x w
tryCatch x f := fun w =>
match x w with
| .ok a s => .ok a s
| .error e w =>
match f e w with
| .ok a w => .ok a w
| .error e w => .error e w
| .cont c s => .cont (c.map fun
| .ok a s => .ok a s
| .error e s => .error e s) s
| .cont c s => .cont (c.bind fun result =>
match result with
| .ok a s => Task.pure (.ok a s)
| .error e s => Task.pure (.error e s)) s
instance : Suspend Async Async where
«suspend» x f := fun w => «suspend» (x w) f
@[inline]
def fromIO (x : IO (AsyncTask α)) : Async α := fun w =>
match x w with
| .ok a s => .cont (Task.map (fun
| Except.ok a => .ok a s
| Except.error a => .error a s) a) s
| .error e s => .error e s
@[inline]
def wait (time : Std.Time.Millisecond.Offset) : Async Unit :=
fromIO <| Sleep.mk time >>= Sleep.wait
@[inline]
def await (t : IO (AsyncTask α)) : Async α := fromIO t
@[inline]
def parallel (task : Async α) : Async Unit := fun w =>
match task w with
| .ok _ s => .ok () s
| .error e s => .error e s
| .cont c s =>
let op := BaseIO.bindTask c (fun t => do return (Task.pure t))
let .ok _ s' := op s
.ok () s'
@[inline]
def block (t : Async α) : IO α := do
let t ← t.toTask
let f ← IO.wait t
IO.ofExcept f
end SuspendableNaive
namespace SuspendableCPS
inductive AsyncResult (ε σ α : Type u) where
| ok : α → σ → AsyncResult ε σ α
| error : ε → σ → AsyncResult ε σ α
| cont : ((Task (Except ε α) → BaseIO Unit) → BaseIO Unit) → σ → AsyncResult ε σ α
@[inline]
def pure [Inhabited σ] (a : α) : AsyncResult ε σ α :=
AsyncResult.ok a default
def Async (α : Type) :=
IO.RealWorld → AsyncResult IO.Error IO.RealWorld α
@[inline]
def Excp.ofBaseIO (a : BaseIO α) : IO.RealWorld → Except ε α := fun w =>
match a w with
| .ok a _ => .ok a
@[inline]
def Async.ofBaseIO (a : BaseIO α) : Async α := fun w =>
match a w with
| .ok a s => .ok a s
@[inline]
def Async.ofIO (a : IO α) : Async α := fun w =>
match a w with
| .ok a s => .ok a s
| .error e s => .error e s
@[inline]
def Async.toTask (x : Async α) : BaseIO (Task (Except IO.Error α)) := fun w =>
let result := x w
match result with
| .ok a s => .ok (Task.pure (.ok a)) s
| .error e s => .ok (Task.pure (.error e)) s
| .cont c s =>
let op : BaseIO (Task (Except IO.Error α)) := do
let promise ← IO.Promise.new (α := Except IO.Error α)
c fun t =>
nextTask t fun res => do promise.resolve res
return promise.result!
op s
@[inline]
partial def bind {α β : Type} (x : Async α) (f : α → Async β) : Async β := fun w =>
match x w with
| .ok a s => f a s
| .error e s => .error e s
| .cont c s =>
let op : BaseIO (Except IO.Error α) := do
let promise ← IO.Promise.new (α := Except IO.Error α)
c (nextTask · promise.resolve)
return promise.result!.get
let .ok a s := op s
match a with
| .ok a => f a s
| .error e => .error e s
@[inline]
partial def «suspend» {α β : Type} (x : AsyncResult IO.Error IO.RealWorld α) (f : α → IO.RealWorld → AsyncResult IO.Error IO.RealWorld β) : AsyncResult IO.Error IO.RealWorld β :=
match x with
| .ok a s => f a s
| .error e s => .error e s
| .cont x s => .cont (fun k s => x (fun t =>
nextTask t fun
| .ok res =>
match (f res) s with
| .ok a _ => k (Task.pure (.ok a))
| .error e _ => k (Task.pure (.error e))
| .cont c _ => c k
| .error err => k (Task.pure (.error err))
) s
) s
@[inline]
def Async.toIO {α : Type} (x : Async α) : IO α := do
let task ← x.toTask
IO.ofExcept task.get
instance : Monad Async where
pure a := fun w => .ok a w
bind := bind
instance : MonadLift IO Async where
monadLift io := fun w => match io w with
| .ok a w => .ok a w
| .error e w => .error e w
instance : Suspend Async Async where
«suspend» x f := fun w => «suspend» (x w) f
instance : Suspend (IO ∘ AsyncTask) Async where
«suspend» x f := fun w =>
match x w with
| .ok t s => .cont (fun k => do
nextTask t (fun result =>
match result with
| .ok a => do
let .ok res _ := (Async.toTask <| f a) s
k res
| .error e => k (Task.pure (.error e))
)
) s
| .error e s => .error e s
@[inline]
def fromIO (x : IO (AsyncTask α)) : Async α := fun w =>
match x w with
| .ok a s => .cont (· a) s
| .error e s => .error e s
@[inline]
def wait (time : Std.Time.Millisecond.Offset) : Async Unit :=
fromIO <| Sleep.mk time >>= Sleep.wait
def await (t : IO (AsyncTask α)) : Async α := fun w =>
AsyncResult.cont (fun k => do
let result ← t.toBaseIO
match result with
| .ok a => k a
| .error e => k (Task.pure (.error e))
) w
def parallel (task : Async α) : Async Unit := fun w =>
AsyncResult.cont (fun k => do
discard task.toTask
k emptyTask
) w
def block (t : Async α) : IO α := do
let task ← t.toTask
IO.ofExcept =<< IO.wait task
end SuspendableCPS
namespace MacMaybeTask
namespace MaybeTask
inductive MaybeTask (α : Type u)
| protected pure (a : α)
| ofTask (a : Task α)
@[inline]
def joinTask (t : Task (MaybeTask α)) (prio := Task.Priority.default) : Task α := t.bind (prio := prio) (sync := true) fun | .pure a => .pure a | .ofTask t => t
def AsyncIO (α : Type) := BaseIO (MaybeTask α)
@[inline]
def AsyncIO.mk (x : BaseIO (MaybeTask α)) : AsyncIO α := x
def durat := 600000
-- 101 lines of IR
partial def a (n : Nat) : IO Unit := do
IO.println "a"
IO.println "b"
IO.println "c"
let _ := IO.sleep 0
IO.println s!"d {n}"
IO.println "d"
if n < durat then
a (n + 1)
-- 397 lines of IR
open Naive in
partial def b (n : Nat) : Async Unit := do
IO.println "a"
IO.println "b"
IO.println "c"
let _ := wait 0
IO.println s!"d {n}"
IO.println "d"
if n < durat then
b (n + 1)
-- 832 lines of IR
open CPS in
partial def c (n : Nat) : Async Unit := do
IO.println "a"
IO.println "b"
IO.println "c"
let _ := wait 0
IO.println s!"d {n}"
IO.println "d"
if n < durat then
c (n + 1)
-- 233 lines of IR
-- The bind behavior on the monad causes it to become blocking.
open SuspendableNaive in
partial def d (n : Nat) : Async Unit := do
IO.println "a"
IO.println "b"
IO.println "c"
let _ ← wait 0
IO.println s!"d {n}"
if n < durat then
d (n + 1)
-- 460 lines of IR
-- The bind behavior on the monad causes it to become blocking.
open SuspendableCPS in
partial def e (n : Nat) : Async Unit := do
IO.println "a"
IO.println "b"
IO.println "c"
let _ ← wait 0
IO.println s!"d {n}"
if n < durat then
e (n + 1)
-- TCP
namespace Test
open SuspendableCPS
partial def writeLoop (client : Socket.Client) (message : String) : Async Unit := do
IO.println s!"write loop: {message}"
suspend await <| client.send (String.toUTF8 message)
if let none := suspend await <| client.recv? 1024 then
IO.println "client disconnected from receiving"
else
suspend writeLoop client message
def echoClient (addr : SocketAddress) (message : String) : Async Unit := do
let socket ← Client.mk
suspend await <| socket.connect addr
let t ← socket.getPeerName
suspend writeLoop socket message
partial def echoLoop (client : Socket.Client) : Async Unit := do
let message := suspend await <| client.recv? 1024
IO.println s!"received: {String.fromUTF8! <$> message}"
if let some msg := message then
let _ := suspend await <| client.send msg
suspend echoLoop client
else
IO.println "client disconnected from echoing"
partial def acceptLoop (server : Socket.Server) : Async Unit := do
let client := suspend await (server.accept)
suspend parallel (echoLoop client)
suspend acceptLoop server
def tcp : IO Unit := do
let addr := SocketAddressV4.mk (.ofParts 127 0 0 1) 8080
let server ← Server.mk
server.bind addr
server.listen 128
let action := parallel (acceptLoop server)
action |>.toIO
let t ← block (echoClient addr "monad") |>.toBaseIO
IO.println s!"--> endddd!!! {t}"
end Test
def measureIO (f : IO Unit) : IO Std.Time.Duration := do
let t ← Std.Time.Timestamp.now
f
let e ← Std.Time.Timestamp.now
return e - t
def main : IO Unit := do
let io ← measureIO (a 0)
let naive ← measureIO (b 0).toIO
let cps ← measureIO (c 0).toIO
let suspendableNaive ← measureIO (d 0).toIO
let suspendableCps ← measureIO (e 0).toIO
IO.println s!"IO: {io}"
IO.println s!"Naive: {naive}"
IO.println s!"CPS: {cps}"
IO.println s!"Suspendable Naive: {suspendableNaive}"
IO.println s!"Suspendable CPS: {suspendableCps}"
/-
IO: 101 lines of IR
Naive: 397 lines of IR
CPS: 832 lines of IR
Suspendable Naive: 233 lines of IR
Suspendable CPS: 460 lines of IR
IO: 3.601164000s
Naive: 25.425553000s
CPS: 29.176067000s
Suspendable Naive: 23.824065000s
Suspendable CPS: 25.893960000s
-/
-- New Ideas
-- Async (Task α)
-- Async α := BaseIO (MaybeTask α)
-- await : Task α → Async α
-- async : Async α → Async (Task α)
-- async do
-- IO.println "a"
-- await task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment