Last active
July 31, 2022 09:11
-
-
Save graingert/d20fdaa41511c4cccb756259ee477444 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import collections.abc | |
import contextlib | |
import functools | |
import types | |
import httpx | |
import sniffio | |
import trio | |
TRIO_DONE = object() | |
class DumbFuture: | |
_asyncio_future_blocking = True | |
_add_done_callback = None | |
_on_cancel = None | |
def __init__(self, add_done_callback, on_cancel): | |
self._add_done_callback = add_done_callback | |
self._on_cancel = on_cancel | |
def cancel(self, *args, **kwargs): | |
v = self._on_cancel | |
# break a reference cycle and only support level cancel | |
del self._on_cancel | |
v() | |
# asyncio.Task.cancel calls: | |
# | |
# if self._fut_waiter is not None: | |
# if self._fut_waiter.cancel(msg=msg): | |
# # Leave self._fut_waiter; it may be a Task that | |
# # catches and ignores the cancellation so we may have | |
# # to cancel it again later. | |
# return True | |
# # It must be the case that self.__step is already scheduled. | |
# self._must_cancel = True | |
# self._cancel_message = msg | |
# fut_waiter (that's us) needs to return True otherwise task._must_cancel | |
# is set to True, which means when we wake up the task it will call | |
# coro.throw(CancelledError)! | |
return True | |
def get_loop(self): | |
return asyncio.get_running_loop() | |
def add_done_callback(self, fn, *, context): | |
v = self._add_done_callback | |
# break a reference cycle and detect multiple add_done_callbacks | |
del self._add_done_callback | |
if v is None: | |
raise AssertionError("only one task can listen to a Future at a time") | |
v(fn, context) | |
@types.coroutine | |
def _async_yield(v): | |
return (yield v) | |
@collections.abc.Coroutine.register | |
class WrapCoro: | |
def __init__(self, coro, context): | |
self._coro = coro | |
self._context = context | |
def __await__(self): | |
return self | |
def __iter__(self): | |
return self | |
def __next__(self): | |
return self.send(None) | |
def throw(self, *exc_info): | |
result = self._context.run(self._coro.throw, *exc_info) | |
if result is TRIO_DONE: | |
raise StopIteration | |
return result | |
def send(self, v): | |
result = self._context.run(self._coro.send, v) | |
if result is TRIO_DONE: | |
raise StopIteration | |
return result | |
class NullFuture: | |
def result(self): | |
return None | |
class NullContext: # sniffio stores the current async library on the context and not a threadlocal | |
def run(self, fn, /, *args, **kwargs): | |
return fn(*args, **kwargs) | |
def done_callback(outcome, call_soon, callback, context): | |
del outcome # we don't need the outcome, it can only be None | |
call_soon(callback, NullFuture(), context=context) | |
@contextlib.asynccontextmanager | |
async def as_trio(): | |
cancel_scope = trio.CancelScope() | |
# Revised 'done' callback: set a Future | |
async def trio_main(coro): | |
with cancel_scope: | |
return await coro | |
def add_done_callback(callback, context): | |
task = asyncio.current_task() | |
loop = task.get_loop() | |
loop.call_soon( | |
functools.partial( | |
trio.lowlevel.start_guest_run, | |
functools.partial( | |
trio_main, WrapCoro(task.get_coro(), context=NullContext()) | |
), | |
run_sync_soon_not_threadsafe=loop.call_soon, | |
run_sync_soon_threadsafe=loop.call_soon_threadsafe, | |
done_callback=functools.partial( | |
done_callback, | |
call_soon=loop.call_soon, | |
callback=callback, | |
context=context, | |
), | |
) | |
) | |
# suspend the current task so we can use its coro | |
await _async_yield( | |
DumbFuture( | |
add_done_callback=add_done_callback, on_cancel=cancel_scope.cancel | |
) | |
) | |
try: | |
yield | |
finally: | |
# tell our WrapCoro that trio is done | |
await _async_yield(TRIO_DONE) | |
async def demo(client): | |
r = await client.get("https://google.com") | |
print(r) | |
async def main(): | |
task = asyncio.current_task() | |
task.get_loop().call_later(1, task.cancel) | |
try: | |
async with as_trio(): | |
print(sniffio.current_async_library()) | |
async with httpx.AsyncClient() as client: | |
async with trio.open_nursery() as nursery: | |
nursery.start_soon(demo, client) | |
nursery.start_soon(demo, client) | |
await trio.sleep(10) | |
except trio.Cancelled: | |
print("cancelled") | |
print(sniffio.current_async_library()) | |
async with httpx.AsyncClient() as client: | |
await asyncio.gather(demo(client), demo(client)) | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment