Last active
November 16, 2018 10:14
-
-
Save vxgmichel/cfc34f2ee34c150b9988034082622640 to your computer and use it in GitHub Desktop.
Safely merging async generators with trio
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
from contextlib import asynccontextmanager | |
import trio | |
@asynccontextmanager | |
async def aitercontext(resource): | |
aiterator = resource.__aiter__() | |
try: | |
yield aiterator | |
except BaseException as exc: | |
if hasattr(aiterator, 'athrow'): | |
await aiterator.athrow(exc) | |
raise | |
finally: | |
if hasattr(aiterator, 'aclose'): | |
await aiterator.aclose() | |
async def merge(agens): | |
async def produce(agen, channel): | |
async with channel: | |
async with aitercontext(agen) as safe_agen: | |
async for item in safe_agen: | |
await channel.send(item) | |
async with trio.open_nursery() as nursery: | |
send_channel, receive_channel = trio.open_memory_channel(0) | |
async with receive_channel: | |
async with send_channel: | |
for agen in agens: | |
nursery.start_soon(produce, agen, send_channel.clone()) | |
async for item in receive_channel: | |
yield item | |
async def random_agen(i): | |
for value in range(random.randint(1, 10)): | |
await trio.sleep(random.random()) | |
yield i, value | |
if random.random() > 0.95: | |
raise RuntimeError('Oops') | |
async def main(): | |
agens = [random_agen(i) for i in range(10)] | |
async with aitercontext(merge(agens)) as safe_agen: | |
async for item in safe_agen: | |
await trio.sleep(0.1) | |
print(*item) | |
if __name__ == '__main__': | |
trio.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment