Last active
September 26, 2022 01:48
-
-
Save zhaowb/77a6b7272cdd15751f1f88bcb8fff64a to your computer and use it in GitHub Desktop.
multithreading.pool.ThreadPool.imap with limited memory footprint
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from multiprocessing.pool import ThreadPool | |
import queue | |
from typing import Any, Callable, Iterator | |
def threadpool_imap(func: Callable[[Any], Any], | |
iterable: Iterator[Any], | |
*, | |
num_thread: int = 4 | |
) -> Iterator[Any]: | |
"""This should be a contextmanager to close threads properly when the caller stopped iterate results""" | |
class BlockingPool(ThreadPool): | |
"""limit only consume input up to queue_size | |
Copy from https://github.com/elastic/elasticsearch-py/blob/cf0196f5fca6187e1221ed25c43091bb9ca05122/elasticsearch/helpers/actions.py#L424 | |
Because ThreadPool default consumes input as much as possible, it breaks | |
the point to use iterator for input. Usually iterator is used to minimize | |
memory usesage only to read input when about to process the input. | |
""" | |
def _setup_queues(self): | |
# default is https://github.com/python/cpython/blob/main/Lib/multiprocessing/pool.py#L932 | |
# self._inqueue is SimpleQueue without limit | |
super()._setup_queues() | |
self._inqueue = queue.Queue(num_thread + 1) | |
self._quick_put = self._inqueue.put | |
pool = BlockingPool(num_thread) | |
try: | |
yield from pool.imap(func, iterable) | |
finally: | |
pool.close() | |
pool.join() | |
if __name__ == '__main__': | |
# performance: 10000 used 51.135 seconds vs 50.000 seconds eff=102.27% | |
imap = threadpool_imap | |
import time | |
func = lambda i: time.sleep(0.1) or i*2 | |
def test_imap_succ(): | |
N = 10000 | |
num_thread = 20 | |
iterable = range(N) | |
ts0 = time.perf_counter() | |
results = list(imap(func, iterable, num_thread=num_thread)) | |
used = time.perf_counter() - ts0 | |
ideal = N/num_thread*0.1 | |
print(len(results), f'used {used:.3f} seconds vs {ideal:.3f} seconds eff={used/ideal:.2%}') | |
assert len(results) == N | |
assert sorted(results) == results | |
assert results == [i*2 for i in range(N)] | |
def test_imap_err(): | |
def iterable2(): | |
yield 1 | |
yield 2 | |
raise RuntimeError('error') | |
ts0 = time.perf_counter() | |
try: | |
for result in imap(func, iterable2(), num_thread=4): | |
print('result', result) | |
except: | |
import traceback | |
traceback.print_exc() | |
used = time.perf_counter() - ts0 | |
print(f'used {used:.3f} seconds') | |
test_imap_succ() | |
test_imap_err() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment