Created
April 5, 2022 21:07
-
-
Save onecrayon/975c541f4387a1fd960751d6e5e734c0 to your computer and use it in GitHub Desktop.
QueuedProcessor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import signal | |
import sys | |
from queue import Empty, Queue | |
from threading import Lock, Thread | |
from time import sleep | |
from typing import Callable | |
class QueuedProcessor: | |
"""Context manager class for processing script actions in threads | |
Mainly useful if you want nice-looking output on the command line to visualize how many | |
concurrent operations are ongoing. Not intended for use in web applications or similar. | |
NOTE: This uses threads, NOT multi-processing, which means that it's only appropriate | |
for I/O bound script actions. It will not give you any benefit when performing CPU-bound | |
actions. | |
See <https://stackoverflow.com/a/50174144/38666> for prior art on capturing signals. | |
Example usage: | |
import random | |
from time import sleep | |
def fake_action(throw_error=False): | |
sleep(random.random() * 2) | |
if throw_error: | |
raise Exception("No good!") | |
if __name__ == "__main__": | |
TOTAL = 50 | |
THREADS = 5 | |
with QueuedProcessor( | |
fake_action, | |
to_process=TOTAL, | |
max_threads=THREADS, | |
use_progress_tracker=True, | |
) as queue: | |
for _ in range(0, TOTAL): | |
zero_or_one = random.randint(0, 1) | |
queue.put(zero_or_one) | |
processed = queue.await_completion() | |
print(f"{processed} items run through queue!") | |
""" | |
queue: Queue | |
process_method: Callable | |
to_process: int | |
max_threads: int | |
# Internal variables | |
processed = 0 | |
stop_workers = False | |
workers = [] | |
_log_delay: int | |
_overwrite_progress_logs: bool | |
_previous_line_overwrote = False | |
_log_lock: Lock | |
# Used to track historical SIGINT and SIGTERM handlers | |
old_sigint = None | |
old_sigterm = None | |
def __init__( | |
self, | |
process_method: Callable, | |
to_process: int = 0, | |
max_threads: int = 10, | |
use_progress_tracker: bool = False, | |
log_lock: Lock = None, | |
): | |
"""Create a QueuedProcessor | |
:param process_method: Callable method that will be invoked for each item in the queue within its | |
own thread. | |
:param to_process: Total number of items to be processed (used for progress reports) | |
:param max_threads: Maximum number of allowed concurrent threads | |
:param use_progress_tracker: Whether to use a fancy, auto-updating progress tracker at 1 second | |
intervals, or a standard log output at 10 second intervals | |
:param log_lock: An optional Lock instance that will be acquired and released when outputting | |
log entries (only specify to share with outer logic) | |
""" | |
self.process_method = process_method | |
self.to_process = to_process | |
self.max_threads = max_threads | |
self._log_delay = 1 if use_progress_tracker else 10 | |
self._overwrite_progress_logs = use_progress_tracker | |
if log_lock is not None: | |
self._log_lock = log_lock | |
else: | |
self._log_lock = Lock() | |
self.queue = Queue(maxsize=self.max_threads) | |
# Create our processing threads | |
for i in range(self.max_threads): | |
worker = Thread(target=self._threaded_worker) | |
worker.start() | |
self.workers.append(worker) | |
# And create our progress tracker | |
worker = Thread(target=self._progress_tracker) | |
worker.start() | |
self.workers.append(worker) | |
def put(self, item): | |
"""Mirrors Queue.put(), but refuses to accept new items if we have exited for some reason""" | |
if self.stop_workers: | |
return | |
self.queue.put(item) | |
def await_completion(self) -> int: | |
"""Waits until queue is cleared, then cleans up the threads and returns the number of items processed""" | |
self.queue.join() | |
self.exit() | |
return self.processed | |
def log(self, message: str, overwrite=False): | |
"""Logs a message to the output, overwriting the previous line, if requested""" | |
self._log_lock.acquire() | |
if self._previous_line_overwrote: | |
message = "\r" + message | |
if overwrite: | |
sys.stdout.write(message) | |
sys.stdout.flush() | |
self._previous_line_overwrote = True | |
else: | |
print(message, flush=True) | |
self._previous_line_overwrote = False | |
self._log_lock.release() | |
def output_progress(self, final=False): | |
# Construct console output | |
bar_length = 50 | |
queue_length = self.queue.qsize() | |
filled_length = int(round(bar_length * queue_length / float(self.max_threads))) | |
queued_bar = "=" * filled_length + "-" * (bar_length - filled_length) | |
# Pad to width 6 so that the total line length doesn't change | |
processed_padded = "{:>6}".format(self.processed) | |
to_process_padded = "{:>6}".format(self.to_process) | |
# Counter tracking progress | |
output = f"🔄 Processed: {processed_padded} of {to_process_padded}; Queue: [{queued_bar}]" | |
self.log(output, overwrite=self._overwrite_progress_logs and not final) | |
def exit(self): | |
self.stop_workers = True | |
# Drain the queue (otherwise we might block on putting None if the queue is full) | |
while True: | |
try: | |
self.queue.get_nowait() | |
self.queue.task_done() | |
except Empty: | |
break | |
# Pass in None to the worker threads to ensure they don't wait on the queue forever (if we finished the queue) | |
for _ in self.workers: | |
self.queue.put(None) | |
# Wait for all threads to shut down | |
for thread in self.workers: | |
thread.join() | |
self.output_progress(final=True) | |
def _threaded_worker(self): | |
"""Processes queued items in a thread until it receives None or the process is killed""" | |
while not self.stop_workers: | |
item = self.queue.get() | |
if item is None: | |
break | |
processed_num = self.processed + 1 | |
try: | |
self.process_method(item) | |
except Exception as e: | |
self.log(f"⛔️ {item} ({processed_num}/{self.to_process}): {e}") | |
self.queue.task_done() | |
self.processed += 1 | |
def _progress_tracker(self): | |
"""Tracks progress and outputs to the command line | |
See: https://gist.github.com/vladignatyev/06860ec2040cb497f0f3 | |
""" | |
while not self.stop_workers: | |
# No need to output progress if we've completed everything, because a last progress will be output on exit; | |
# instead, we'll break out of the progress tracker and let the other threads end naturally. If we don't | |
# do this, then we'll have up to `self._log_delay` | |
if self.processed >= self.to_process: | |
break | |
self.output_progress() | |
sleep(self._log_delay) | |
def _kill_handler(self, signum, frame): | |
self.log("SIGINT or SIGTERM: finishing open threads and exiting...") | |
self.exit() | |
sys.exit(0) | |
def __enter__(self): | |
self.old_sigint = signal.signal(signal.SIGINT, self._kill_handler) | |
self.old_sigterm = signal.signal(signal.SIGTERM, self._kill_handler) | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
signal.signal(signal.SIGINT, self.old_sigint) | |
signal.signal(signal.SIGTERM, self.old_sigterm) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment