Skip to content

Instantly share code, notes, and snippets.

@smurfix
Last active April 14, 2025 08:57
Show Gist options
  • Save smurfix/9c986d3f20beb5e55b9c9404e795e5ed to your computer and use it in GitHub Desktop.
Save smurfix/9c986d3f20beb5e55b9c9404e795e5ed to your computer and use it in GitHub Desktop.
from __future__ import annotations
import sys
import trio
from outcome import Outcome, Value, Error, AlreadyUsedError
from attrs import define,field
from contextlib import asynccontextmanager
from enum import IntEnum
class ErrorCapture(IntEnum):
Never=0
Cancel=1
Always=2
@define
class Future:
nursery: trio.abc.Nursery=field()
scope: trio.abc.CancelScope=field()
result: Outcome|None=field(default=None, init=False)
ready: trio.Event=field(factory=trio.Event, init=False)
def __hash__(self):
return id(self)
async def wait(self):
await self.ready.wait()
if self.result is None:
raise AlreadyUsedError
try:
return self.result.unwrap()
except Exception:
self.nursery._todo.remove(self)
raise
finally:
self.result = None
def cancel(self):
self.scope.cancel()
class NoResult(Exception):
pass
class NurseryPlusManager:
strict_exception_groups: bool = True
@trio.lowlevel.enable_ki_protection
async def __aenter__(self) -> trio.abc.Nursery:
self._todo = set()
self._scope = trio.CancelScope()
self._scope.__enter__()
self._nursery = trio._core.Nursery._create(
trio.lowlevel.current_task(),
self._scope,
self.strict_exception_groups,
)
self._nursery.start_task = self._start_task
return self._nursery
async def _start_task(
self,
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]],
*args: Unpack[PosArgT],
name: object = None,
capture:Capture = False,
) -> Future:
async def _wrap(fn, args, nursery, *, task_status):
with trio.CancelScope() as scope:
f = Future(self,scope)
task_status.started(f)
try:
f.result = Value(await async_fn(*args))
except BaseException as exc:
if capture == ErrorCapture.Never:
f.result = NoResult(fn)
raise
f.result = Error(exc if isinstance(exc,Exception) else NoResult(fn))
if capture == ErrorCapture.Cancel:
nursery.cancel_scope.cancel()
self._todo.add(f)
if not isinstance(exc, Exception):
raise
finally:
f.ready.set()
return
return await self._nursery.start(_wrap, async_fn, args,
self._nursery, name=name)
@trio.lowlevel.enable_ki_protection
async def __aexit__(
self,
etype: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> bool:
new_exc = await self._nursery._nested_child_finished(exc)
# Tracebacks show the 'raise' line below out of context, so let's give
# this variable a name that makes sense out of context.
combined_error_from_nursery = self._scope._close(new_exc)
if self._todo:
# We have some unfinished business
errs = []
for f in self._todo:
errs.append(f.result.error)
f.result = None
errs = ExceptionGroup("Futures from Trio nursery", errs)
if combined_error_from_nursery is None:
combined_error_from_nursery = errs
else:
combined_error_from_nursery = type(combined_error_from_nursery)("Exceptions and Futures from Trio nursery",[combined_error_from_nursery,errs])
del self._todo
if combined_error_from_nursery is None:
return True
elif combined_error_from_nursery is exc:
return False
else:
# Copied verbatim from the old MultiErrorCatcher. Python doesn't
# allow us to encapsulate this __context__ fixup.
old_context = combined_error_from_nursery.__context__
try:
raise combined_error_from_nursery
finally:
_, value, _ = sys.exc_info()
assert value is combined_error_from_nursery
value.__context__ = old_context
# delete references from locals to avoid creating cycles
# see test_cancel_scope_exit_doesnt_create_cyclic_garbage
del _, combined_error_from_nursery, value, new_exc
def open_plus_nursery():
return NurseryPlusManager()
import trio
from nursery_plus import open_plus_nursery, ErrorCapture
async def res(x):
await trio.sleep(0.1)
if isinstance(x,Exception):
raise x
return x
async def test():
async with open_plus_nursery() as n:
r1 = await n.start_task(res,12, capture=ErrorCapture.Always)
r2 = await n.start_task(res,RuntimeError(), capture=ErrorCapture.Always)
r3 = await n.start_task(res,ValueError(), capture=ErrorCapture.Always)
assert 12 == await r1.wait()
try:
await r2.wait()
except RuntimeError:
pass
else:
raise AssertionError("nope")
# r3 will fall thru
raise AssertionError("nope either")
async def main():
errs = 0
try:
await test()
except* ValueError:
errs += 1
assert errs == 1
trio.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment