Skip to content

Instantly share code, notes, and snippets.

@x42005e1f
Last active August 23, 2025 21:32
Show Gist options
  • Save x42005e1f/e50cc904867f2458a546c9e2f51128fe to your computer and use it in GitHub Desktop.
Save x42005e1f/e50cc904867f2458a546c9e2f51128fe to your computer and use it in GitHub Desktop.
Some patches for thread-safety (threading & eventlet)
#!/usr/bin/env python3
# SPDX-FileCopyrightText: 2024 Ilya Egorov <[email protected]>
# SPDX-License-Identifier: ISC
# mypy: disable-error-code="attr-defined, import-untyped, no-untyped-def"
# pyright: reportAttributeAccessIssue=false, reportOptionalMemberAccess=false
from __future__ import annotations
import sys
from functools import wraps
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from types import ModuleType
def patch_threading(threading: ModuleType | None = None, /) -> None:
"""
Fixes race conditions in :meth:`threading.Thread.join`
(:exc:`AssertionError`, :exc:`RuntimeError`) for PyPy.
Code to reproduce:
.. code:: python
from threading import Thread, current_thread
def func():
thread = Thread(target=current_thread().join)
thread.start()
for _ in range(1000):
thread = Thread(target=func)
thread.start()
thread.join()
Why is this happening? Because :meth:`threading.Thread.join` implementation
is not thread-safe (see CPython's `gh-116372 <https://github.com/python/
cpython/issues/116372>`__), and PyPy can switch execution to another thread
right after :meth:`threading.Lock.acquire` (CPython can't):
.. code:: python
from threading import Barrier, Lock, Thread
lock = Lock()
barrier = Barrier(1000)
def func():
barrier.wait()
lock.acquire()
lock.release()
assert not lock.locked() # succeeds on CPython, fails on PyPy
for _ in range(1000):
Thread(target=func).start()
You may ask: why not just suppress :exc:`AssertionError` and
:exc:`RuntimeError`? Well, even if the user suppresses them, the
:func:`threading._shutdown` function won't. As a result, running threads
will be killed as daemon threads:
.. code:: python
import time
from threading import Thread, current_thread
def func(i):
try:
thread = Thread(target=current_thread().join)
thread.start()
for j in range(3, 0, -1):
if i == 0:
print(f"{j}...")
time.sleep(1)
finally:
if i == 999:
time.sleep(0.5) # ensure a long runtime
print("OK") # will be printed on CPython, but not on PyPy
for i in range(1000):
thread = Thread(target=func, args=[i])
thread.start()
A related problem is that if a first thread calls
:meth:`threading.Thread.join()` of a second thread and in the first thread
an exception (for example, :exc:`KeyboardInterrupt`) is raised, then the
second thread will be killed as a daemon thread:
.. code:: python
import time
from threading import Thread, main_thread
blackmail = " ".join([
"Victor Stinner, I'll take your beer away!",
"Muahahahahaha...",
"Hahahahahahaha...",
"AAAAHAHAHAHAHAHA!",
])
def func():
try:
main_thread().join()
finally:
time.sleep(0.5) # ensure a long runtime
print(blackmail) # will never be printed
thread = Thread(target=func)
thread.start()
try:
thread.join()
except KeyboardInterrupt:
pass # PyPy doesn't wait for threads after KeyboardInterrupt
The reason for this is that on September 27, 2021, CPython applied a fix
for a different race condition on the same :meth:`threading.Thread.join()`
that caused hangs, and mostly on Windows (see `bpo-21822 <https://
bugs.python.org/issue21822>`__, `bpo-45274 <https://bugs.python.org/
issue45274>`__, and `gh-28532 <https://github.com/python/cpython/pull/
28532>`__). Actually, there are no hangs now, because threads can now kill
each other. So the :exc:`KeyboardInterrupt` exception raised by the signal
handler has become really dangerous. This patch doesn't undo that fix. If
you really want safe Control-C handling, set your own signal handler.
In Python 3.13, all of these issues have been resolved at C level as part
of the free-threaded mode implementation (see `gh-114271 <https://
github.com/python/cpython/issues/114271>`__).
Does not affect PyPy 7.3.18 and newer (see `pypy/pypy#5080 <https://
github.com/pypy/pypy/pull/5080>`__).
"""
if not hasattr(sys, "pypy_version_info"):
return
if sys.pypy_version_info >= (7, 3, 18):
return
if threading is None:
threading = __import__("threading")
Thread = threading.Thread
if not hasattr(threading, "_maintain_shutdown_locks"):
def _maintain_shutdown_locks():
_shutdown_locks = threading._shutdown_locks
_shutdown_locks.difference_update([
lock for lock in _shutdown_locks if not lock.locked()
])
threading._maintain_shutdown_locks = _maintain_shutdown_locks
@wraps(Thread._set_tstate_lock)
def Thread__set_tstate_lock(self):
self._tstate_lock = threading._set_sentinel()
self._tstate_lock.acquire()
if not self.daemon:
with threading._shutdown_locks_lock:
threading._maintain_shutdown_locks()
threading._shutdown_locks.add(self._tstate_lock)
Thread._set_tstate_lock = Thread__set_tstate_lock
@wraps(Thread._stop)
def Thread__stop(self):
self._is_stopped = True
if not self.daemon:
if self._tstate_lock is not None:
with threading._shutdown_locks_lock:
if self._tstate_lock is not None:
threading._maintain_shutdown_locks()
self._tstate_lock = None
else:
self._tstate_lock = None
Thread._stop = Thread__stop
@wraps(Thread._wait_for_tstate_lock)
def Thread__wait_for_tstate_lock(self, block=True, timeout=-1):
lock = self._tstate_lock
if lock is None:
assert self._is_stopped
return
try:
if lock.acquire(block, timeout):
try:
lock.release()
except RuntimeError:
pass
self._stop()
except:
if lock.locked():
try:
lock.release()
except RuntimeError:
pass
self._stop()
raise
Thread._wait_for_tstate_lock = Thread__wait_for_tstate_lock
@wraps(threading._shutdown)
def _shutdown():
_main_thread = threading._main_thread
_threading_atexits = threading._threading_atexits
_shutdown_locks = threading._shutdown_locks
_shutdown_locks_lock = threading._shutdown_locks_lock
try:
_is_main_interpreter = threading._is_main_interpreter
except AttributeError:
if _main_thread._is_stopped:
return
else:
if _main_thread._is_stopped and _is_main_interpreter():
return
threading._SHUTTING_DOWN = True
for atexit_call in reversed(_threading_atexits):
atexit_call()
_threading_atexits.clear()
if _main_thread.ident == threading.get_ident():
tlock = _main_thread._tstate_lock
assert tlock is not None
assert tlock.locked()
tlock.release()
_main_thread._stop()
while True:
with _shutdown_locks_lock:
locks = list(_shutdown_locks)
_shutdown_locks.clear()
if not locks:
break
for lock in locks:
lock.acquire()
try:
lock.release()
except RuntimeError:
pass
threading._shutdown = _shutdown
def patch_eventlet() -> None:
"""
Injects ``destroy()`` into :class:`eventlet.hubs.hub.BaseHub` to fix EMFILE
("too many open files") + ENOMEM (memory leak).
Code to reproduce:
.. code:: python
from threading import Thread
import eventlet
import eventlet.hubs
stop = False
def func():
global stop
try:
eventlet.sleep()
except:
stop = True
raise
finally:
hub = eventlet.hubs.get_hub()
try:
destroy = hub.destroy
except AttributeError:
pass
else:
destroy()
while not stop:
thread = Thread(target=func)
thread.start()
thread.join()
Also injects ``schedule_call_threadsafe()`` (a thread-safe variant of
:meth:`eventlet.hubs.hub.BaseHub.schedule_call_global`).
"""
from eventlet.hubs.asyncio import Hub as AsyncioHub
from eventlet.hubs.hub import BaseHub
if not hasattr(BaseHub, "destroy"):
def BaseHub_destroy(self, /):
if not self.greenlet.dead:
self.abort(wait=True)
BaseHub.destroy = BaseHub_destroy
def AsyncioHub_destroy(self, /):
super(self.__class__, self).destroy()
self.loop.close()
AsyncioHub.destroy = AsyncioHub_destroy
try:
from eventlet.hubs.epolls import Hub as EpollHub
except ImportError:
pass
else:
def EpollHub_destroy(self, /):
super(self.__class__, self).destroy()
self.poll.close()
EpollHub.destroy = EpollHub_destroy
try:
from eventlet.hubs.kqueue import Hub as KqueueHub
except ImportError:
pass
else:
def KqueueHub_destroy(self, /):
super(self.__class__, self).destroy()
self.kqueue.close()
KqueueHub.destroy = KqueueHub_destroy
if not hasattr(BaseHub, "schedule_call_threadsafe"):
from eventlet.hubs.timer import Timer
from eventlet.patcher import original
socket = original("socket")
BaseHub_destroy_impl = BaseHub.destroy
@wraps(BaseHub.destroy)
def BaseHub_destroy_redef(self, /):
try:
rsock = self._threadsafe_rsock
except AttributeError:
pass
else:
rsock.close()
try:
wsock = self._threadsafe_wsock
except AttributeError:
pass
else:
wsock.close()
return BaseHub_destroy_impl(self)
BaseHub.destroy = BaseHub_destroy_redef
def BaseHub__init_socketpair(self, /):
if not hasattr(self, "_threadsafe_wsock"):
rsock, wsock = socket.socketpair()
rsock_fd = rsock.fileno()
rsock.setblocking(False)
wsock.setblocking(False)
rsock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1)
wsock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1)
try:
wsock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
except OSError:
pass
def rsock_recv(_):
while True:
try:
data = rsock.recv(4096)
except InterruptedError:
continue
except BlockingIOError:
break
else:
if not data:
break
def rsock_throw(exc, /):
raise exc
self.mark_as_reopened(rsock_fd)
self.add(self.READ, rsock_fd, rsock_recv, rsock_throw, None)
self._threadsafe_rsock = rsock
self._threadsafe_wsock = wsock
BaseHub._init_socketpair = BaseHub__init_socketpair
def AsyncioHub__init_socketpair(self, /):
pass
AsyncioHub._init_socketpair = AsyncioHub__init_socketpair
BaseHub_prepare_timers_impl = BaseHub.prepare_timers
@wraps(BaseHub.prepare_timers)
def BaseHub_prepare_timers(self):
self._init_socketpair()
try:
timers = self._threadsafe_timers
except AttributeError:
pass
else:
if timers:
items = timers.copy()
self.next_timers.extend(items)
del timers[: len(items)]
return BaseHub_prepare_timers_impl(self)
BaseHub.prepare_timers = BaseHub_prepare_timers
def BaseHub__threadsafe_wakeup(self, /):
try:
wsock = self._threadsafe_wsock
except AttributeError:
pass
else:
try:
wsock.send(b"\x00")
except OSError:
pass
BaseHub._threadsafe_wakeup = BaseHub__threadsafe_wakeup
def AsyncioHub__threadsafe_wakeup(self, /):
try:
self.loop.call_soon_threadsafe(self.sleep_event.set)
except RuntimeError: # event loop is closed
pass
AsyncioHub._threadsafe_wakeup = AsyncioHub__threadsafe_wakeup
def BaseHub_schedule_call_threadsafe(
self,
seconds,
callback,
/,
*args,
**kwargs,
):
timer = Timer(seconds, callback, *args, **kwargs)
scheduled_time = self.clock() + seconds
try:
timers = self._threadsafe_timers
except AttributeError:
timers = vars(self).setdefault("_threadsafe_timers", [])
timers.append((scheduled_time, timer))
self._threadsafe_wakeup()
# timer methods are not thread-safe, so we do not return it
BaseHub.schedule_call_threadsafe = BaseHub_schedule_call_threadsafe
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment