Last active
April 14, 2025 08:57
-
-
Save smurfix/9c986d3f20beb5e55b9c9404e795e5ed to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import sys | |
import trio | |
from outcome import Outcome, Value, Error, AlreadyUsedError | |
from attrs import define,field | |
from contextlib import asynccontextmanager | |
from enum import IntEnum | |
class ErrorCapture(IntEnum): | |
Never=0 | |
Cancel=1 | |
Always=2 | |
@define | |
class Future: | |
nursery: trio.abc.Nursery=field() | |
scope: trio.abc.CancelScope=field() | |
result: Outcome|None=field(default=None, init=False) | |
ready: trio.Event=field(factory=trio.Event, init=False) | |
def __hash__(self): | |
return id(self) | |
async def wait(self): | |
await self.ready.wait() | |
if self.result is None: | |
raise AlreadyUsedError | |
try: | |
return self.result.unwrap() | |
except Exception: | |
self.nursery._todo.remove(self) | |
raise | |
finally: | |
self.result = None | |
def cancel(self): | |
self.scope.cancel() | |
class NoResult(Exception): | |
pass | |
class NurseryPlusManager: | |
strict_exception_groups: bool = True | |
@trio.lowlevel.enable_ki_protection | |
async def __aenter__(self) -> trio.abc.Nursery: | |
self._todo = set() | |
self._scope = trio.CancelScope() | |
self._scope.__enter__() | |
self._nursery = trio._core.Nursery._create( | |
trio.lowlevel.current_task(), | |
self._scope, | |
self.strict_exception_groups, | |
) | |
self._nursery.start_task = self._start_task | |
return self._nursery | |
async def _start_task( | |
self, | |
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], | |
*args: Unpack[PosArgT], | |
name: object = None, | |
capture:Capture = False, | |
) -> Future: | |
async def _wrap(fn, args, nursery, *, task_status): | |
with trio.CancelScope() as scope: | |
f = Future(self,scope) | |
task_status.started(f) | |
try: | |
f.result = Value(await async_fn(*args)) | |
except BaseException as exc: | |
if capture == ErrorCapture.Never: | |
f.result = NoResult(fn) | |
raise | |
f.result = Error(exc if isinstance(exc,Exception) else NoResult(fn)) | |
if capture == ErrorCapture.Cancel: | |
nursery.cancel_scope.cancel() | |
self._todo.add(f) | |
if not isinstance(exc, Exception): | |
raise | |
finally: | |
f.ready.set() | |
return | |
return await self._nursery.start(_wrap, async_fn, args, | |
self._nursery, name=name) | |
@trio.lowlevel.enable_ki_protection | |
async def __aexit__( | |
self, | |
etype: type[BaseException] | None, | |
exc: BaseException | None, | |
tb: TracebackType | None, | |
) -> bool: | |
new_exc = await self._nursery._nested_child_finished(exc) | |
# Tracebacks show the 'raise' line below out of context, so let's give | |
# this variable a name that makes sense out of context. | |
combined_error_from_nursery = self._scope._close(new_exc) | |
if self._todo: | |
# We have some unfinished business | |
errs = [] | |
for f in self._todo: | |
errs.append(f.result.error) | |
f.result = None | |
errs = ExceptionGroup("Futures from Trio nursery", errs) | |
if combined_error_from_nursery is None: | |
combined_error_from_nursery = errs | |
else: | |
combined_error_from_nursery = type(combined_error_from_nursery)("Exceptions and Futures from Trio nursery",[combined_error_from_nursery,errs]) | |
del self._todo | |
if combined_error_from_nursery is None: | |
return True | |
elif combined_error_from_nursery is exc: | |
return False | |
else: | |
# Copied verbatim from the old MultiErrorCatcher. Python doesn't | |
# allow us to encapsulate this __context__ fixup. | |
old_context = combined_error_from_nursery.__context__ | |
try: | |
raise combined_error_from_nursery | |
finally: | |
_, value, _ = sys.exc_info() | |
assert value is combined_error_from_nursery | |
value.__context__ = old_context | |
# delete references from locals to avoid creating cycles | |
# see test_cancel_scope_exit_doesnt_create_cyclic_garbage | |
del _, combined_error_from_nursery, value, new_exc | |
def open_plus_nursery(): | |
return NurseryPlusManager() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import trio | |
from nursery_plus import open_plus_nursery, ErrorCapture | |
async def res(x): | |
await trio.sleep(0.1) | |
if isinstance(x,Exception): | |
raise x | |
return x | |
async def test(): | |
async with open_plus_nursery() as n: | |
r1 = await n.start_task(res,12, capture=ErrorCapture.Always) | |
r2 = await n.start_task(res,RuntimeError(), capture=ErrorCapture.Always) | |
r3 = await n.start_task(res,ValueError(), capture=ErrorCapture.Always) | |
assert 12 == await r1.wait() | |
try: | |
await r2.wait() | |
except RuntimeError: | |
pass | |
else: | |
raise AssertionError("nope") | |
# r3 will fall thru | |
raise AssertionError("nope either") | |
async def main(): | |
errs = 0 | |
try: | |
await test() | |
except* ValueError: | |
errs += 1 | |
assert errs == 1 | |
trio.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment