Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save yxlwfds/e194d43443a09753db3d440965109584 to your computer and use it in GitHub Desktop.
Save yxlwfds/e194d43443a09753db3d440965109584 to your computer and use it in GitHub Desktop.
map_concurrently_in_subthread_trio
import random
from functools import partial
from heapq import heappush, heappop
from time import sleep, perf_counter
import trio
import queue
CONCURRENCY_LIMIT = 8
limiter = trio.CapacityLimiter(CONCURRENCY_LIMIT)
def sync_work(item):
sleep(item)
return item
async def asyncify_iterator(iter, limiter=None):
sentinel = object()
while True:
result = await trio.to_thread.run_sync(
next, iter, sentinel, limiter=limiter, cancellable=True
)
if result is sentinel:
return
yield result
async def worker_task(i, func, send_chan, task_status):
async with limiter, send_chan:
task_status.started()
result = await trio.to_thread.run_sync(
func, limiter=trio.CapacityLimiter(1), cancellable=True
)
await send_chan.send((i, result))
async def result_sender(result_queue, recv_chan):
result_heap = []
async with recv_chan:
j = 0
async for i, result in recv_chan:
heappush(result_heap, (i, id(result), result))
while result_heap and result_heap[0][0] == j:
*_, result = heappop(result_heap)
await trio.to_thread.run_sync(result_queue.put, result)
j += 1
async def amain(result_queue, func, items, args, kwargs):
send_chan, recv_chan = trio.open_memory_channel(0)
async with trio.open_nursery() as nursery:
nursery.start_soon(result_sender, result_queue, recv_chan)
item_aiter = asyncify_iterator(iter(items), limiter)
i = 0
async for item in item_aiter:
await nursery.start(
worker_task, i, partial(func, item, *args, **kwargs), send_chan.clone()
)
i += 1
send_chan.close()
def map_concurrently_in_subthread_trio(func, items, args=(), kwargs={}):
trio_outcome = None
sentinel = object()
result_queue = queue.SimpleQueue()
def trio_main():
trio.run(amain, result_queue, func, items, args, kwargs)
def deliver(result):
nonlocal trio_outcome
trio_outcome = result
result_queue.put(sentinel)
trio.lowlevel.start_thread_soon(trio_main, deliver)
while (result := result_queue.get()) is not sentinel:
yield result
return trio_outcome.unwrap()
if __name__ == "__main__":
t0 = perf_counter()
for x in map_concurrently_in_subthread_trio(
sync_work, (random.random() for i in range(100))
):
print(x)
t = perf_counter()
print(t - t0, "is less than 50")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment