Last active
March 9, 2021 18:18
-
-
Save xoelop/46a18e311d59ef7ed20f023f47d7ec93 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tqdm import tqdm | |
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed | |
def parallel_process(array, | |
function, | |
type_pool: str = 'multithreading', | |
use_kwargs=False, | |
n_jobs=16, | |
front_num=3, | |
desc='', | |
**func_kwargs): | |
""" | |
A parallel version of the map function with a progress bar. | |
Adapted from http://danshiebler.com/2016-09-14-parallel-progress-bar/ | |
Args: | |
array (array-like): An array to iterate over. | |
function (function): A python function to apply to the elements of array | |
type_pool (str, default='multithreading'): | |
'multiprocessing': use several cores | |
'multithreading': use several threads | |
use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of | |
keyword arguments to function | |
n_jobs (int, default=16): The number of cores to use | |
front_num (int, default=3): The number of iterations to run serially before | |
kicking off the parallel job. | |
Useful for catching bugs | |
**func_kwargs: additional keyword arguments for the function to use | |
Returns: | |
([function(array[0], **func_kwargs), function(array[1], **func_kwargs), ...], [error1, error2...]) | |
""" | |
# We run the first few iterations serially to catch bugs | |
if front_num > 0: | |
front = [function(**a, **func_kwargs) if use_kwargs | |
else function(a, **func_kwargs) | |
for a in array[:front_num]] | |
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging. | |
if n_jobs == 1: | |
return front + [function(**a, **func_kwargs) if use_kwargs | |
else function(a, **func_kwargs) | |
for a in tqdm(array[front_num:])] | |
# Assemble the workers | |
if type_pool == 'multithreading': | |
concurrent_pool = ThreadPoolExecutor | |
elif type_pool == 'multiprocessing': | |
concurrent_pool = ProcessPoolExecutor | |
with concurrent_pool(max_workers=n_jobs) as pool: | |
# Pass the elements of array into function | |
if use_kwargs: | |
futures = [pool.submit(function, **a, **func_kwargs) for a in array[front_num:]] | |
else: | |
futures = [pool.submit(function, a, **func_kwargs) for a in array[front_num:]] | |
kwargs = { | |
'total': len(futures), | |
'unit': 'it', | |
'unit_scale': True, | |
'leave': True | |
} | |
# Print out the progress as tasks complete | |
for f in tqdm(as_completed(futures), desc=desc, **kwargs): | |
pass | |
# Get the results from the futures. | |
out = [] | |
errors = [] | |
for i, future in tqdm(enumerate(futures)): | |
try: | |
out.append(future.result()) | |
except Exception as e: | |
errors.append(e) | |
result = out if front_num == 0 else front + out | |
return result, errors |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment