Skip to content

Instantly share code, notes, and snippets.

@stereobutter
Last active August 18, 2022 20:53
Show Gist options
  • Save stereobutter/64a102aca892fce6916a7264b056c9e1 to your computer and use it in GitHub Desktop.
Save stereobutter/64a102aca892fce6916a7264b056c9e1 to your computer and use it in GitHub Desktop.
soft cancellation with trio
import trio
from .soft_cancel_scope import SoftCancelScope
async def main():
async def do_stuff(num):
print(f'starting job {num} at t={trio.current_time()-START_TIME}')
await trio.sleep(1)
print(f'finished job {num} at t={trio.current_time()-START_TIME}')
async def some_task(cleanup_time, soft_cancel_scope):
with soft_cancel_scope:
for i in range(10):
await do_stuff(i)
if soft_cancel_scope.cancel_called:
try:
print(f'attempting cleanup at t={trio.current_time()-START_TIME}')
await trio.sleep(cleanup_time)
except trio.Cancelled:
print(f'cleanup aborted at t={trio.current_time()-START_TIME}')
raise
else:
print(f'finished cleanup at t={trio.current_time()-START_TIME}')
break
async def cancel_after(timeout, soft_cancel_scope):
await trio.sleep(5.5)
print(f'soft cancel requested at t={trio.current_time()-START_TIME}')
try:
with trio.fail_after(timeout):
await soft_cancel_scope.cancel()
print(f'soft cancel completed at t={trio.current_time()-START_TIME}')
except trio.TooSlowError:
print(f'hard cancelled initiated at t={trio.current_time()-START_TIME}')
START_TIME = trio.current_time()
soft_cancel_scope = SoftCancelScope()
async with trio.open_nursery() as nursery:
nursery.start_soon(some_task, 3, soft_cancel_scope)
nursery.start_soon(cancel_after, 3, soft_cancel_scope)
trio.run(main)
import trio
from contextlib import contextmanager
from contextvars import ContextVar
CURRENT_SCOPE = ContextVar('CURRENT_SCOPE', default=None)
class SoftCancelScope:
def __init__(self):
self._event = trio.Event()
self._cancel_called = False
self._nested_scopes = set()
self._ctx = None
self._hard_cancel_scope = trio.CancelScope()
@property
def cancel_called(self):
return self._cancel_called
@property
def cancel_caught(self):
return self._end_event.is_set()
async def cancel(self):
current_scope = CURRENT_SCOPE.get()
ancestors = current_scope._ancestral_scopes() if current_scope is not None else ()
if self is current_scope or self in ancestors:
raise RuntimeError('Cannot cancel from within scope')
try:
async with trio.open_nursery() as nursery:
for scope in self._nested_scopes:
nursery.start_soon(scope.cancel)
self._cancel_called = True
await self._event.wait()
except trio.Cancelled:
self._hard_cancel_scope.cancel()
raise
@contextmanager
def _contextmanager(self):
parent_scope = CURRENT_SCOPE.get()
token = CURRENT_SCOPE.set(self)
self._parent_scope = parent_scope
if parent_scope is not None:
parent_scope._nested_scopes.add(self)
if parent_scope._cancel_called:
self.cancel_nowait()
with self._hard_cancel_scope:
try:
yield self
except trio.Cancelled:
raise
else:
self._event.set()
finally:
if parent_scope is not None:
parent_scope._nested_scopes.remove(self)
CURRENT_SCOPE.reset(token)
def _ancestral_scopes(self):
parent_scope = self._parent_scope
if parent_scope is None:
yield None
else:
yield parent_scope
yield from parent_scope._ancestral_scopes()
def __enter__(self):
if self._ctx is None:
self._ctx = self._contextmanager()
else:
raise RuntimeError(f'{self} cannot be re-entered')
return self._ctx.__enter__()
def __exit__(self, et, ev, tb):
if self._ctx is None:
raise RuntimeError(f'{self} has not been entered')
return self._ctx.__exit__(et, ev, tb)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment