Created
December 15, 2024 10:50
-
-
Save graingert/3c75a9bec75ea1535e22a2f7d938a344 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import dataclasses | |
import math | |
from collections.abc import Callable, Coroutine, Generator | |
from typing import TYPE_CHECKING | |
import trio.lowlevel | |
from typing_extensions import ParamSpec, Self, TypeVar, overload | |
if TYPE_CHECKING: | |
from types import TracebackType | |
_P = ParamSpec("_P") | |
_YieldT_co = TypeVar("_YieldT_co", covariant=True) | |
_SendT_contra = TypeVar("_SendT_contra", contravariant=True, default=None) | |
_ReturnT_co = TypeVar("_ReturnT_co", covariant=True, default=None) | |
_SendT_contra_nd = TypeVar("_SendT_contra_nd", contravariant=True) | |
_ReturnT_co_nd = TypeVar("_ReturnT_co_nd", covariant=True) | |
@dataclasses.dataclass | |
class WrapCoro( | |
Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], | |
Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], | |
): | |
_current_task: trio.lowlevel.Task | |
_cancel_scope: trio.CancelScope | |
_coro: Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd] | |
_was_cancelled: bool = False | |
def __await__(self) -> Self: | |
return self | |
def send(self, value: _SendT_contra_nd) -> _YieldT_co: | |
if self._was_cancelled: | |
self._cancel_scope.shield = True | |
return self._coro.send(value) | |
cancelled = trio.current_effective_deadline() == -math.inf | |
r = self._coro.send(value) | |
if cancelled: | |
self._was_cancelled = True | |
return r | |
@overload | |
def throw( | |
self, | |
typ: type[BaseException], | |
val: BaseException | object = None, | |
tb: TracebackType | None = None, | |
) -> _YieldT_co: ... | |
@overload | |
def throw( | |
self, | |
typ: BaseException, | |
val: None = None, | |
tb: TracebackType | None = None, | |
) -> _YieldT_co: ... | |
def throw( | |
self, | |
typ: type[BaseException] | BaseException, | |
val: object = None, | |
tb: TracebackType | None = None, | |
) -> _YieldT_co: | |
if val is None and tb is None: | |
return self._coro.throw(typ) | |
return self._coro.throw(typ, val, tb) # type: ignore[arg-type] | |
def close(self) -> None: | |
pass | |
def edge_cancel( | |
fn: Callable[_P, Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]], | |
) -> Callable[_P, Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]]: | |
async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _ReturnT_co_nd: | |
with trio.CancelScope() as scope: | |
return await WrapCoro( | |
trio.lowlevel.current_task(), | |
scope, | |
fn(*args, **kwargs), | |
) | |
return wrapper | |
async def demo() -> None: | |
@edge_cancel | |
async def edge_cancelled() -> None: | |
print("started") | |
try: | |
await trio.sleep(math.inf) | |
except BaseException: | |
print("cancelled!") | |
await trio.sleep(0.1) | |
print("slept!") | |
await trio.sleep(0.1) | |
print("slept!") | |
raise | |
async with trio.open_nursery() as nursery: | |
nursery.start_soon(edge_cancelled) | |
nursery.cancel_scope.cancel() | |
trio.run(demo) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment