Last active
April 8, 2020 15:18
-
-
Save schwartzmx/af3910099ecd76928159cf5c3228bacb to your computer and use it in GitHub Desktop.
Wrapper around concurrent.futures.ThreadPoolExecutor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from concurrent.futures import ( | |
ThreadPoolExecutor, | |
wait, | |
CancelledError, | |
TimeoutError, | |
ALL_COMPLETED, | |
) | |
import os | |
logger = logging.getLogger(__name__) | |
class Executor(object): | |
def __init__(self, | |
max_workers=os.cpu_count(), | |
task_timeout=60, | |
default_result=[]): | |
self.max_workers = max_workers | |
self.pool = ThreadPoolExecutor(max_workers=self.max_workers) | |
self.task_timeout = task_timeout | |
self.task_futures = [] | |
self.default_result = default_result | |
def __enter__(self): | |
return self | |
def submit_task(self, fn, **kwargs): | |
''' | |
submits a task to the thread pool, | |
fn: is a callable function | |
kwargs: is any function key word arguments to pass to the callable | |
''' | |
logger.debug('Task queued: (fn: %s kwargs: %s)', fn, kwargs) | |
self.task_futures.append(self.pool.submit(fn, **kwargs)) | |
def results(self): | |
''' | |
returns a list of results for all task_futures that were submitted, | |
returns an empty list if task_futures is empty | |
''' | |
results = [] | |
if not self.task_futures: | |
return results | |
# wait for all tasks | |
completed_tasks, _ = wait(self.task_futures, timeout=self.task_timeout) | |
for task in completed_tasks: | |
task_result = self.default_result | |
try: | |
task_result = task.result() | |
logger.debug('Task %s Result: %s', task, task_result) | |
except TimeoutError: | |
logger.exception('Timeout error occurred in executor for task %s', task) | |
except CancelledError: | |
logger.exception('Cancelled task occurred in executor for task %s', task) | |
results.append(task_result) | |
return results | |
def wait_all_completed(self, timeout=None): | |
'''Wait until all task_futures are completed before returning, this method | |
returns tuple of: completed_tasks, uncompleted_tasks | |
If timeout is not provided, the Executor.task_timeout is used.''' | |
wait_timeout = timeout if timeout else self.task_timeout | |
return wait(self.task_futures, timeout=wait_timeout, return_when=ALL_COMPLETED) | |
def shutdown(self, wait=True): | |
''' explicity shutdown the underlying ThreadPoolExecutor, it is recommended | |
to use this class using a `with` statement, for example: | |
``` | |
with Executor() as executor: | |
executor.submit_task(fn, fn_arg=1, fn_arg2=3) | |
executor.submit_task(task, **kwargs) | |
... | |
results = executor.results() | |
``` | |
''' | |
self.pool.shutdown(wait=wait) | |
def __repr__(self): | |
return '<Executor(max_workers={0}, task_timeout={1}, task_futures_count={2}, default_result={3})>'.format( | |
self.max_workers, | |
self.task_timeout, | |
len(self.task_futures), | |
self.default_result | |
) | |
def __exit__(self, exc_type, exc_value, traceback): | |
logger.debug('Exiting Executor %s - closing pool.', self) | |
self.shutdown() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment