Skip to content

Instantly share code, notes, and snippets.

@matsadler
Created May 1, 2020 21:47
Show Gist options
  • Save matsadler/79950cce41d00166f3abe7ff5a16cec5 to your computer and use it in GitHub Desktop.
Save matsadler/79950cce41d00166f3abe7ff5a16cec5 to your computer and use it in GitHub Desktop.
Toy implementation of Rust's async/await in Ruby
require "set"
class Waker
def initialize(id, queue)
@id = id
@queue = queue
end
def wake
@queue << @id
end
end
class Reactor
def initialize
@io = {r: Set[], w: Set[], e: Set[]}
@wakers = {}
end
def register_read(io, waker)
register(io, waker, :r)
end
def register_write(io, waker)
register(io, waker, :w)
end
def deregister_read(io)
deregister(io, :r)
end
def deregister_write(io)
deregister(io, :w)
end
def tick
read, write, _ = select(@io[:r].to_a, @io[:w].to_a, @io[:e].to_a)
read.each {|io| @wakers[io][:r].wake}
write.each {|io| @wakers[io][:w].wake}
end
private
def register(io, waker, interest)
@io[interest] << io
(@wakers[io] ||= {})[interest] = waker
nil
end
def deregister(io, interest)
@io[interest].delete(io)
wakers = @wakers[io]
if wakers
wakers.delete(interest)
@wakers.delete(io) if wakers.empty?
end
nil
end
end
class Executor
def initialize
@tasks = {}
@ready = []
@task_id = 0
end
@reactor = Reactor.new
def self.reactor
@reactor
end
def reactor
self.class.reactor
end
def spawn(task)
task_id = (@task_id += 1)
@tasks[task_id] = task
poll(task_id)
end
def run
while @tasks.any?
reactor.tick
while task_id = @ready.shift
poll(task_id)
end
end
end
private
def poll(task_id)
waker = Waker.new(task_id, @ready)
if :pending != @tasks[task_id].poll(waker)
@tasks.delete(task_id)
end
nil
end
end
module AsyncRead
def poll_read(waker, buf)
res = to_io.read_nonblock(1024 * 16, buf, exception: false)
if res != :wait_readable
Executor.reactor.deregister_read(to_io)
return res&.bytesize
end
Executor.reactor.register_read(to_io, waker)
:pending
end
end
module AsyncWrite
def poll_write(waker, buf)
res = to_io.write_nonblock(buf, exception: false)
if res != :wait_readable
Executor.reactor.deregister_write(to_io)
return res
end
Executor.reactor.register_write(to_io, waker)
:pending
end
end
module Future
def inspect(&block)
return super unless block
InspectFuture.new(self, block)
end
def await
Fiber.yield(AwaitFuture.new(self))
end
def then(&block)
ThenFuture.new(self, block)
end
end
class InspectFuture
include Future
def initialize(parent, block)
@parent = parent
@block = block
end
def poll(waker)
res = @parent.poll(waker)
return :pending if :pending == res
@block.call(res)
res
end
end
class ThenFuture
include Future
def initialize(parent, block)
@parent = parent
@block = block
@child = nil
end
def poll(waker)
return @child.poll(waker) if @child
res = @parent.poll(waker)
return :pending if :pending == res
@child = @block.call(res)
@child.poll(waker)
end
end
class AwaitFuture
include Future
def initialize(parent)
@parent = parent
end
def poll(waker)
@parent.poll(waker)
end
end
class FutureIO
include AsyncRead
include AsyncWrite
def initialize(io)
@io = io
end
def to_io
@io
end
end
class ConnectFuture
include Future
def initialize(sock, addrinfo)
@sock = sock
@addrinfo = addrinfo
end
def poll(waker)
res = @sock.to_io.connect_nonblock(@addrinfo, exception: false)
if res != :wait_writable
Executor.reactor.deregister_write(@sock.to_io)
return @sock
end
Executor.reactor.register_write(@sock.to_io, waker)
:pending
end
end
require "socket"
class FutureSocket < FutureIO
def self.connect(addrinfo)
sock = Socket.new(addrinfo.pfamily, addrinfo.socktype)
ConnectFuture.new(new(sock), addrinfo)
end
end
class FutureTCPSocket < FutureSocket
def self.connect(host, port)
super(Addrinfo.tcp("example.com", 80))
end
end
class ReadAll
include Future
def initialize(io)
@io = io
@buffer = +""
end
def poll(waker)
temp = +""
loop do
res = @io.poll_read(waker, temp)
return res if res == :pending
return @buffer unless res
@buffer << temp
end
end
end
def read_all(io)
ReadAll.new(io)
end
class WriteAll
include Future
def initialize(io, buffer)
@io = io
@buffer = buffer
@total = 0
end
def poll(waker)
until @buffer.empty?
written = @io.poll_write(waker, @buffer)
return written if written == :pending
@buffer = @buffer[written..]
@total += written
end
@total
end
end
def write_all(io, buffer)
WriteAll.new(io, buffer)
end
require "fiber"
class Async
include Future
def initialize(fiber)
@parent = nil
@fiber = fiber
end
def poll(waker)
@parent = @fiber.resume unless @parent
while @parent.is_a?(AwaitFuture)
res = @parent.poll(waker)
return :pending if :pending == res
@parent = @fiber.resume(res)
end
@parent
rescue FiberError
end
end
def async(&block)
Async.new(Fiber.new(&block))
end
x = async do
sock = FutureTCPSocket.connect("example.com", 80).await
write_all(sock, <<~REQ).await
GET / HTTP/1.0\r
Host: example.com\r
\r
REQ
puts read_all(sock).await
end
e = Executor.new
e.spawn(x)
e.run
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment