Skip to content

Instantly share code, notes, and snippets.

@ddelange
Last active September 24, 2024 10:28
Show Gist options
  • Save ddelange/c98b05437f80e4b16bf4fc20fde9c999 to your computer and use it in GitHub Desktop.
Save ddelange/c98b05437f80e4b16bf4fc20fde9c999 to your computer and use it in GitHub Desktop.
The missing ThreadPoolExecutor.imap (also works for ProcessPoolExecutor.imap)
from collections import deque
from concurrent.futures import ThreadPoolExecutor as _ThreadPoolExecutor
class ThreadPoolExecutor(_ThreadPoolExecutor):
"""Subclass with a lazy consuming imap method."""
def imap(self, fn, *iterables, timeout=None, queued_tasks_per_worker=2):
"""Ordered imap that lazily consumes iterables ref https://gist.github.com/ddelange/c98b05437f80e4b16bf4fc20fde9c999."""
futures, maxlen = deque(), self._max_workers * (queued_tasks_per_worker + 1)
popleft, append, submit = futures.popleft, futures.append, self.submit
def get():
"""Block until the next task is done and return the result."""
return popleft().result(timeout)
for args in zip(*iterables, strict=True):
append(submit(fn, *args))
if len(futures) == maxlen:
yield get()
while futures:
yield get()
import logging
import os
import typing
from collections import deque
from concurrent.futures import ThreadPoolExecutor as _ThreadPoolExecutor
from functools import wraps
from django import db
class ThreadPoolExecutor(_ThreadPoolExecutor):
"""Subclass with a django-specific lazy consuming imap method."""
def imap(self, fn, *iterables, timeout=None, queued_tasks_per_worker=2):
"""Ordered imap that lazily consumes iterables ref https://gist.github.com/ddelange/c98b05437f80e4b16bf4fc20fde9c999."""
futures, maxlen = deque(), self._max_workers * (queued_tasks_per_worker + 1)
popleft, append, submit = futures.popleft, futures.append, self.submit
def get():
"""Block until the next task is done and return the result."""
return popleft().result(timeout)
for args in zip(*iterables, strict=True):
append(submit(fn, *args))
if len(futures) == maxlen:
yield get()
while futures:
yield get()
@staticmethod
def closing(fn: typing.Callable) -> typing.Callable:
"""Close db connections created by a thread, before returning to the parent thread.
Avoids lingering db connections from threads that don't exist anymore.
References:
https://stackoverflow.com/a/73792156/5511061
Args:
fn: Function to decorate.
Returns:
Decorated function.
"""
@wraps(fn)
def wrapped(*args, **kwargs):
try:
return fn(*args, **kwargs)
finally:
db.connections.close_all()
return wrapped
def submit(self, fn, /, *args, **kwargs):
"""Subclass submit, including db connection clean up."""
fn = self.closing(fn)
try:
return super().submit(fn, *args, **kwargs)
except RuntimeError: # cannot schedule new futures after interpreter shutdown
logging.exception("Failed to submit future")
os._exit(1)
@ddelange
Copy link
Author

ddelange commented Jan 24, 2023

All of ThreadPoolExecutor.map, multiprocessing.pool.ThreadPool.imap, multiprocessing.pool.ThreadPool.imap_unordered will exhaust input iterables immediately to fill the queue.

For very long iterables this results in unnecessary memory consumption, hence the missing ThreadPoolExecutor.imap above. Small difference with ThreadPoolExecutor.map: here, timeout is per future result retrieval (instead of for the map operation as a whole).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment