Skip to content

Instantly share code, notes, and snippets.

@jonashaag
Last active November 29, 2024 11:09
Show Gist options
  • Save jonashaag/7a4d5627331c749f2b5b85869d9499e7 to your computer and use it in GitHub Desktop.
Save jonashaag/7a4d5627331c749f2b5b85869d9499e7 to your computer and use it in GitHub Desktop.
Python ThreadPoolExecutor Work Stealing
import concurrent.futures.thread as _thread_impl
import threading
import time
import weakref
from concurrent.futures import Future
class WorkStealThreadPoolExecutor(_thread_impl.ThreadPoolExecutor):
"""A ThreadPoolExecutor that supports work stealing.
We use work stealing to prevent worker starvation.
We use a custom `WorkStealFuture` to support work stealing upon calling `future.result()`.
"""
def submit(self, fn, /, *args, **kwargs):
# NOTE: Code below is almost completely copy-pasted from _thread_impl.ThreadPoolExecutor.submit.
with self._shutdown_lock, _thread_impl._global_shutdown_lock:
if self._broken:
raise _thread_impl.BrokenThreadPool(self._broken)
if self._shutdown:
raise RuntimeError("cannot schedule new futures after shutdown")
if _thread_impl._shutdown:
raise RuntimeError("cannot schedule new futures after interpreter shutdown")
f = WorkStealFuture()
w = _thread_impl._WorkItem(f, fn, args, kwargs)
f._work_item_weakref = weakref.ref(w)
self._work_queue.put(w)
self._adjust_thread_count()
return f
def map(self, fn, *iterables, timeout=None, chunksize=1, enable_work_stealing=True):
# Optimization
args_lists = list(zip(*iterables))
if enable_work_stealing and timeout is None and len(args_lists) < 2:
return (fn(*args) for args in args_lists)
# NOTE: Code below is almost completely copy-pasted from _thread_impl.ThreadPoolExecutor.map.
end_time = (timeout or 0) + time.monotonic()
fs = [self.submit(fn, *args) for args in args_lists]
def result_iterator():
try:
fs.reverse()
while fs:
yield _result_or_cancel(
fs.pop(),
timeout=None if timeout is None else end_time - time.monotonic(),
enable_work_stealing=enable_work_stealing,
)
finally:
for future in fs:
future.cancel()
return result_iterator()
def _result_or_cancel(fut, *, timeout, enable_work_stealing):
try:
try:
return fut.result(timeout, enable_work_stealing)
finally:
fut.cancel()
finally:
# Break a reference cycle with the exception in self._exception
del fut
class WorkStealFuture(Future):
"""A `Future` that supports work stealing, for use in `WorkStealThreadPoolExecutor`."""
_work_item_weakref: weakref.ref[_thread_impl._WorkItem]
def __init__(self):
super().__init__()
self._lock = threading.Lock()
self._starting_thread_id = None
def result(self, timeout=None, enable_work_stealing=True):
"""Get the result of the future.
If enable_work_stealing is True, attempt to execute the task itself instead of waiting.
"""
if not enable_work_stealing:
return super().result(timeout)
if timeout is not None and timeout > 0:
raise NotImplementedError("Positive 'timeout' not supported with 'enable_work_stealing=True'")
try:
with self._lock:
if self._starting_thread_id:
should_start = False
elif timeout is not None and timeout <= 0:
raise TimeoutError
else:
should_start = True
self._starting_thread_id = threading.get_ident()
if should_start and (work_item := self._work_item_weakref()):
work_item.run()
return super().result(timeout)
finally:
# Not sure if this is needed. See ThreadPoolExecutor implementation.
self = None
def set_running_or_notify_cancel(self):
with self._lock:
if self._starting_thread_id and (self._starting_thread_id != threading.get_ident()) or (self._state != "PENDING"):
return False
self._starting_thread_id = threading.get_ident()
return super().set_running_or_notify_cancel()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment