Skip to content

Instantly share code, notes, and snippets.

@stereobutter
Last active June 23, 2021 15:01
Show Gist options
  • Save stereobutter/306702ef8571123f42ecbcd3518db8aa to your computer and use it in GitHub Desktop.
Save stereobutter/306702ef8571123f42ecbcd3518db8aa to your computer and use it in GitHub Desktop.
graceful termination with trio
from .lifetime import Lifetime
from contextlib import contextmanager
@contextmanager
def merge_cancellation():
try:
yield
except trio.MultiError as wrapper:
errors = wrapper.exceptions
if all(isinstance(error, trio.Cancelled) for error in errors):
raise errors[0]
else:
raise
async def child_task(name, shutdown_delay):
with Lifetime() as lifetime:
try:
while not lifetime.stop_called:
print(f'ping from {name}')
await trio.sleep(1)
print(f'stopping {name}')
await trio.sleep(shutdown_delay)
print(f'stopped {name} gracefully')
except trio.Cancelled:
print(f'stopped {name} forcefully')
raise
finally:
print(f'cleaning up {name}')
async def main_task(lifetime):
with lifetime:
try:
with merge_cancellation():
async with trio.open_nursery() as nursery:
nursery.start_soon(child_task, 'alice', 1)
nursery.start_soon(child_task, 'bob', 2)
print('stopping main task')
await trio.sleep(1)
except trio.Cancelled:
print('stopped main task forcefully')
raise
else:
print('stopped main task gracefully')
finally:
print('cleaning up main task')
async def cancel_tasks(lifetime, delay, grace_period):
await trio.sleep(delay)
try:
with trio.fail_after(grace_period):
print('>>> calling soft cancel')
await lifetime.stop()
except trio.TooSlowError:
print('>>> calling hard cancel')
lifetime.cancel_scope.cancel()
async def main():
lifetime = Lifetime()
async with trio.open_nursery() as nursery:
nursery.start_soon(main_task, lifetime)
nursery.start_soon(cancel_tasks, lifetime, 3, 4) # try out different values here
trio.run(main)
import trio
from contextlib import contextmanager
from contextvars import ContextVar
CURRENT_SCOPE = ContextVar('CURRENT_SCOPE', default=None)
class Lifetime:
def __init__(self):
self._stop_event = trio.Event()
self._stopped_event = trio.Event()
self._nested_scopes = set()
self._cancel_scope = trio.CancelScope()
self._ctx = None
@property
def stop_called(self):
return self._stop_event.is_set()
@property
def stopped(self):
return self._stopped_event.is_set()
@property
def cancel_scope(self):
return self._cancel_scope
def _signal_stop(self):
self._stop_event.set()
async def stop(self):
self._signal_stop()
for scope in self._nested_scopes:
scope._signal_stop()
await self._stopped_event.wait()
async def wait_for_stop(self):
await self._stop_event.wait()
@contextmanager
def _contextmanager(self):
with self._cancel_scope:
parent_scope = CURRENT_SCOPE.get()
token = CURRENT_SCOPE.set(self)
if parent_scope is not None:
parent_scope._nested_scopes.add(self)
if parent_scope.stop_called:
self._signal_stop()
try:
yield self
except trio.Cancelled:
raise
else:
self._stopped_event.set()
finally:
if parent_scope is not None:
parent_scope._nested_scopes.remove(self)
CURRENT_SCOPE.reset(token)
def __enter__(self):
if self._ctx is None:
self._ctx = self._contextmanager()
else:
raise RuntimeError(f'{self} cannot be re-entered')
return self._ctx.__enter__()
def __exit__(self, et, ev, tb):
return self._ctx.__exit__(et, ev, tb)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment