Skip to content

Instantly share code, notes, and snippets.

@njsmith
Last active October 1, 2018 13:03
Show Gist options
  • Select an option

  • Save njsmith/a655f364e61cfda94d377a2ee2f9c7cb to your computer and use it in GitHub Desktop.

Select an option

Save njsmith/a655f364e61cfda94d377a2ee2f9c7cb to your computer and use it in GitHub Desktop.
import curio
# A simple task supervisor. Very loosely inspired by Erlang's supervisor
# trees:
#
# http://erlang.org/doc/design_principles/sup_princ.html
# http://erlang.org/doc/man/supervisor.html
#
# ...but much simpler (no respawning, no one-for-one policy, etc.).
#
# Semantics are:
# - tracks a set of tasks
# - new tasks can be added by calling the spawn() method
# - if a task exits normally or is cancelled, that's fine, we let it go
# - if a task crashes, then we cancel all tasks and exit, propagating errors
# - if we are cancelled, then we cancel all tasks and exit, propagating errors
# - shutdown() can be called to cancel all tasks and trigger run() to exit
# - run() does not exit until shutdown() is called
#
# Multiple supervisors can be nested in a tree structure like
#
# root = Supervisor()
# child = Supervisor()
# _, root_task = await root.start(curio.spawn)
# _, child_task = await child.start(root_task.spawn)
__all__ = ["Supervisor"]
# async def start_listener(spawn_fn, listener_config, shutdown_deadline):
# ...
# await spawn_fn(_accept_loop(listen_sock,
# listener_config,
# shutdown_broadcast,
# ))
# return url
# async def _accept_loop(listen_sock, listener_config, shutdown_deadline):
# async with self._listen_sock:
# while True:
# await curio.sleep(0)
# ...
# A free-floating task that just channels the done notification from the given
# task into the given queue
async def _join_and_notify(task, q):
try:
await task.join()
except Exception:
pass
await q.put(task)
class Supervisor:
def __init__(self):
self._started = False
self._shutting_down = False
self._finished = curio.Event()
self._tasks = set()
self._q = curio.Queue()
async def spawn(self, coro):
if not self._started:
# We refuse to accept tasks when we haven't started yet, because
# we can't guarantee that we'll actually supervise them.
raise RuntimeError("supervisor has not started")
if self._shutting_down:
raise RuntimeError("supervisor is shutting down, can't accept "
"new tasks")
## Start critical region
task = await curio.spawn(coro)
# can't fail
self._tasks.add(task)
# can't fail
await curio.spawn(_join_and_notify(task, self._q))
## End of critical region (task is now supervised)
return task
async def run(self, bootstrap_coro=None):
async def fake_spawn_fn(coro):
await coro
await self.start(fake_spawn_fn, bootstrap_coro)
# Start running this supervisor as a task. The important thing about this
# as compared to just doing
#
# spawn_fn(supervisor.run())
#
# is that start() doesn't return until bootstrap_coro has completed, and
# errors in bootstrap_coro are propagated.
async def start(self, spawn_fn, bootstrap_coro=None):
if self._started:
raise RuntimeError("supervisor can only be run once")
self._started = True
# Some things this is careful about:
# - we run the bootstrap coro in *this* context
# - we don't return until bootstrap_coro finishes
# - any errors raised by bootstrap_coro propagate to our caller
# - if bootstrap_coro fails, we shut down everything properly
# - we work even if spawn_fn is the trivial fake one used by run()
# (this is why we can't just spawn run() at the top)
try:
try:
if bootstrap_coro is not None:
bootstrap_result = await bootstrap_coro
else:
bootstrap_result = None
except:
await self.start_shutdown()
raise
finally:
spawn_result = await spawn_fn(self._supervise_loop())
return (bootstrap_result, spawn_result)
async def _process_finished_task(self, task):
# Clean up bookkeeping information for this task
# Ignore CancelledError, but any other errors raise an exception
# This is called with cancellation disabled
assert task.terminated
self._tasks.remove(task)
try:
result = await task.join()
except curio.TaskError as e:
if not isinstance(e.__cause__, curio.CancelledError):
raise
else:
if result is not None:
raise ValueError("Supervised tasks should return None, "
"not {!r}".format(result))
async def _supervise_loop(self):
assert self._started
chained_exc = None
async with curio.defer_cancellation:
while self._tasks or not self._shutting_down:
try:
# This is the only line where we want to allow
# cancellation / timeouts:
async with curio.allow_cancellation:
task = await self._q.get()
await self._process_finished_task(task)
except BaseException as exc:
if not self._shutting_down:
await self.start_shutdown()
if chained_exc is not None:
exc.__context__ = chained_exc
chained_exc = exc
# All done!
await self._finished.set()
if chained_exc is not None:
raise chained_exc
async def start_graceful_shutdown(self):
# Stop accepting new tasks and exit when all current tasks have
# finished, but don't do anything to actually kill existing tasks.
self._shutting_down = True
async def start_shutdown(self):
self._shutting_down = True
for task in self._tasks:
await task.cancel(blocking=False)
async def wait(self):
await self._finished.wait()
import pytest
import curio
from .tutil import Sequencer
from .._supervisor import Supervisor
@pytest.mark.curio
async def test_supervisor():
happytask_canceled = False
async def happytask(seq):
async with seq(0):
pass
try:
await curio.sleep(10000)
except curio.CancelledError:
nonlocal happytask_canceled
happytask_canceled = True
raise
async def crashytask(seq):
async with seq(1):
raise ValueError("whatever")
happycrashybootstrap_finished = False
async def happycrashybootstrap(sup):
seq = Sequencer()
await sup.spawn(happytask(seq))
await sup.spawn(crashytask(seq))
nonlocal happycrashybootstrap_finished
happycrashybootstrap_finished = True
return "hi"
supervisor = Supervisor()
with pytest.raises(curio.TaskError) as exc_info:
await supervisor.run(happycrashybootstrap(supervisor))
assert isinstance(exc_info.value.__cause__, ValueError)
assert happytask_canceled
# Can't add tasks to a non-running supervisor
supervisor = Supervisor()
with pytest.raises(RuntimeError):
await supervisor.spawn(happytask(Sequencer()))
# Test start -- in particular, the bootstrap should finish before it
# returns
supervisor = Supervisor()
happycrashybootstrap_finished = False
b, t = await supervisor.start(curio.spawn, happycrashybootstrap(supervisor))
assert happycrashybootstrap_finished
assert b == "hi"
with pytest.raises(curio.TaskError):
await t.join()
# Test errors in bootstrap are propagated and cleanup still happens
async def happy_but_doomed_bootstrap(sup):
seq = Sequencer()
await sup.spawn(happytask(seq))
raise ValueError("oops")
happytask_canceled = False
supervisor = Supervisor()
with pytest.raises(ValueError):
await supervisor.run(happy_but_doomed_bootstrap(supervisor))
assert happytask_canceled
# Test cancelling the supervisor itself
supervisor = Supervisor()
happytask_canceled = False
bootstrap_result, t = await supervisor.start(curio.spawn)
assert bootstrap_result is None
h = await supervisor.spawn(happytask(Sequencer()))
await t.cancel()
# By the time the supervisor task completes, happytask should have been
# fully canceled and cleaned up.
assert h.terminated
assert happytask_canceled
# Test supervisor.start_shutdown()
supervisor = Supervisor()
_, t = await supervisor.start(curio.spawn)
seq = Sequencer()
h = await supervisor.spawn(happytask(seq))
# Give happytask time to get started:
async with seq(1):
await supervisor.start_shutdown()
# the supervisor task finishes without error:
await t.join()
# happytask was cancelled:
assert h.terminated
with pytest.raises(curio.TaskError) as exc_info:
await h.join()
assert isinstance(exc_info.value.__cause__, curio.CancelledError)
@pytest.mark.curio
async def test_the_shutdown_is_coming_from_INSIDE_the_supervisor():
# kill_them_all here actually cancels *itself*, which used to make curio
# uncomfortable until I fixed it.
async def happytask(seq):
async with seq(0):
pass
await curio.sleep(1000)
async def kill_them_all(seq, sup):
async with seq(1):
await sup.start_shutdown()
async def bootstrap(sup):
seq = Sequencer()
await sup.spawn(happytask(seq))
await sup.spawn(kill_them_all(seq, sup))
supervisor = Supervisor()
await supervisor.run(bootstrap(supervisor))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment