Skip to content

Instantly share code, notes, and snippets.

@zhaowb
Last active September 26, 2022 01:48
Show Gist options
  • Save zhaowb/77a6b7272cdd15751f1f88bcb8fff64a to your computer and use it in GitHub Desktop.
Save zhaowb/77a6b7272cdd15751f1f88bcb8fff64a to your computer and use it in GitHub Desktop.
multithreading.pool.ThreadPool.imap with limited memory footprint
from multiprocessing.pool import ThreadPool
import queue
from typing import Any, Callable, Iterator
def threadpool_imap(func: Callable[[Any], Any],
iterable: Iterator[Any],
*,
num_thread: int = 4
) -> Iterator[Any]:
"""This should be a contextmanager to close threads properly when the caller stopped iterate results"""
class BlockingPool(ThreadPool):
"""limit only consume input up to queue_size
Copy from https://github.com/elastic/elasticsearch-py/blob/cf0196f5fca6187e1221ed25c43091bb9ca05122/elasticsearch/helpers/actions.py#L424
Because ThreadPool default consumes input as much as possible, it breaks
the point to use iterator for input. Usually iterator is used to minimize
memory usesage only to read input when about to process the input.
"""
def _setup_queues(self):
# default is https://github.com/python/cpython/blob/main/Lib/multiprocessing/pool.py#L932
# self._inqueue is SimpleQueue without limit
super()._setup_queues()
self._inqueue = queue.Queue(num_thread + 1)
self._quick_put = self._inqueue.put
pool = BlockingPool(num_thread)
try:
yield from pool.imap(func, iterable)
finally:
pool.close()
pool.join()
if __name__ == '__main__':
# performance: 10000 used 51.135 seconds vs 50.000 seconds eff=102.27%
imap = threadpool_imap
import time
func = lambda i: time.sleep(0.1) or i*2
def test_imap_succ():
N = 10000
num_thread = 20
iterable = range(N)
ts0 = time.perf_counter()
results = list(imap(func, iterable, num_thread=num_thread))
used = time.perf_counter() - ts0
ideal = N/num_thread*0.1
print(len(results), f'used {used:.3f} seconds vs {ideal:.3f} seconds eff={used/ideal:.2%}')
assert len(results) == N
assert sorted(results) == results
assert results == [i*2 for i in range(N)]
def test_imap_err():
def iterable2():
yield 1
yield 2
raise RuntimeError('error')
ts0 = time.perf_counter()
try:
for result in imap(func, iterable2(), num_thread=4):
print('result', result)
except:
import traceback
traceback.print_exc()
used = time.perf_counter() - ts0
print(f'used {used:.3f} seconds')
test_imap_succ()
test_imap_err()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment