Skip to content

Instantly share code, notes, and snippets.

@thehesiod
Last active September 26, 2024 03:14
Asynchronous Worker Pool, allows for limiting number of concurrent tasks
import asyncio
from datetime import datetime, timezone
import os
def utc_now():
# utcnow returns a naive datetime, so we have to set the timezone manually <sigh>
return datetime.utcnow().replace(tzinfo=timezone.utc)
class Terminator:
pass
_unittest = os.environ.get('UNITTEST')
class AsyncWorkerPool:
def __init__(self, loop, num_workers: int, name: str, logger, worker_co, load_factor: int=1,
job_accept_duration: int=None, max_task_time: int=None, return_futures: bool=False,
raise_on_join: bool=False, log_every_n: int=None, expected_total=None):
"""
This class will create `num_workers` asyncio tasks to work against a queue of
`num_workers * load_factor` items of back-pressure (IOW we will block after such
number of items of work is in the queue). `worker_co` will be called
against each item retrieved from the queue. If any exceptions are raised out of
worker_co, self.exceptions will be set to True.
@param loop: asyncio loop to use
@param num_workers: number of async tasks which will pull from the internal queue
@param name: name of the worker pool (used for logging)
@param logger: logger to use
@param worker_co: async coroutine to call when an item is retrieved from the queue
@param load_factor: multiplier used for number of items in queue
@param job_accept_duration: maximum number of seconds from first push to last push before a TimeoutError will be thrown.
Set to None for no limit. Note this does not get reset on aenter/aexit.
@param max_task_time: maximum time allowed for each task before a CancelledError is raised in the task.
Set to None for no limit.
@param return_futures: set to reture to return a future for each `push` (imposes CPU overhead)
@param raise_on_join: raise on join if any exceptions have occurred, default is False
@param log_every_n: (optional) set to number of `push`s each time a log statement should be printed (default does not print every-n pushes)
@param expected_total: (optional) expected total number of jobs (used for `log_event_n` logging)
@return: instance of AsyncWorkerPool
"""
loop = loop or asyncio.get_event_loop()
self._loop = loop
self._num_workers = num_workers
self._logger = logger
self._queue = asyncio.Queue(num_workers * load_factor)
self._workers = None
self._exceptions = False
self._job_accept_duration = job_accept_duration
self._first_push_dt = None
self._max_task_time = max_task_time
self._return_futures = return_futures
self._raise_on_join = raise_on_join
self._name = name
self._worker_co = worker_co
self._total_queued = 0
self._log_every_n = log_every_n
self._expected_total = expected_total
async def _worker_loop(self):
while True:
got_obj = False
future = None
try:
item = await self._queue.get()
got_obj = True
if item.__class__ is Terminator:
break
future, args, kwargs = item
# the wait_for will cancel the task (task sees CancelledError) and raises a TimeoutError from here
# so be wary of catching TimeoutErrors in this loop
result = await asyncio.wait_for(self._worker_co(*args, **kwargs), self._max_task_time, loop=self._loop)
if future:
future.set_result(result)
except (KeyboardInterrupt, MemoryError, SystemExit) as e:
if future:
future.set_exception(e)
self._exceptions = True
raise
except BaseException as e:
self._exceptions = True
if future:
# don't log the failure when the client is receiving the future
future.set_exception(e)
elif not _unittest:
self._logger.exception('Worker call failed')
finally:
if got_obj:
self._queue.task_done()
@property
def exceptions(self):
return self._exceptions
@property
def total_queued(self):
return self._total_queued
async def __aenter__(self):
self.start()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.join()
async def push(self, *args, **kwargs) -> asyncio.Future:
""" Method to push work to `worker_co` passed to `__init__`.
:param args: position arguments to be passed to `worker_co`
:param kwargs: keyword arguments to be passed to `worker_co`
:return: future of result """
if self._first_push_dt is None:
self._first_push_dt = utc_now()
if self._job_accept_duration is not None and (utc_now() - self._first_push_dt) > self._job_accept_duration:
raise TimeoutError("Maximum lifetime of {} seconds of AsyncWorkerPool: {} exceeded".format(self._job_accept_duration, self._name))
future = asyncio.futures.Future(loop=self._loop) if self._return_futures else None
await self._queue.put((future, args, kwargs))
self._total_queued += 1
if self._log_every_n is not None and (self._total_queued % self._log_every_n) == 0:
self._logger.info("pushed {}/{} items to {} AsyncWorkerPool".format(self._total_queued, self._expected_total, self._name))
return future
def start(self):
""" Will start up worker pool and reset exception state """
assert self._workers is None
self._exceptions = False
self._workers = [asyncio.ensure_future(self._worker_loop(), loop=self._loop) for _ in range(self._num_workers)]
async def join(self):
# no-op if workers aren't running
if not self._workers:
return
self._logger.debug('Joining {}'.format(self._name))
# The Terminators will kick each worker from being blocked against the _queue.get() and allow
# each one to exit
for _ in range(self._num_workers):
await self._queue.put(Terminator())
try:
await wait_raise_on_exception(self._workers, loop=self._loop, logger=self._logger)
self._workers = None
except:
self._logger.exception('Exception joining {}'.format(self._name))
raise
finally:
self._logger.debug('Completed {}'.format(self._name))
if self._exceptions and self._raise_on_join:
raise Exception("Exception occurred in pool {}".format(self._name))
@SylvainAudel
Copy link

I m new with asyncio. your async worker pool is interesting but i don't understand exactly how to use it. There is a short example for it?

@thehesiod
Copy link
Author

thehesiod commented May 25, 2017

sorry, for some reason didn't get a notification of your comment, here's an example:

async def process_files(s3_client: S3ClientPool, bucket: str, root: str, content: dict or None, fc_queue: asyncio.Queue, src_bucket: str, src_root: str):
    pass  # do work

async def main():
    async with AsyncWorkerPool(None, s3_client.max_tasks(), 's3_lq bucket:{} prefix: {}'.format(bucket, list_prefix), logger, process_files,
                                   log_every_n=4000, raise_on_join=True) as wq:
        async for result in s3_client.list_pager(Bucket=bucket, Prefix=list_prefix):
            for c in result.get('Contents', []):
                # if c["Key"].endswith('/') and c["Size"] == 0: continue
                await wq.push(s3_client, bucket, path, c, *action_args)

@Fitblip
Copy link

Fitblip commented Nov 14, 2017

Hey this is super awesome! I find myself needing these sorts of coro pools often, and it's easy for things to become unruly.

Major 👍 for this. Have you considered making this a module/library?

@Fitblip
Copy link

Fitblip commented Dec 7, 2017

Just as a heads up @thehesiod - I pulled this into its own package as I needed it in multiple projects https://github.com/CaliDog/asyncpool

@thehesiod
Copy link
Author

ya, may do that, I have actually this and a multiprocess version

@uogbuji
Copy link

uogbuji commented Jun 6, 2018

@thehesoid do you have a link to the multiprocess version? Thanks.

@perkfly
Copy link

perkfly commented Oct 13, 2018

why not using a semaphore

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment