Created
January 2, 2025 03:09
-
-
Save tejashah88/00f4f31eee1d07683dedb9d22545f391 to your computer and use it in GitHub Desktop.
An implementation of an async version of click's CLI progress bar with atomic-based progress reporting.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import concurrent.futures | |
import threading | |
import click | |
def safe_run_async(async_fn, *argv): | |
loop = asyncio.get_event_loop() | |
try: | |
ret = loop.run_until_complete(async_fn(*argv)) | |
loop.run_until_complete(loop.shutdown_asyncgens()) | |
finally: | |
loop.close() | |
return ret | |
class AsyncProgressBar: | |
''' | |
An implementation of an async version of click's CLI progress bar with atomic-based progress reporting. | |
''' | |
def __init__(self, max_workers): | |
self.max_workers = max_workers | |
# Used to atomically update the progress bar | |
self._lock = threading.Lock() | |
def process(self, lst, map_fn, label): | |
async def run_progressive_task(lst): | |
def run_map_fn(item, bar): | |
result = map_fn(item) | |
# We lock the bar's access to prevent overwriting of the progress from multiple workers | |
with self._lock: | |
bar.update(1) | |
return result | |
with click.progressbar(length=len(lst), label=label) as bar: | |
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: | |
loop = asyncio.get_event_loop() | |
# Generate the required tasks needed to be executed | |
futures = [loop.run_in_executor(executor, run_map_fn, item, bar) for item in lst] | |
# Gather all the returned results and return the array | |
results_list = await asyncio.gather(*futures) | |
return results_list | |
return safe_run_async(run_progressive_task, lst) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment