Last active
September 26, 2022 01:47
-
-
Save zhaowb/334ddc81094a51ee745bfccfbc6900da to your computer and use it in GitHub Desktop.
Use concurrent.futures.ThreadPoolExecutor do similar as multiprocessing.ThreadPool.imap, with limited memory footprint
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import queue | |
from concurrent.futures import ThreadPoolExecutor | |
from contextlib import contextmanager | |
from typing import Any, Callable, Iterator | |
@contextmanager | |
def concurrent_imap(func: Callable[[Any], Any], | |
iterable: Iterator[Any], | |
*, | |
raise_func_err: bool = True, | |
num_thread: int = 4 | |
) -> Iterator[Any]: | |
"""multithread imap as contextmanager. If consumer stopped, threads are also stopped automatically. | |
:param raise_func_err: What to do when func() raises an error: | |
False: yield exception as result. | |
True: (default) raise exception and stop executor | |
Usage: | |
``` | |
with concurrent_imap(func, iterable) as ctx: | |
for result in ctx: # [1] | |
process result # [2] | |
``` | |
- When iterable fails, [1] raises error | |
- When loop [1] ends before all results are consume, ctx closes all threads. | |
Eg [2] failed, or [2] break the loop early | |
Simple use case: | |
Read S3 files into memory as {key: bytes} | |
``` | |
def load_s3_files(bucket: str, s3_prefix: str, num_thread=20) -> Mapping[str, bytes]: | |
# load all files return {key: bytes} using multithread | |
def func(s3_file): | |
return ( | |
s3_file['Key'], | |
s3.get_object( | |
Bucket=bucket, | |
Key=s3_file['Key'] | |
)['Body'].read() | |
) | |
with concurrent_imap(func, iter_s3_files(bucket, s3_prefix), num_thread=num_thread) as ctx: | |
return dict(ctx) | |
``` | |
Complex use case: | |
Read multiple files, parse to json, filter, build doc, bulk update elasticsearch index. | |
The process needs to check for remaining time limit and return a breakpoint for next process to continue. | |
A file key can be used as a breakpoint. | |
Pseudo code: | |
``` | |
from elasticsearch.helpers import streaming_bulk | |
def iter_list_files(start_after='', ...) -> Iterator[str]: | |
... | |
def read_file_to_json(key: str) -> dict: | |
... | |
def json_to_es_doc(dict) -> Optional[Document]: | |
... # return Document to index or None to skip | |
def time_left() -> float: | |
... # return remaining time limit in number of seconds | |
def upsert(doc): # helper for bulk | |
d = doc.to_dict(True) | |
d['_op_type'] = 'update' | |
d['doc'] = d['_source'] | |
d['doc_as_upsert'] = True | |
del d['_source'] | |
return d | |
# Notice each function above can raise exception | |
def process(..., start_after='') -> Optional[str]: | |
# return breakpoint or None for finished | |
func = lambda key: return json_to_es_doc(read_file_to_json(key)) | |
stopped = False | |
def iterable(): | |
for key in iter_list_files(start_after=start_after): | |
if time_left() < 60: | |
# 60 seconds left, clean up and return breakpoint | |
nonlocal stopped | |
stopped = True | |
return | |
yield key | |
es_client = ... | |
last_succ = None | |
with concurrent_imap(func, iterable(), raise_func_err=True) as ctx: | |
docs = (upsert(doc) for doc in ctx if doc is not None) | |
try: | |
for ok, item in streaming_bulk(es_client, docs, max_retries=3, max_backoff=10): | |
if ok: | |
result = pydash.get(item, 'update.result') | |
last_succ = item | |
except elasticsearch.exceptions.TransportError: | |
nonlocal stopped | |
stopped = True | |
if stopped: | |
doc_id = pydash.get(last_succ, 'update._id') | |
last_succ_doc = MyDoc.get(doc_id) | |
last_succ_file_key = ... # re-construct file key from document | |
return last_succ_file_key | |
return None | |
""" | |
q = queue.Queue(int(num_thread*1.2)+2) | |
alive = True | |
executor = None | |
def runner(): | |
def producer(): | |
nonlocal alive | |
for i in iterable: | |
future = exe.submit(func, i) | |
while True: | |
try: | |
q.put(future, timeout=0.1) | |
break | |
except queue.Full: | |
if not alive: | |
# caller failed, nobody is going to consume from queue, must stop | |
future.cancel() | |
return | |
with ThreadPoolExecutor(num_thread + 1) as exe: | |
nonlocal executor | |
executor = exe | |
producer_future = exe.submit(producer) | |
while True: | |
try: | |
future = q.get(timeout=0.1) | |
except queue.Empty: | |
if producer_future.done(): | |
break | |
continue | |
try: | |
yield future.result() | |
except Exception as exc: | |
if not raise_func_err: | |
yield exc | |
else: | |
nonlocal alive | |
alive = False | |
raise | |
executor = None | |
# If iterable failed, raise error here. | |
# The caller can get the error if the caller is waiting. | |
# If the caller is not waiting (failed in loop or finished loop before consume all results), | |
# then nobody cares about the error here. | |
try: | |
producer_future.result() | |
except Exception as exc: | |
raise exc | |
try: | |
yield runner() | |
finally: | |
alive = False | |
while q.qsize(): | |
# if caller failed, there might be some left in queue, they don't need continue | |
try: | |
q.get(timeout=0.01).cancel() | |
except queue.Empty: | |
break | |
if executor: | |
executor.shutdown() | |
if __name__ == '__main__': | |
imap = concurrent_imap | |
import threading | |
import time | |
def func(i): return time.sleep(0.1) or i*2 | |
def fail(i): return time.sleep(0.1) or (i*2 if i < 10 else i/0) | |
def test_imap_succ(N, num_thread=20): | |
print('>>> Test succ loop <<<') | |
iterable = range(N) | |
active_account_before_imap = threading.active_count() | |
ts0 = time.perf_counter() | |
with imap(func, iterable, num_thread=num_thread) as ctx: | |
results = list(ctx) | |
used = time.perf_counter() - ts0 | |
ideal = N/num_thread*0.1 | |
print(f'results {len(results)} ' | |
f'used {used:.3f} secs vs ideal {ideal:.3f} secs overhead={used/ideal:.2%}') | |
assert len(results) == N | |
assert results == [i*2 for i in range(N)] | |
assert active_account_before_imap == threading.active_count(), f'active_account={threading.active_count()} != {active_account_before_imap}' | |
print() | |
def test_imap_input_err(): | |
print('>>> Test input data failure <<<') | |
def iterable2(): | |
yield 1 | |
yield 2 | |
raise RuntimeError('simulate input err') | |
active_account_before_imap = threading.active_count() | |
from itertools import islice | |
try: | |
with imap(func, iterable2()) as ctx: | |
for result in islice(ctx, 5): | |
print('result', result) | |
assert False, 'SHOULD NOT REACH THIS LINE' | |
except RuntimeError as exc: | |
if str(exc) == 'simulate input err': | |
print('OK: failure raised as expected') | |
else: | |
raise | |
assert active_account_before_imap == threading.active_count(), f'active_account={threading.active_count()} != {active_account_before_imap}' | |
print() | |
def test_imap_return_worker_err(): | |
print('>>> Test return worker failure <<<') | |
iterable = range(12) | |
active_account_before_imap = threading.active_count() | |
with imap(fail, iterable) as ctx: | |
results = list(ctx) | |
print('example when worker failed', results) | |
assert active_account_before_imap == threading.active_count(), f'active_account={threading.active_count()} != {active_account_before_imap}' | |
print() | |
def test_imap_raise_worker_err(): | |
print('>>> Test raise worker failure <<<') | |
iterable = range(14) | |
active_account_before_imap = threading.active_count() | |
results = [] | |
with imap(fail, iterable, raise_func_err=True) as ctx: | |
try: | |
for result in ctx: | |
results.append(result) | |
except ZeroDivisionError: | |
print('OK: worker error raised') | |
assert len(results) == 10 | |
assert active_account_before_imap == threading.active_count(), f'active_account={threading.active_count()} != {active_account_before_imap}' | |
print() | |
def test_imap_caller_err(): | |
print('>>> Test caller failure <<<') | |
iterable = range(10) | |
active_account_before_imap = threading.active_count() | |
try: | |
with imap(func, iterable) as ctx: | |
for result in ctx: | |
assert result < 10, 'simulate caller fail' | |
except AssertionError: | |
print('OK failure raised as expected') | |
assert active_account_before_imap == threading.active_count(), f'active_account={threading.active_count()} != {active_account_before_imap}' | |
print() | |
test_imap_succ(500) | |
# test_imap_succ(5000) | |
test_imap_input_err() | |
test_imap_return_worker_err() | |
test_imap_raise_worker_err() | |
test_imap_caller_err() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment