Last active
August 18, 2022 20:53
-
-
Save stereobutter/64a102aca892fce6916a7264b056c9e1 to your computer and use it in GitHub Desktop.
soft cancellation with trio
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import trio | |
from .soft_cancel_scope import SoftCancelScope | |
async def main(): | |
async def do_stuff(num): | |
print(f'starting job {num} at t={trio.current_time()-START_TIME}') | |
await trio.sleep(1) | |
print(f'finished job {num} at t={trio.current_time()-START_TIME}') | |
async def some_task(cleanup_time, soft_cancel_scope): | |
with soft_cancel_scope: | |
for i in range(10): | |
await do_stuff(i) | |
if soft_cancel_scope.cancel_called: | |
try: | |
print(f'attempting cleanup at t={trio.current_time()-START_TIME}') | |
await trio.sleep(cleanup_time) | |
except trio.Cancelled: | |
print(f'cleanup aborted at t={trio.current_time()-START_TIME}') | |
raise | |
else: | |
print(f'finished cleanup at t={trio.current_time()-START_TIME}') | |
break | |
async def cancel_after(timeout, soft_cancel_scope): | |
await trio.sleep(5.5) | |
print(f'soft cancel requested at t={trio.current_time()-START_TIME}') | |
try: | |
with trio.fail_after(timeout): | |
await soft_cancel_scope.cancel() | |
print(f'soft cancel completed at t={trio.current_time()-START_TIME}') | |
except trio.TooSlowError: | |
print(f'hard cancelled initiated at t={trio.current_time()-START_TIME}') | |
START_TIME = trio.current_time() | |
soft_cancel_scope = SoftCancelScope() | |
async with trio.open_nursery() as nursery: | |
nursery.start_soon(some_task, 3, soft_cancel_scope) | |
nursery.start_soon(cancel_after, 3, soft_cancel_scope) | |
trio.run(main) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import trio | |
from contextlib import contextmanager | |
from contextvars import ContextVar | |
CURRENT_SCOPE = ContextVar('CURRENT_SCOPE', default=None) | |
class SoftCancelScope: | |
def __init__(self): | |
self._event = trio.Event() | |
self._cancel_called = False | |
self._nested_scopes = set() | |
self._ctx = None | |
self._hard_cancel_scope = trio.CancelScope() | |
@property | |
def cancel_called(self): | |
return self._cancel_called | |
@property | |
def cancel_caught(self): | |
return self._end_event.is_set() | |
async def cancel(self): | |
current_scope = CURRENT_SCOPE.get() | |
ancestors = current_scope._ancestral_scopes() if current_scope is not None else () | |
if self is current_scope or self in ancestors: | |
raise RuntimeError('Cannot cancel from within scope') | |
try: | |
async with trio.open_nursery() as nursery: | |
for scope in self._nested_scopes: | |
nursery.start_soon(scope.cancel) | |
self._cancel_called = True | |
await self._event.wait() | |
except trio.Cancelled: | |
self._hard_cancel_scope.cancel() | |
raise | |
@contextmanager | |
def _contextmanager(self): | |
parent_scope = CURRENT_SCOPE.get() | |
token = CURRENT_SCOPE.set(self) | |
self._parent_scope = parent_scope | |
if parent_scope is not None: | |
parent_scope._nested_scopes.add(self) | |
if parent_scope._cancel_called: | |
self.cancel_nowait() | |
with self._hard_cancel_scope: | |
try: | |
yield self | |
except trio.Cancelled: | |
raise | |
else: | |
self._event.set() | |
finally: | |
if parent_scope is not None: | |
parent_scope._nested_scopes.remove(self) | |
CURRENT_SCOPE.reset(token) | |
def _ancestral_scopes(self): | |
parent_scope = self._parent_scope | |
if parent_scope is None: | |
yield None | |
else: | |
yield parent_scope | |
yield from parent_scope._ancestral_scopes() | |
def __enter__(self): | |
if self._ctx is None: | |
self._ctx = self._contextmanager() | |
else: | |
raise RuntimeError(f'{self} cannot be re-entered') | |
return self._ctx.__enter__() | |
def __exit__(self, et, ev, tb): | |
if self._ctx is None: | |
raise RuntimeError(f'{self} has not been entered') | |
return self._ctx.__exit__(et, ev, tb) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment