Created
May 1, 2020 21:47
-
-
Save matsadler/79950cce41d00166f3abe7ff5a16cec5 to your computer and use it in GitHub Desktop.
Toy implementation of Rust's async/await in Ruby
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require "set" | |
class Waker | |
def initialize(id, queue) | |
@id = id | |
@queue = queue | |
end | |
def wake | |
@queue << @id | |
end | |
end | |
class Reactor | |
def initialize | |
@io = {r: Set[], w: Set[], e: Set[]} | |
@wakers = {} | |
end | |
def register_read(io, waker) | |
register(io, waker, :r) | |
end | |
def register_write(io, waker) | |
register(io, waker, :w) | |
end | |
def deregister_read(io) | |
deregister(io, :r) | |
end | |
def deregister_write(io) | |
deregister(io, :w) | |
end | |
def tick | |
read, write, _ = select(@io[:r].to_a, @io[:w].to_a, @io[:e].to_a) | |
read.each {|io| @wakers[io][:r].wake} | |
write.each {|io| @wakers[io][:w].wake} | |
end | |
private | |
def register(io, waker, interest) | |
@io[interest] << io | |
(@wakers[io] ||= {})[interest] = waker | |
nil | |
end | |
def deregister(io, interest) | |
@io[interest].delete(io) | |
wakers = @wakers[io] | |
if wakers | |
wakers.delete(interest) | |
@wakers.delete(io) if wakers.empty? | |
end | |
nil | |
end | |
end | |
class Executor | |
def initialize | |
@tasks = {} | |
@ready = [] | |
@task_id = 0 | |
end | |
@reactor = Reactor.new | |
def self.reactor | |
@reactor | |
end | |
def reactor | |
self.class.reactor | |
end | |
def spawn(task) | |
task_id = (@task_id += 1) | |
@tasks[task_id] = task | |
poll(task_id) | |
end | |
def run | |
while @tasks.any? | |
reactor.tick | |
while task_id = @ready.shift | |
poll(task_id) | |
end | |
end | |
end | |
private | |
def poll(task_id) | |
waker = Waker.new(task_id, @ready) | |
if :pending != @tasks[task_id].poll(waker) | |
@tasks.delete(task_id) | |
end | |
nil | |
end | |
end | |
module AsyncRead | |
def poll_read(waker, buf) | |
res = to_io.read_nonblock(1024 * 16, buf, exception: false) | |
if res != :wait_readable | |
Executor.reactor.deregister_read(to_io) | |
return res&.bytesize | |
end | |
Executor.reactor.register_read(to_io, waker) | |
:pending | |
end | |
end | |
module AsyncWrite | |
def poll_write(waker, buf) | |
res = to_io.write_nonblock(buf, exception: false) | |
if res != :wait_readable | |
Executor.reactor.deregister_write(to_io) | |
return res | |
end | |
Executor.reactor.register_write(to_io, waker) | |
:pending | |
end | |
end | |
module Future | |
def inspect(&block) | |
return super unless block | |
InspectFuture.new(self, block) | |
end | |
def await | |
Fiber.yield(AwaitFuture.new(self)) | |
end | |
def then(&block) | |
ThenFuture.new(self, block) | |
end | |
end | |
class InspectFuture | |
include Future | |
def initialize(parent, block) | |
@parent = parent | |
@block = block | |
end | |
def poll(waker) | |
res = @parent.poll(waker) | |
return :pending if :pending == res | |
@block.call(res) | |
res | |
end | |
end | |
class ThenFuture | |
include Future | |
def initialize(parent, block) | |
@parent = parent | |
@block = block | |
@child = nil | |
end | |
def poll(waker) | |
return @child.poll(waker) if @child | |
res = @parent.poll(waker) | |
return :pending if :pending == res | |
@child = @block.call(res) | |
@child.poll(waker) | |
end | |
end | |
class AwaitFuture | |
include Future | |
def initialize(parent) | |
@parent = parent | |
end | |
def poll(waker) | |
@parent.poll(waker) | |
end | |
end | |
class FutureIO | |
include AsyncRead | |
include AsyncWrite | |
def initialize(io) | |
@io = io | |
end | |
def to_io | |
@io | |
end | |
end | |
class ConnectFuture | |
include Future | |
def initialize(sock, addrinfo) | |
@sock = sock | |
@addrinfo = addrinfo | |
end | |
def poll(waker) | |
res = @sock.to_io.connect_nonblock(@addrinfo, exception: false) | |
if res != :wait_writable | |
Executor.reactor.deregister_write(@sock.to_io) | |
return @sock | |
end | |
Executor.reactor.register_write(@sock.to_io, waker) | |
:pending | |
end | |
end | |
require "socket" | |
class FutureSocket < FutureIO | |
def self.connect(addrinfo) | |
sock = Socket.new(addrinfo.pfamily, addrinfo.socktype) | |
ConnectFuture.new(new(sock), addrinfo) | |
end | |
end | |
class FutureTCPSocket < FutureSocket | |
def self.connect(host, port) | |
super(Addrinfo.tcp("example.com", 80)) | |
end | |
end | |
class ReadAll | |
include Future | |
def initialize(io) | |
@io = io | |
@buffer = +"" | |
end | |
def poll(waker) | |
temp = +"" | |
loop do | |
res = @io.poll_read(waker, temp) | |
return res if res == :pending | |
return @buffer unless res | |
@buffer << temp | |
end | |
end | |
end | |
def read_all(io) | |
ReadAll.new(io) | |
end | |
class WriteAll | |
include Future | |
def initialize(io, buffer) | |
@io = io | |
@buffer = buffer | |
@total = 0 | |
end | |
def poll(waker) | |
until @buffer.empty? | |
written = @io.poll_write(waker, @buffer) | |
return written if written == :pending | |
@buffer = @buffer[written..] | |
@total += written | |
end | |
@total | |
end | |
end | |
def write_all(io, buffer) | |
WriteAll.new(io, buffer) | |
end | |
require "fiber" | |
class Async | |
include Future | |
def initialize(fiber) | |
@parent = nil | |
@fiber = fiber | |
end | |
def poll(waker) | |
@parent = @fiber.resume unless @parent | |
while @parent.is_a?(AwaitFuture) | |
res = @parent.poll(waker) | |
return :pending if :pending == res | |
@parent = @fiber.resume(res) | |
end | |
@parent | |
rescue FiberError | |
end | |
end | |
def async(&block) | |
Async.new(Fiber.new(&block)) | |
end | |
x = async do | |
sock = FutureTCPSocket.connect("example.com", 80).await | |
write_all(sock, <<~REQ).await | |
GET / HTTP/1.0\r | |
Host: example.com\r | |
\r | |
REQ | |
puts read_all(sock).await | |
end | |
e = Executor.new | |
e.spawn(x) | |
e.run |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment