Created
May 5, 2023 16:21
-
-
Save thewisenerd/5719ee49a62d14eeb2fc0be6407c6169 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import threading | |
import typing | |
from dataclasses import dataclass, field | |
from queue import Queue | |
from typing import TypeVar, Generic | |
T = TypeVar('T') | |
R = TypeVar('R') | |
@dataclass | |
class Job(Generic[T, R]): | |
input: T | |
done: threading.Event = field(init=False, default_factory=threading.Event) | |
success: bool = field(init=False, default=True) | |
result: typing.Optional[R] = field(init=False, default=None) | |
error: typing.Optional[Exception] = field(init=False, default=None) | |
def complete_ok(self, result: R): | |
self.result = result | |
self.done.set() | |
def complete_err(self, err: Exception): | |
self.success = False | |
self.error = err | |
self.done.set() | |
def wait(self): | |
self.done.wait() | |
def get(self): | |
self.wait() | |
if self.success: | |
return self.result | |
raise self.error | |
class WorkerPool(Generic[T, R]): | |
def __init__(self, | |
name: str, | |
threads: int, | |
runner: typing.Callable[[T, 'WorkerPool'], R], | |
q: typing.Optional[Queue[T]] = None): | |
if q is None: | |
q = Queue[Job[T, R]]() | |
self.name = name | |
self.threads = threads | |
self.q = q | |
self.runner = runner | |
self.kill_switch = threading.Event() | |
self.worker_threads = [] | |
def worker(): | |
while not self.kill_switch.is_set(): | |
job = self.q.get() | |
try: | |
r = self.runner(job.input, self) | |
job.complete_ok(r) | |
except Exception as e: | |
job.complete_err(e) | |
q.task_done() | |
for idx in range(0, self.threads): | |
thread = threading.Thread( | |
target=worker, | |
name=f'{self.name}-{idx}', | |
args=[] | |
) | |
thread.daemon = True | |
thread.start() | |
self.worker_threads.append(thread) | |
def submit(self, task: T) -> Job[T, R]: | |
if self.kill_switch.is_set(): | |
raise RuntimeError("cannot submit jobs for a shut down worker pool") | |
job: Job[T, R] = Job(task) | |
self.q.put(job) | |
return job | |
def shutdown(self): | |
self.q.join() | |
self.kill_switch.set() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment