Last active
June 24, 2025 18:50
-
-
Save ddelange/13b0f9da3147f3754b9e1e88c13303ba to your computer and use it in GitHub Desktop.
Multithreaded S3 downloads
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pip install smart_open[s3] | |
from collections import deque | |
from concurrent.futures import ThreadPoolExecutor as _ThreadPoolExecutor | |
from functools import partial | |
from typing import Callable, Dict, Optional, Iterable, Iterator, Sequence | |
import boto3 | |
import botocore | |
import smart_open | |
class URIDownloader: | |
"""Stream S3 URIs into memory using multithreading.""" | |
def __init__(self, threads: int = 64): | |
self.threads = threads | |
self.executor = ThreadPoolExecutor(max_workers=threads) | |
config = botocore.client.Config( | |
max_pool_connections=threads, | |
tcp_keepalive=True, | |
retries={"max_attempts": 6, "mode": "adaptive"}, | |
) | |
# thread-safe ref https://github.com/boto/boto3/blob/1.38.41/docs/source/guide/clients.rst?plain=1#L111 | |
client = boto3.session.Session().client("s3", config=config) | |
self._open = partial(smart_open.open, transport_params={"client": client}) | |
def read(self, uri: str, /, mode="rb", **kwargs) -> bytes: | |
"""Download (and decompress) a URI using smart_open.""" | |
with self._open(uri, mode, **kwargs) as fp: | |
return fp.read() | |
def read_multi(self, uris: Iterable[str], **kwargs) -> Iterator[bytes]: | |
"""Download (and decompress) URIs with a multithreaded boto3 client.""" | |
yield from self.executor.imap(partial(self.read, **kwargs), uris) | |
def read_multi_dict(self, uris: Sequence[str], **kwargs) -> Dict[str, bytes]: | |
"""Download (and decompress) URIs with a multithreaded boto3 client into a dict[uri, bytes].""" | |
return dict(zip(uris, self.read_multi(uris, **kwargs))) | |
class ThreadPoolExecutor(_ThreadPoolExecutor): | |
"""Subclass with a lazy consuming imap method.""" | |
def imap(self, fn, *iterables, timeout=None, queued_tasks_per_worker=2): | |
"""Ordered imap that consumes iterables just-in-time ref https://gist.github.com/ddelange/c98b05437f80e4b16bf4fc20fde9c999. | |
Args: | |
fn: Function to apply. | |
iterables: One (or more) iterable(s) to pass to fn (using zip) as positional argument(s). | |
timeout: Per-future result retrieval timeout in seconds. | |
queued_tasks_per_worker: Amount of additional items per worker to fetch from iterables to fill the queue: this determines the total queue size. | |
Setting 0 will result in a true just-in-time behaviour: when a worker finishes a task, it waits until a result is consumed from the imap generator, at which point next() is called on the input iterable(s) and a new task is submitted. | |
Default 2 ensures there is always some work to pick up. Note that at imap startup, the queue will fill up before the first yield occurs. | |
Example: | |
long_generator = itertools.count() | |
with ThreadPoolExecutor(42) as pool: | |
result_generator = pool.imap(fn, long_generator) | |
for result in result_generator: | |
print(result) | |
""" | |
futures, maxlen = deque(), self._max_workers * (queued_tasks_per_worker + 1) | |
popleft, append, submit = futures.popleft, futures.append, self.submit | |
def get(): | |
"""Block until the next task is done and return the result.""" | |
return popleft().result(timeout) | |
for args in zip(*iterables, strict=True): | |
append(submit(fn, *args)) | |
if len(futures) == maxlen: | |
yield get() | |
while futures: | |
yield get() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
for an asynchronous variant, see https://gist.github.com/ddelange/643fbb791b398783c04d1ceb90102163#file-proxy-py-L28