Last active
April 14, 2025 12:46
-
-
Save algebraic-dev/e56604e833186f7795c2b43e13bf167d to your computer and use it in GitHub Desktop.
Alts for Async
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Std.Internal.Async | |
import Std.Internal.UV | |
import Std.Net.Addr | |
open Std.Internal.IO.Async.TCP.Socket | |
open Std.Internal.IO.Async.TCP | |
open Std.Internal.IO.Async | |
open Std.Net | |
-- All of these function are used to create the Monads. | |
@[inline] | |
def emptyTask : Task (Except IO.Error Unit) := | |
Task.pure (Except.pure ()) | |
@[inline] | |
def block (io : Task α) : IO α := | |
pure io.get | |
@[inline] | |
def nextTask {α} (t : Task α) (f : α → BaseIO Unit) : BaseIO Unit := | |
discard <| BaseIO.bindTask t (f · *> pure emptyTask) | |
namespace Naive | |
-- This is the most naive Monad that we should consider: It does not allow tail | |
-- recursion, so it creates a huge task stack when you use it. | |
def Async (α : Type) := BaseIO (AsyncTask α) | |
@[inline] | |
def Async.exec (x : Async α) : IO α := do | |
let task ← x.toIO | |
task.block | |
-- Instances | |
@[inline] | |
def bind {α β : Type} (x : BaseIO (AsyncTask α)) (f : α → BaseIO (AsyncTask β)) : BaseIO (AsyncTask β) := do | |
let task ← x | |
task.bindIO (BaseIO.toIO ∘ f) | |
@[inline] | |
def baseMonadLift {α : Type} (x : BaseIO α) : BaseIO (AsyncTask α) := do | |
AsyncTask.pure <$> x | |
@[inline] | |
def monadLift {α : Type} (x : IO α) : BaseIO (AsyncTask α) := do | |
let res ← x.toBaseIO | |
match res with | |
| .ok a => pure (AsyncTask.pure a) | |
| .error e => pure (Task.pure (Except.error e)) | |
@[inline] | |
def tryCatch (t : Async α) (f : IO.Error → Async α) : BaseIO (AsyncTask α) := do | |
let task ← t.run | |
BaseIO.bindTask task fun | |
| .ok res => pure (AsyncTask.pure res) | |
| .error err => f err |>.run | |
instance : Monad Async where | |
pure x := EStateM.pure (AsyncTask.pure x) | |
bind := bind | |
instance : MonadLift BaseIO Async where | |
monadLift := baseMonadLift | |
instance : MonadLift IO Async where | |
monadLift := monadLift | |
instance : MonadExcept IO.Error Async where | |
throw e := EStateM.pure (Task.pure (Except.error e)) | |
tryCatch := tryCatch | |
@[inline] | |
def fromIO (x : IO (AsyncTask α)) : BaseIO (AsyncTask α) := do | |
let res ← x.toBaseIO | |
match res with | |
| .ok result => pure result | |
| .error err => pure (Task.pure (Except.error err)) | |
@[inline] | |
def wait (time : Std.Time.Millisecond.Offset) : Async Unit := | |
fromIO <| Sleep.mk time >>= Sleep.wait | |
def await (task : IO (AsyncTask α)) : Async α := fun w => | |
match task w with | |
| .ok result s => .ok result s | |
| .error err s => .ok (Task.pure (.error err)) s | |
def parallel (task : Async α) : Async Unit := | |
discard <| task.toIO | |
def Async.toIO {α : Type} (task : Async α) : IO α := do | |
BaseIO.toIO task >>= AsyncTask.block | |
def block (task : Async α) : IO α := | |
task.toIO | |
end Naive | |
namespace CPS | |
-- This is a more advanced Monad that allows tail recursion, so it does not | |
-- create a huge task stack. | |
def Async (α : Type) := (AsyncTask α → BaseIO Unit) → BaseIO Unit | |
def Async.toIO {α : Type} (x : Async α) : IO α := do | |
let promise : IO.Promise (Except IO.Error α) ← IO.Promise.new | |
x fun task => discard <| IO.bindTask task fun x => do | |
promise.resolve x | |
return AsyncTask.pure () | |
IO.ofExcept promise.result!.get | |
@[inline] | |
def Async.toTask {α : Type} [Nonempty α] (x : Async α) : IO (Task (Except IO.Error α)) := do | |
let promise ← IO.Promise.new | |
x (nextTask · promise.resolve) | |
pure promise.result! | |
@[inline] | |
def Async.exec {α : Type} [Nonempty α] (x : Async α) : IO α := do | |
match (← block =<< x.toTask) with | |
| .ok res => pure res | |
| .error err => throw err | |
@[inline] | |
def await {α : Type} (task : IO (AsyncTask α)) : Async α := | |
fun k => do | |
match ← task.toBaseIO with | |
| .ok result => k result | |
| .error err => k (Task.pure (.error err)) | |
@[inline] | |
def parallel (task : Async α) : Async Unit := | |
fun k => do | |
discard <| IO.asTask (task (fun _ => pure ())) | |
k emptyTask | |
-- Instances | |
instance : Functor Async where | |
map f x := fun k => x (fun t => k (t.map f)) | |
instance : Monad Async where | |
pure a := fun k => k (AsyncTask.pure a) | |
bind x f := fun k => x fun t => | |
nextTask t (fun | |
| .ok res => (f res) k | |
| .error err => k (Task.pure (.error err))) | |
instance : MonadLift IO Async where | |
monadLift io := fun k => do | |
match ← io.toBaseIO with | |
| .ok res => k (AsyncTask.pure res) | |
| .error err => k (Task.pure (.error err)) | |
instance : MonadExcept IO.Error Async where | |
throw x := fun t => t (Task.pure (.error x)) | |
tryCatch x f := fun k => x fun t => | |
nextTask t fun | |
| .ok res => k (Task.pure (.ok res)) | |
| .error err => (f err) k | |
@[inline] | |
def fromIO (x : IO (AsyncTask α)) : Async α := fun k => do | |
let res ← x.toBaseIO | |
match res with | |
| .ok result => k result | |
| .error err => k (Task.pure (.error err)) | |
@[inline] | |
def wait (time : Std.Time.Millisecond.Offset) : Async Unit := | |
fromIO <| Sleep.mk time >>= Sleep.wait | |
@[inline] | |
def block (t : Async α) : IO α := do | |
t.toIO | |
-- | |
end CPS | |
namespace SuspendableNaive | |
inductive AsyncResult (ε σ α : Type u) where | |
| ok : α → σ → AsyncResult ε σ α | |
| error : ε → σ → AsyncResult ε σ α | |
| cont : Task (EStateM.Result ε σ α) → σ →AsyncResult ε σ α | |
@[inline] | |
def pure [Inhabited σ] (a : α) : AsyncResult ε σ α := | |
AsyncResult.ok a default | |
@[inline] | |
def toTask : AsyncResult ε σ α → Task (EStateM.Result ε σ α) | |
| .ok a s => Task.pure (.ok a s) | |
| .error e s => Task.pure (.error e s) | |
| .cont c _ => c | |
@[inline] | |
partial def map {α β : Type} (x : AsyncResult ε σ α) (f : α → β) : AsyncResult ε σ β := | |
match x with | |
| .ok a s => .ok (f a) s | |
| .error e s => .error e s | |
| .cont c s => .cont (c.map fun | |
| .ok a s => .ok (f a) s | |
| .error e s => .error e s) s | |
@[inline] | |
partial def bind {α β : Type} (x : AsyncResult ε σ α) (f : α → σ → AsyncResult ε σ β) : AsyncResult ε σ β := | |
match x with | |
| .ok a s => f a s | |
| .error e s => .error e s | |
| .cont c s => .cont (c.bind fun result => | |
match result with | |
| .ok a s => | |
match (f a s) with | |
| .ok a s => Task.pure (.ok a s) | |
| .error e s => Task.pure (.error e s) | |
| .cont c _ => c | |
| .error e s=> Task.pure (.error e s)) s | |
@[inline] | |
partial def «suspend» {α β : Type} (x : AsyncResult ε σ α) (f : α → σ → AsyncResult ε σ β) : AsyncResult ε σ β := | |
match x with | |
| .ok a s => f a s | |
| .error e s => .error e s | |
| .cont c s => .cont (c.bind fun result => | |
match result with | |
| .ok a s => | |
match (f a s) with | |
| .ok a s => Task.pure (.ok a s) | |
| .error e s => Task.pure (.error e s) | |
| .cont c _ => c | |
| .error e s=> Task.pure (.error e s)) s | |
def Async (α : Type) := IO.RealWorld → AsyncResult IO.Error IO.RealWorld α | |
def Async.toTask {α : Type} (x : Async α) : IO (Task (Except IO.Error α)) := fun w => | |
match x w with | |
| .ok a s => .ok (Task.pure (Except.ok a)) s | |
| .error e s => .error e s | |
| .cont c s => .ok (c.map fun | |
| .ok a _ => .ok a | |
| .error a _ => .error a) s | |
def Async.toIO {α : Type} (x : Async α) : IO α := do | |
let task ← x.toTask | |
IO.ofExcept task.get | |
instance : Functor Async where | |
map f x := fun w => match x w with | |
| .ok a s => .ok (f a) s | |
| .error e s => .error e s | |
| .cont c s => .cont (c.map fun | |
| .ok a s => .ok (f a) s | |
| .error e s => .error e s) s | |
instance : Monad Async where | |
pure a := fun w => .ok a w | |
bind x f := fun w => bind (x w) f | |
instance : MonadLift IO Async where | |
monadLift io := fun w => match io w with | |
| .ok a w => .ok a w | |
| .error e w => .error e w | |
instance : MonadExcept IO.Error Async where | |
throw x := fun w => .error x w | |
tryCatch x f := fun w => | |
match x w with | |
| .ok a s => .ok a s | |
| .error e w => | |
match f e w with | |
| .ok a w => .ok a w | |
| .error e w => .error e w | |
| .cont c s => .cont (c.map fun | |
| .ok a s => .ok a s | |
| .error e s => .error e s) s | |
| .cont c s => .cont (c.bind fun result => | |
match result with | |
| .ok a s => Task.pure (.ok a s) | |
| .error e s => Task.pure (.error e s)) s | |
instance : Suspend Async Async where | |
«suspend» x f := fun w => «suspend» (x w) f | |
@[inline] | |
def fromIO (x : IO (AsyncTask α)) : Async α := fun w => | |
match x w with | |
| .ok a s => .cont (Task.map (fun | |
| Except.ok a => .ok a s | |
| Except.error a => .error a s) a) s | |
| .error e s => .error e s | |
@[inline] | |
def wait (time : Std.Time.Millisecond.Offset) : Async Unit := | |
fromIO <| Sleep.mk time >>= Sleep.wait | |
@[inline] | |
def await (t : IO (AsyncTask α)) : Async α := fromIO t | |
@[inline] | |
def parallel (task : Async α) : Async Unit := fun w => | |
match task w with | |
| .ok _ s => .ok () s | |
| .error e s => .error e s | |
| .cont c s => | |
let op := BaseIO.bindTask c (fun t => do return (Task.pure t)) | |
let .ok _ s' := op s | |
.ok () s' | |
@[inline] | |
def block (t : Async α) : IO α := do | |
let t ← t.toTask | |
let f ← IO.wait t | |
IO.ofExcept f | |
end SuspendableNaive | |
namespace SuspendableCPS | |
inductive AsyncResult (ε σ α : Type u) where | |
| ok : α → σ → AsyncResult ε σ α | |
| error : ε → σ → AsyncResult ε σ α | |
| cont : ((Task (Except ε α) → BaseIO Unit) → BaseIO Unit) → σ → AsyncResult ε σ α | |
@[inline] | |
def pure [Inhabited σ] (a : α) : AsyncResult ε σ α := | |
AsyncResult.ok a default | |
def Async (α : Type) := | |
IO.RealWorld → AsyncResult IO.Error IO.RealWorld α | |
@[inline] | |
def Excp.ofBaseIO (a : BaseIO α) : IO.RealWorld → Except ε α := fun w => | |
match a w with | |
| .ok a _ => .ok a | |
@[inline] | |
def Async.ofBaseIO (a : BaseIO α) : Async α := fun w => | |
match a w with | |
| .ok a s => .ok a s | |
@[inline] | |
def Async.ofIO (a : IO α) : Async α := fun w => | |
match a w with | |
| .ok a s => .ok a s | |
| .error e s => .error e s | |
@[inline] | |
def Async.toTask (x : Async α) : BaseIO (Task (Except IO.Error α)) := fun w => | |
let result := x w | |
match result with | |
| .ok a s => .ok (Task.pure (.ok a)) s | |
| .error e s => .ok (Task.pure (.error e)) s | |
| .cont c s => | |
let op : BaseIO (Task (Except IO.Error α)) := do | |
let promise ← IO.Promise.new (α := Except IO.Error α) | |
c fun t => | |
nextTask t fun res => do promise.resolve res | |
return promise.result! | |
op s | |
@[inline] | |
partial def bind {α β : Type} (x : Async α) (f : α → Async β) : Async β := fun w => | |
match x w with | |
| .ok a s => f a s | |
| .error e s => .error e s | |
| .cont c s => | |
let op : BaseIO (Except IO.Error α) := do | |
let promise ← IO.Promise.new (α := Except IO.Error α) | |
c (nextTask · promise.resolve) | |
return promise.result!.get | |
let .ok a s := op s | |
match a with | |
| .ok a => f a s | |
| .error e => .error e s | |
@[inline] | |
partial def «suspend» {α β : Type} (x : AsyncResult IO.Error IO.RealWorld α) (f : α → IO.RealWorld → AsyncResult IO.Error IO.RealWorld β) : AsyncResult IO.Error IO.RealWorld β := | |
match x with | |
| .ok a s => f a s | |
| .error e s => .error e s | |
| .cont x s => .cont (fun k s => x (fun t => | |
nextTask t fun | |
| .ok res => | |
match (f res) s with | |
| .ok a _ => k (Task.pure (.ok a)) | |
| .error e _ => k (Task.pure (.error e)) | |
| .cont c _ => c k | |
| .error err => k (Task.pure (.error err)) | |
) s | |
) s | |
@[inline] | |
def Async.toIO {α : Type} (x : Async α) : IO α := do | |
let task ← x.toTask | |
IO.ofExcept task.get | |
instance : Monad Async where | |
pure a := fun w => .ok a w | |
bind := bind | |
instance : MonadLift IO Async where | |
monadLift io := fun w => match io w with | |
| .ok a w => .ok a w | |
| .error e w => .error e w | |
instance : Suspend Async Async where | |
«suspend» x f := fun w => «suspend» (x w) f | |
instance : Suspend (IO ∘ AsyncTask) Async where | |
«suspend» x f := fun w => | |
match x w with | |
| .ok t s => .cont (fun k => do | |
nextTask t (fun result => | |
match result with | |
| .ok a => do | |
let .ok res _ := (Async.toTask <| f a) s | |
k res | |
| .error e => k (Task.pure (.error e)) | |
) | |
) s | |
| .error e s => .error e s | |
@[inline] | |
def fromIO (x : IO (AsyncTask α)) : Async α := fun w => | |
match x w with | |
| .ok a s => .cont (· a) s | |
| .error e s => .error e s | |
@[inline] | |
def wait (time : Std.Time.Millisecond.Offset) : Async Unit := | |
fromIO <| Sleep.mk time >>= Sleep.wait | |
def await (t : IO (AsyncTask α)) : Async α := fun w => | |
AsyncResult.cont (fun k => do | |
let result ← t.toBaseIO | |
match result with | |
| .ok a => k a | |
| .error e => k (Task.pure (.error e)) | |
) w | |
def parallel (task : Async α) : Async Unit := fun w => | |
AsyncResult.cont (fun k => do | |
discard task.toTask | |
k emptyTask | |
) w | |
def block (t : Async α) : IO α := do | |
let task ← t.toTask | |
IO.ofExcept =<< IO.wait task | |
end SuspendableCPS | |
namespace MacMaybeTask | |
namespace MaybeTask | |
inductive MaybeTask (α : Type u) | |
| protected pure (a : α) | |
| ofTask (a : Task α) | |
@[inline] | |
def joinTask (t : Task (MaybeTask α)) (prio := Task.Priority.default) : Task α := t.bind (prio := prio) (sync := true) fun | .pure a => .pure a | .ofTask t => t | |
def AsyncIO (α : Type) := BaseIO (MaybeTask α) | |
@[inline] | |
def AsyncIO.mk (x : BaseIO (MaybeTask α)) : AsyncIO α := x | |
def durat := 600000 | |
-- 101 lines of IR | |
partial def a (n : Nat) : IO Unit := do | |
IO.println "a" | |
IO.println "b" | |
IO.println "c" | |
let _ := IO.sleep 0 | |
IO.println s!"d {n}" | |
IO.println "d" | |
if n < durat then | |
a (n + 1) | |
-- 397 lines of IR | |
open Naive in | |
partial def b (n : Nat) : Async Unit := do | |
IO.println "a" | |
IO.println "b" | |
IO.println "c" | |
let _ := wait 0 | |
IO.println s!"d {n}" | |
IO.println "d" | |
if n < durat then | |
b (n + 1) | |
-- 832 lines of IR | |
open CPS in | |
partial def c (n : Nat) : Async Unit := do | |
IO.println "a" | |
IO.println "b" | |
IO.println "c" | |
let _ := wait 0 | |
IO.println s!"d {n}" | |
IO.println "d" | |
if n < durat then | |
c (n + 1) | |
-- 233 lines of IR | |
-- The bind behavior on the monad causes it to become blocking. | |
open SuspendableNaive in | |
partial def d (n : Nat) : Async Unit := do | |
IO.println "a" | |
IO.println "b" | |
IO.println "c" | |
let _ ← wait 0 | |
IO.println s!"d {n}" | |
if n < durat then | |
d (n + 1) | |
-- 460 lines of IR | |
-- The bind behavior on the monad causes it to become blocking. | |
open SuspendableCPS in | |
partial def e (n : Nat) : Async Unit := do | |
IO.println "a" | |
IO.println "b" | |
IO.println "c" | |
let _ ← wait 0 | |
IO.println s!"d {n}" | |
if n < durat then | |
e (n + 1) | |
-- TCP | |
namespace Test | |
open SuspendableCPS | |
partial def writeLoop (client : Socket.Client) (message : String) : Async Unit := do | |
IO.println s!"write loop: {message}" | |
suspend await <| client.send (String.toUTF8 message) | |
if let none := suspend await <| client.recv? 1024 then | |
IO.println "client disconnected from receiving" | |
else | |
suspend writeLoop client message | |
def echoClient (addr : SocketAddress) (message : String) : Async Unit := do | |
let socket ← Client.mk | |
suspend await <| socket.connect addr | |
let t ← socket.getPeerName | |
suspend writeLoop socket message | |
partial def echoLoop (client : Socket.Client) : Async Unit := do | |
let message := suspend await <| client.recv? 1024 | |
IO.println s!"received: {String.fromUTF8! <$> message}" | |
if let some msg := message then | |
let _ := suspend await <| client.send msg | |
suspend echoLoop client | |
else | |
IO.println "client disconnected from echoing" | |
partial def acceptLoop (server : Socket.Server) : Async Unit := do | |
let client := suspend await (server.accept) | |
suspend parallel (echoLoop client) | |
suspend acceptLoop server | |
def tcp : IO Unit := do | |
let addr := SocketAddressV4.mk (.ofParts 127 0 0 1) 8080 | |
let server ← Server.mk | |
server.bind addr | |
server.listen 128 | |
let action := parallel (acceptLoop server) | |
action |>.toIO | |
let t ← block (echoClient addr "monad") |>.toBaseIO | |
IO.println s!"--> endddd!!! {t}" | |
end Test | |
def measureIO (f : IO Unit) : IO Std.Time.Duration := do | |
let t ← Std.Time.Timestamp.now | |
f | |
let e ← Std.Time.Timestamp.now | |
return e - t | |
def main : IO Unit := do | |
let io ← measureIO (a 0) | |
let naive ← measureIO (b 0).toIO | |
let cps ← measureIO (c 0).toIO | |
let suspendableNaive ← measureIO (d 0).toIO | |
let suspendableCps ← measureIO (e 0).toIO | |
IO.println s!"IO: {io}" | |
IO.println s!"Naive: {naive}" | |
IO.println s!"CPS: {cps}" | |
IO.println s!"Suspendable Naive: {suspendableNaive}" | |
IO.println s!"Suspendable CPS: {suspendableCps}" | |
/- | |
IO: 101 lines of IR | |
Naive: 397 lines of IR | |
CPS: 832 lines of IR | |
Suspendable Naive: 233 lines of IR | |
Suspendable CPS: 460 lines of IR | |
IO: 3.601164000s | |
Naive: 25.425553000s | |
CPS: 29.176067000s | |
Suspendable Naive: 23.824065000s | |
Suspendable CPS: 25.893960000s | |
-/ | |
-- New Ideas | |
-- Async (Task α) | |
-- Async α := BaseIO (MaybeTask α) | |
-- await : Task α → Async α | |
-- async : Async α → Async (Task α) | |
-- async do | |
-- IO.println "a" | |
-- await task |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment