Skip to content

Instantly share code, notes, and snippets.

@zhaowb
Last active September 26, 2022 01:47
Show Gist options
  • Save zhaowb/334ddc81094a51ee745bfccfbc6900da to your computer and use it in GitHub Desktop.
Save zhaowb/334ddc81094a51ee745bfccfbc6900da to your computer and use it in GitHub Desktop.
Use concurrent.futures.ThreadPoolExecutor do similar as multiprocessing.ThreadPool.imap, with limited memory footprint
import queue
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Callable, Iterator
@contextmanager
def concurrent_imap(func: Callable[[Any], Any],
iterable: Iterator[Any],
*,
raise_func_err: bool = True,
num_thread: int = 4
) -> Iterator[Any]:
"""multithread imap as contextmanager. If consumer stopped, threads are also stopped automatically.
:param raise_func_err: What to do when func() raises an error:
False: yield exception as result.
True: (default) raise exception and stop executor
Usage:
```
with concurrent_imap(func, iterable) as ctx:
for result in ctx: # [1]
process result # [2]
```
- When iterable fails, [1] raises error
- When loop [1] ends before all results are consume, ctx closes all threads.
Eg [2] failed, or [2] break the loop early
Simple use case:
Read S3 files into memory as {key: bytes}
```
def load_s3_files(bucket: str, s3_prefix: str, num_thread=20) -> Mapping[str, bytes]:
# load all files return {key: bytes} using multithread
def func(s3_file):
return (
s3_file['Key'],
s3.get_object(
Bucket=bucket,
Key=s3_file['Key']
)['Body'].read()
)
with concurrent_imap(func, iter_s3_files(bucket, s3_prefix), num_thread=num_thread) as ctx:
return dict(ctx)
```
Complex use case:
Read multiple files, parse to json, filter, build doc, bulk update elasticsearch index.
The process needs to check for remaining time limit and return a breakpoint for next process to continue.
A file key can be used as a breakpoint.
Pseudo code:
```
from elasticsearch.helpers import streaming_bulk
def iter_list_files(start_after='', ...) -> Iterator[str]:
...
def read_file_to_json(key: str) -> dict:
...
def json_to_es_doc(dict) -> Optional[Document]:
... # return Document to index or None to skip
def time_left() -> float:
... # return remaining time limit in number of seconds
def upsert(doc): # helper for bulk
d = doc.to_dict(True)
d['_op_type'] = 'update'
d['doc'] = d['_source']
d['doc_as_upsert'] = True
del d['_source']
return d
# Notice each function above can raise exception
def process(..., start_after='') -> Optional[str]:
# return breakpoint or None for finished
func = lambda key: return json_to_es_doc(read_file_to_json(key))
stopped = False
def iterable():
for key in iter_list_files(start_after=start_after):
if time_left() < 60:
# 60 seconds left, clean up and return breakpoint
nonlocal stopped
stopped = True
return
yield key
es_client = ...
last_succ = None
with concurrent_imap(func, iterable(), raise_func_err=True) as ctx:
docs = (upsert(doc) for doc in ctx if doc is not None)
try:
for ok, item in streaming_bulk(es_client, docs, max_retries=3, max_backoff=10):
if ok:
result = pydash.get(item, 'update.result')
last_succ = item
except elasticsearch.exceptions.TransportError:
nonlocal stopped
stopped = True
if stopped:
doc_id = pydash.get(last_succ, 'update._id')
last_succ_doc = MyDoc.get(doc_id)
last_succ_file_key = ... # re-construct file key from document
return last_succ_file_key
return None
"""
q = queue.Queue(int(num_thread*1.2)+2)
alive = True
executor = None
def runner():
def producer():
nonlocal alive
for i in iterable:
future = exe.submit(func, i)
while True:
try:
q.put(future, timeout=0.1)
break
except queue.Full:
if not alive:
# caller failed, nobody is going to consume from queue, must stop
future.cancel()
return
with ThreadPoolExecutor(num_thread + 1) as exe:
nonlocal executor
executor = exe
producer_future = exe.submit(producer)
while True:
try:
future = q.get(timeout=0.1)
except queue.Empty:
if producer_future.done():
break
continue
try:
yield future.result()
except Exception as exc:
if not raise_func_err:
yield exc
else:
nonlocal alive
alive = False
raise
executor = None
# If iterable failed, raise error here.
# The caller can get the error if the caller is waiting.
# If the caller is not waiting (failed in loop or finished loop before consume all results),
# then nobody cares about the error here.
try:
producer_future.result()
except Exception as exc:
raise exc
try:
yield runner()
finally:
alive = False
while q.qsize():
# if caller failed, there might be some left in queue, they don't need continue
try:
q.get(timeout=0.01).cancel()
except queue.Empty:
break
if executor:
executor.shutdown()
if __name__ == '__main__':
imap = concurrent_imap
import threading
import time
def func(i): return time.sleep(0.1) or i*2
def fail(i): return time.sleep(0.1) or (i*2 if i < 10 else i/0)
def test_imap_succ(N, num_thread=20):
print('>>> Test succ loop <<<')
iterable = range(N)
active_account_before_imap = threading.active_count()
ts0 = time.perf_counter()
with imap(func, iterable, num_thread=num_thread) as ctx:
results = list(ctx)
used = time.perf_counter() - ts0
ideal = N/num_thread*0.1
print(f'results {len(results)} '
f'used {used:.3f} secs vs ideal {ideal:.3f} secs overhead={used/ideal:.2%}')
assert len(results) == N
assert results == [i*2 for i in range(N)]
assert active_account_before_imap == threading.active_count(), f'active_account={threading.active_count()} != {active_account_before_imap}'
print()
def test_imap_input_err():
print('>>> Test input data failure <<<')
def iterable2():
yield 1
yield 2
raise RuntimeError('simulate input err')
active_account_before_imap = threading.active_count()
from itertools import islice
try:
with imap(func, iterable2()) as ctx:
for result in islice(ctx, 5):
print('result', result)
assert False, 'SHOULD NOT REACH THIS LINE'
except RuntimeError as exc:
if str(exc) == 'simulate input err':
print('OK: failure raised as expected')
else:
raise
assert active_account_before_imap == threading.active_count(), f'active_account={threading.active_count()} != {active_account_before_imap}'
print()
def test_imap_return_worker_err():
print('>>> Test return worker failure <<<')
iterable = range(12)
active_account_before_imap = threading.active_count()
with imap(fail, iterable) as ctx:
results = list(ctx)
print('example when worker failed', results)
assert active_account_before_imap == threading.active_count(), f'active_account={threading.active_count()} != {active_account_before_imap}'
print()
def test_imap_raise_worker_err():
print('>>> Test raise worker failure <<<')
iterable = range(14)
active_account_before_imap = threading.active_count()
results = []
with imap(fail, iterable, raise_func_err=True) as ctx:
try:
for result in ctx:
results.append(result)
except ZeroDivisionError:
print('OK: worker error raised')
assert len(results) == 10
assert active_account_before_imap == threading.active_count(), f'active_account={threading.active_count()} != {active_account_before_imap}'
print()
def test_imap_caller_err():
print('>>> Test caller failure <<<')
iterable = range(10)
active_account_before_imap = threading.active_count()
try:
with imap(func, iterable) as ctx:
for result in ctx:
assert result < 10, 'simulate caller fail'
except AssertionError:
print('OK failure raised as expected')
assert active_account_before_imap == threading.active_count(), f'active_account={threading.active_count()} != {active_account_before_imap}'
print()
test_imap_succ(500)
# test_imap_succ(5000)
test_imap_input_err()
test_imap_return_worker_err()
test_imap_raise_worker_err()
test_imap_caller_err()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment