Skip to content

Instantly share code, notes, and snippets.

@graingert
Created December 15, 2024 10:50
Show Gist options
  • Save graingert/3c75a9bec75ea1535e22a2f7d938a344 to your computer and use it in GitHub Desktop.
Save graingert/3c75a9bec75ea1535e22a2f7d938a344 to your computer and use it in GitHub Desktop.
from __future__ import annotations
import dataclasses
import math
from collections.abc import Callable, Coroutine, Generator
from typing import TYPE_CHECKING
import trio.lowlevel
from typing_extensions import ParamSpec, Self, TypeVar, overload
if TYPE_CHECKING:
from types import TracebackType
_P = ParamSpec("_P")
_YieldT_co = TypeVar("_YieldT_co", covariant=True)
_SendT_contra = TypeVar("_SendT_contra", contravariant=True, default=None)
_ReturnT_co = TypeVar("_ReturnT_co", covariant=True, default=None)
_SendT_contra_nd = TypeVar("_SendT_contra_nd", contravariant=True)
_ReturnT_co_nd = TypeVar("_ReturnT_co_nd", covariant=True)
@dataclasses.dataclass
class WrapCoro(
Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
):
_current_task: trio.lowlevel.Task
_cancel_scope: trio.CancelScope
_coro: Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]
_was_cancelled: bool = False
def __await__(self) -> Self:
return self
def send(self, value: _SendT_contra_nd) -> _YieldT_co:
if self._was_cancelled:
self._cancel_scope.shield = True
return self._coro.send(value)
cancelled = trio.current_effective_deadline() == -math.inf
r = self._coro.send(value)
if cancelled:
self._was_cancelled = True
return r
@overload
def throw(
self,
typ: type[BaseException],
val: BaseException | object = None,
tb: TracebackType | None = None,
) -> _YieldT_co: ...
@overload
def throw(
self,
typ: BaseException,
val: None = None,
tb: TracebackType | None = None,
) -> _YieldT_co: ...
def throw(
self,
typ: type[BaseException] | BaseException,
val: object = None,
tb: TracebackType | None = None,
) -> _YieldT_co:
if val is None and tb is None:
return self._coro.throw(typ)
return self._coro.throw(typ, val, tb) # type: ignore[arg-type]
def close(self) -> None:
pass
def edge_cancel(
fn: Callable[_P, Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]],
) -> Callable[_P, Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]]:
async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _ReturnT_co_nd:
with trio.CancelScope() as scope:
return await WrapCoro(
trio.lowlevel.current_task(),
scope,
fn(*args, **kwargs),
)
return wrapper
async def demo() -> None:
@edge_cancel
async def edge_cancelled() -> None:
print("started")
try:
await trio.sleep(math.inf)
except BaseException:
print("cancelled!")
await trio.sleep(0.1)
print("slept!")
await trio.sleep(0.1)
print("slept!")
raise
async with trio.open_nursery() as nursery:
nursery.start_soon(edge_cancelled)
nursery.cancel_scope.cancel()
trio.run(demo)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment