Last active
November 29, 2024 11:09
-
-
Save jonashaag/7a4d5627331c749f2b5b85869d9499e7 to your computer and use it in GitHub Desktop.
Python ThreadPoolExecutor Work Stealing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import concurrent.futures.thread as _thread_impl | |
import threading | |
import time | |
import weakref | |
from concurrent.futures import Future | |
class WorkStealThreadPoolExecutor(_thread_impl.ThreadPoolExecutor): | |
"""A ThreadPoolExecutor that supports work stealing. | |
We use work stealing to prevent worker starvation. | |
We use a custom `WorkStealFuture` to support work stealing upon calling `future.result()`. | |
""" | |
def submit(self, fn, /, *args, **kwargs): | |
# NOTE: Code below is almost completely copy-pasted from _thread_impl.ThreadPoolExecutor.submit. | |
with self._shutdown_lock, _thread_impl._global_shutdown_lock: | |
if self._broken: | |
raise _thread_impl.BrokenThreadPool(self._broken) | |
if self._shutdown: | |
raise RuntimeError("cannot schedule new futures after shutdown") | |
if _thread_impl._shutdown: | |
raise RuntimeError("cannot schedule new futures after interpreter shutdown") | |
f = WorkStealFuture() | |
w = _thread_impl._WorkItem(f, fn, args, kwargs) | |
f._work_item_weakref = weakref.ref(w) | |
self._work_queue.put(w) | |
self._adjust_thread_count() | |
return f | |
def map(self, fn, *iterables, timeout=None, chunksize=1, enable_work_stealing=True): | |
# Optimization | |
args_lists = list(zip(*iterables)) | |
if enable_work_stealing and timeout is None and len(args_lists) < 2: | |
return (fn(*args) for args in args_lists) | |
# NOTE: Code below is almost completely copy-pasted from _thread_impl.ThreadPoolExecutor.map. | |
end_time = (timeout or 0) + time.monotonic() | |
fs = [self.submit(fn, *args) for args in args_lists] | |
def result_iterator(): | |
try: | |
fs.reverse() | |
while fs: | |
yield _result_or_cancel( | |
fs.pop(), | |
timeout=None if timeout is None else end_time - time.monotonic(), | |
enable_work_stealing=enable_work_stealing, | |
) | |
finally: | |
for future in fs: | |
future.cancel() | |
return result_iterator() | |
def _result_or_cancel(fut, *, timeout, enable_work_stealing): | |
try: | |
try: | |
return fut.result(timeout, enable_work_stealing) | |
finally: | |
fut.cancel() | |
finally: | |
# Break a reference cycle with the exception in self._exception | |
del fut | |
class WorkStealFuture(Future): | |
"""A `Future` that supports work stealing, for use in `WorkStealThreadPoolExecutor`.""" | |
_work_item_weakref: weakref.ref[_thread_impl._WorkItem] | |
def __init__(self): | |
super().__init__() | |
self._lock = threading.Lock() | |
self._starting_thread_id = None | |
def result(self, timeout=None, enable_work_stealing=True): | |
"""Get the result of the future. | |
If enable_work_stealing is True, attempt to execute the task itself instead of waiting. | |
""" | |
if not enable_work_stealing: | |
return super().result(timeout) | |
if timeout is not None and timeout > 0: | |
raise NotImplementedError("Positive 'timeout' not supported with 'enable_work_stealing=True'") | |
try: | |
with self._lock: | |
if self._starting_thread_id: | |
should_start = False | |
elif timeout is not None and timeout <= 0: | |
raise TimeoutError | |
else: | |
should_start = True | |
self._starting_thread_id = threading.get_ident() | |
if should_start and (work_item := self._work_item_weakref()): | |
work_item.run() | |
return super().result(timeout) | |
finally: | |
# Not sure if this is needed. See ThreadPoolExecutor implementation. | |
self = None | |
def set_running_or_notify_cancel(self): | |
with self._lock: | |
if self._starting_thread_id and (self._starting_thread_id != threading.get_ident()) or (self._state != "PENDING"): | |
return False | |
self._starting_thread_id = threading.get_ident() | |
return super().set_running_or_notify_cancel() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment