Skip to content

Instantly share code, notes, and snippets.

@richardsheridan
Last active January 2, 2023 22:14
Show Gist options
  • Save richardsheridan/42d99cbcfcc1d77bc61890b2fae4bfaa to your computer and use it in GitHub Desktop.
Save richardsheridan/42d99cbcfcc1d77bc61890b2fae4bfaa to your computer and use it in GitHub Desktop.
map_concurrently_in_subthread_trio
import queue
import random
from functools import partial
from time import sleep, perf_counter
import trio
CONCURRENCY_LIMIT = 8
limiter = trio.CapacityLimiter(CONCURRENCY_LIMIT)
def sync_work(item):
sleep(item[1])
return item
async def asyncify_iterator(iter, limiter=None):
sentinel = object()
while (
result := await trio.to_thread.run_sync(next, iter, sentinel, limiter=limiter)
) is not sentinel:
yield result
async def worker_task(i, func, send_chan, task_status):
async with limiter, send_chan:
task_status.started()
result = await trio.to_thread.run_sync(func, limiter=trio.CapacityLimiter(1))
await send_chan.send((i, result))
async def result_sorter(result_send, recv_chan):
results = {}
async with recv_chan:
j = 0
async for i, result in recv_chan:
if i != j:
results[i] = result
continue
while True:
await result_send.send(result)
j += 1
if j in results:
result = results.pop(j)
else:
break
async def amain(run_data_queue, func, items, args, kwargs):
send_chan, recv_chan = trio.open_memory_channel(0)
result_send, result_recv = trio.open_memory_channel(0)
async with trio.open_nursery() as nursery:
run_data_queue.put(
(
trio.lowlevel.current_trio_token(),
nursery.cancel_scope.cancel,
result_recv.receive,
)
)
nursery.start_soon(result_sorter, result_send, recv_chan)
item_aiter = asyncify_iterator(iter(items), limiter)
i = 0
async for item in item_aiter:
await nursery.start(
worker_task, i, partial(func, item, *args, **kwargs), send_chan.clone()
)
i += 1
send_chan.close()
def map_concurrently_in_subthread_trio(func, items, args=(), kwargs={}):
run_data_queue = queue.SimpleQueue()
def trio_main():
return trio.run(amain, run_data_queue, func, items, args, kwargs)
def deliver(result):
run_data_queue.put(result)
trio.lowlevel.start_thread_soon(trio_main, deliver)
token, cancel, receive = run_data_queue.get()
while True:
try:
value = trio.from_thread.run(receive, trio_token=token)
except (trio.RunFinishedError, trio.Cancelled):
break # don't unwrap here, to avoid chaining exceptions
try:
yield value
except BaseException:
try:
token.run_sync_soon(cancel)
except trio.RunFinishedError:
pass
run_data_queue.get().unwrap()
raise
return run_data_queue.get().unwrap()
if __name__ == "__main__":
t0 = perf_counter()
for x in map_concurrently_in_subthread_trio(
sync_work, ((i, random.random()) for i in range(100))
):
print(x)
sleep(0.1)
t = perf_counter()
print(t - t0, "is less than 50")
@richardsheridan
Copy link
Author

richardsheridan commented Apr 15, 2021

Notes:

  1. Needs a way to shutdown the trio thread if the generator is GC'd or someone throws an exception in.
  2. If the first item never finishes processing, the result heap will build up forever.
    • This could be solved with a timeout on work or asserting something about the size of the heap before doing heappush.
  3. Beware cancellable thread semantics: the intention is to let the OS zap daemon threads during unclean shutdown, but if you e.g. catch KeyboardInterrupt processing will continue silently in the background until the jobs finish and results are discarded.

@richardsheridan
Copy link
Author

Fixed 1 and 3 above, also now responds to backpressure from the queue if results are being consumed slowly.

@richardsheridan
Copy link
Author

Now uses trio.from_thread.run* to improve readability 100x and make result_sender fully natively async.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment