-
-
Save icezyclon/124df594496dee71ce8455a31b1dd29f to your computer and use it in GitHub Desktop.
Python Reentrant Read Write Lock: Allowing Multithreaded Read Access But Only a Single Writer. Fixed a bug (deadlock while using read and write with context managers), updated with 3.10+ type information and added some tests.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import contextlib | |
import threading | |
from typing import Generator | |
class ReentrantRWLock: | |
"""This class implements reentrant read-write lock objects. | |
A read-write lock can be aquired in read mode or in write mode or both. | |
Many different readers are allowed while no thread holds the write lock. | |
While a writer holds the write lock, no other threads, aside from the writer, | |
may hold the read or the write lock. | |
A thread may upgrade the lock to write mode while already holding the read lock. | |
Similarly, a thread already having write access may aquire the read lock | |
(or may already have it), to retain read access when releasing the write lock. | |
A reentrant lock must be released by the thread that acquired it. Once a | |
thread has acquired a reentrant lock (read or write), the same thread may acquire it | |
again without blocking any number of times; | |
the thread must release each lock (read/write) the same number of times it has acquired it! | |
The lock provides contextmanagers in the form of `for_read()` and `for_write()`, | |
which automatically aquire and release the corresponding lock, e.g., | |
>>> with lock.for_read(): # get read access until end of context | |
>>> ... | |
>>> with lock.for_write(): # upgrade to write access until end of inner | |
>>> ... | |
""" | |
def __init__(self) -> None: | |
self._writer: int | None = None # current writer | |
self._writer_count: int = 0 # number of times writer holding write lock | |
# set of current readers mapping to number of times holding read lock | |
# entry is missing if not holding the read lock (no 0 values) | |
self._readers: dict[int, int] = dict() | |
# main lock + condition, is used for: | |
# * protecting read/write access to _writer, _writer_times and _readers | |
# * is actively held when having write access (so no other thread has access) | |
# * future writers can wait() on the lock to be notified once nobody is reading/writing anymore | |
self._lock = threading.Condition(threading.RLock()) # reentrant | |
@contextlib.contextmanager | |
def for_read(self) -> Generator["ReentrantRWLock", None, None]: | |
""" | |
used for 'with' block, e.g., with lock.for_read(): ... | |
""" | |
try: | |
self.acquire_read() | |
yield self | |
finally: | |
self.release_read() | |
@contextlib.contextmanager | |
def for_write(self) -> Generator["ReentrantRWLock", None, None]: | |
""" | |
used for 'with' block, e.g., with lock.for_write(): ... | |
""" | |
try: | |
self.acquire_write() | |
yield self | |
finally: | |
self.release_write() | |
def acquire_read(self) -> None: | |
""" | |
Acquire one read lock. Blocks only if a another thread has acquired the write lock. | |
""" | |
ident: int = threading.current_thread().ident # type: ignore | |
with self._lock: | |
self._readers[ident] = self._readers.get(ident, 0) + 1 | |
def release_read(self) -> None: | |
""" | |
Release one currently held read lock from this thread. | |
""" | |
ident: int = threading.current_thread().ident # type: ignore | |
with self._lock: | |
if ident not in self._readers: | |
raise RuntimeError( | |
f"Read lock was released while not holding it by thread {ident}" | |
) | |
if self._readers[ident] == 1: | |
del self._readers[ident] | |
else: | |
self._readers[ident] -= 1 | |
if not self._readers: # if no other readers remain | |
self._lock.notify() # wake the next writer if any | |
def acquire_write(self) -> None: | |
""" | |
Acquire one write lock. Blocks until there are no acquired read or write locks from other threads. | |
""" | |
ident: int = threading.current_thread().ident # type: ignore | |
self._lock.acquire() # is reentrant, so current writer can aquire again | |
if self._writer == ident: | |
self._writer_count += 1 | |
return | |
# do not be reader while waiting for write or notify will not be called | |
times_reading = self._readers.pop(ident, 0) | |
while len(self._readers) > 0: | |
self._lock.wait() | |
self._writer = ident | |
self._writer_count += 1 | |
if times_reading: | |
# restore number of read locks thread originally had | |
self._readers[ident] = times_reading | |
def release_write(self) -> None: | |
""" | |
Release one currently held write lock from this thread. | |
""" | |
if self._writer != threading.current_thread().ident: | |
raise RuntimeError( | |
f"Write lock was released while not holding it by thread {threading.current_thread().ident}" | |
) | |
self._writer_count -= 1 | |
if self._writer_count == 0: | |
self._writer = None | |
self._lock.notify() # wake the next writer if any | |
self._lock.release() | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from threading import Thread | |
import pytest # pip install pytest pytest-timeout | |
from read_write_lock import ReentrantRWLock | |
# Note: set timeout for these tests, in case of deadlock we want to fail the test | |
# Should be at least 3 * SLEEP_TIME | |
TIMEOUT = 3 | |
# Some tests use sleeps to simulate possible race conditions | |
# Define how long this should be - test precision is dependend on this! | |
SLEEP_TIME = 0.2 | |
@pytest.mark.timeout(TIMEOUT) | |
def test_single_threaded_upgrade(): | |
lock = ReentrantRWLock() | |
with lock.for_read(): | |
with lock.for_write(): | |
pass | |
assert lock._readers, "Released read lock incorrectly" | |
@pytest.mark.timeout(TIMEOUT) | |
def test_single_threaded_reentrant_read(): | |
lock = ReentrantRWLock() | |
with lock.for_read(): | |
with lock.for_read(): | |
pass | |
assert lock._readers, "Released read lock too early" | |
@pytest.mark.timeout(TIMEOUT) | |
def test_single_threaded_reentrant_write(): | |
lock = ReentrantRWLock() | |
with lock.for_write(): | |
with lock.for_write(): | |
pass | |
assert lock._writer is not None, "Released write lock too early" | |
@pytest.mark.timeout(TIMEOUT) | |
def test_single_threaded_read_when_write(): | |
lock = ReentrantRWLock() | |
with lock.for_write(): | |
with lock.for_read(): | |
pass | |
assert lock._writer is not None, "Released write lock incorrectly" | |
@pytest.mark.timeout(TIMEOUT) | |
def test_single_threaded_deep(): | |
lock = ReentrantRWLock() | |
with lock.for_read(): | |
with lock.for_read(): | |
with lock.for_write(): | |
with lock.for_read(): | |
with lock.for_write(): | |
pass | |
assert lock._writer is None, "Did not release write lock correctly" | |
assert lock._readers, "Released read lock too early" | |
assert not lock._readers, "Did not release read lock correctly" | |
assert lock._writer is None, "Aquired write lock again??" | |
@pytest.mark.timeout(TIMEOUT) | |
def test_lock_no_ambiguous_context(): | |
lock = ReentrantRWLock() | |
with pytest.raises(AttributeError): | |
with lock: # type: ignore | |
pass | |
@pytest.mark.timeout(TIMEOUT) | |
def test_lock_wrong_release(): | |
lock = ReentrantRWLock() | |
with pytest.raises(RuntimeError): | |
lock.release_read() | |
with pytest.raises(RuntimeError): | |
lock.release_write() | |
@pytest.mark.timeout(TIMEOUT) | |
def test_multi_threaded_many_reads(): | |
lock = ReentrantRWLock() | |
def read(): | |
with lock.for_read(): | |
time.sleep(SLEEP_TIME) | |
t1 = Thread(name="read1", target=read, daemon=True) | |
t2 = Thread(name="read2", target=read, daemon=True) | |
start = time.perf_counter() | |
t1.start() | |
t2.start() | |
t1.join() | |
t2.join() | |
delta = time.perf_counter() - start | |
# a bit more than SLEEP_TIME, definitly less than 2 * SLEEP_TIME! | |
assert ( | |
delta < 1.5 * SLEEP_TIME | |
), f"Time for both joins should be {delta=} > {1.5 * SLEEP_TIME}" | |
@pytest.mark.timeout(TIMEOUT) | |
def test_multi_threaded_write_exclusive(): | |
lock = ReentrantRWLock() | |
def write(): | |
with lock.for_write(): | |
time.sleep(SLEEP_TIME) | |
t1 = Thread(name="write1", target=write, daemon=True) | |
t2 = Thread(name="write2", target=write, daemon=True) | |
start = time.perf_counter() | |
t1.start() | |
t2.start() | |
t1.join() | |
t2.join() | |
delta = time.perf_counter() - start | |
# definitly at least 2 * SLEEP_TIME! | |
assert ( | |
delta > 1.9 * SLEEP_TIME | |
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}" | |
@pytest.mark.timeout(TIMEOUT) | |
def test_multi_threaded_read_write_exclusive(): | |
lock = ReentrantRWLock() | |
def read(): | |
with lock.for_read(): | |
time.sleep(SLEEP_TIME) | |
def write(): | |
with lock.for_write(): | |
time.sleep(SLEEP_TIME) | |
t1 = Thread(name="read", target=read, daemon=True) | |
t2 = Thread(name="write", target=write, daemon=True) | |
start = time.perf_counter() | |
t1.start() | |
time.sleep(0.01) | |
t2.start() | |
t1.join() | |
t2.join() | |
delta = time.perf_counter() - start | |
# definitly at least 2 * SLEEP_TIME! | |
assert ( | |
delta > 1.9 * SLEEP_TIME | |
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}" | |
@pytest.mark.timeout(TIMEOUT) | |
def test_multi_threaded_write_read_exclusive(): | |
lock = ReentrantRWLock() | |
def read(): | |
with lock.for_read(): | |
time.sleep(SLEEP_TIME) | |
def write(): | |
with lock.for_write(): | |
time.sleep(SLEEP_TIME) | |
t1 = Thread(name="write", target=write, daemon=True) | |
t2 = Thread(name="read", target=read, daemon=True) | |
start = time.perf_counter() | |
t1.start() | |
time.sleep(0.01) | |
t2.start() | |
t1.join() | |
t2.join() | |
delta = time.perf_counter() - start | |
# definitly at least 2 * SLEEP_TIME! | |
assert ( | |
delta > 1.9 * SLEEP_TIME | |
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}" | |
@pytest.mark.timeout(TIMEOUT) | |
def test_multi_threaded_read_write_exclusive_direct(): | |
lock = ReentrantRWLock() | |
def read(): | |
lock.acquire_read() | |
time.sleep(SLEEP_TIME) | |
lock.release_read() | |
def write(): | |
lock.acquire_write() | |
time.sleep(SLEEP_TIME) | |
lock.release_write() | |
t1 = Thread(name="read", target=read, daemon=True) | |
t2 = Thread(name="write", target=write, daemon=True) | |
start = time.perf_counter() | |
t1.start() | |
time.sleep(0.01) | |
t2.start() | |
t1.join() | |
t2.join() | |
delta = time.perf_counter() - start | |
# definitly at least 2 * SLEEP_TIME! | |
assert ( | |
delta > 1.9 * SLEEP_TIME | |
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}" | |
@pytest.mark.timeout(TIMEOUT) | |
def test_multi_threaded_write_read_exclusive_direct(): | |
lock = ReentrantRWLock() | |
def read(): | |
lock.acquire_read() | |
time.sleep(SLEEP_TIME) | |
lock.release_read() | |
def write(): | |
lock.acquire_write() | |
time.sleep(SLEEP_TIME) | |
lock.release_write() | |
t1 = Thread(name="write", target=write, daemon=True) | |
t2 = Thread(name="read", target=read, daemon=True) | |
start = time.perf_counter() | |
t1.start() | |
time.sleep(0.01) | |
t2.start() | |
t1.join() | |
t2.join() | |
delta = time.perf_counter() - start | |
# definitly at least 2 * SLEEP_TIME! | |
assert ( | |
delta > 1.9 * SLEEP_TIME | |
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}" | |
@pytest.mark.timeout(TIMEOUT) | |
def test_multi_threaded_read_readwrite_exclusive(): | |
lock = ReentrantRWLock() | |
def read(): | |
with lock.for_read(): | |
time.sleep(SLEEP_TIME) | |
def write(): | |
with lock.for_read(): | |
with lock.for_write(): | |
time.sleep(SLEEP_TIME) | |
t1 = Thread(name="read", target=read, daemon=True) | |
t2 = Thread(name="readwrite", target=write, daemon=True) | |
start = time.perf_counter() | |
t1.start() | |
time.sleep(0.01) | |
t2.start() | |
t1.join() | |
t2.join() | |
delta = time.perf_counter() - start | |
# definitly at least 2 * SLEEP_TIME! | |
assert ( | |
delta > 1.9 * SLEEP_TIME | |
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}" | |
@pytest.mark.timeout(TIMEOUT) | |
def test_multi_threaded_readwrite_read_exclusive(): | |
lock = ReentrantRWLock() | |
def read(): | |
with lock.for_read(): | |
time.sleep(SLEEP_TIME) | |
def write(): | |
with lock.for_read(): | |
with lock.for_write(): | |
time.sleep(SLEEP_TIME) | |
t1 = Thread(name="readwrite", target=write, daemon=True) | |
t2 = Thread(name="read", target=read, daemon=True) | |
start = time.perf_counter() | |
t1.start() | |
time.sleep(0.01) | |
t2.start() | |
t1.join() | |
t2.join() | |
delta = time.perf_counter() - start | |
# definitly at least 2 * SLEEP_TIME! | |
assert ( | |
delta > 1.9 * SLEEP_TIME | |
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}" | |
@pytest.mark.timeout(TIMEOUT) | |
def test_multi_threaded_read_readwrite_exclusive_direct(): | |
lock = ReentrantRWLock() | |
def read(): | |
lock.acquire_read() | |
time.sleep(SLEEP_TIME) | |
lock.release_read() | |
def write(): | |
lock.acquire_read() | |
lock.acquire_write() | |
time.sleep(SLEEP_TIME) | |
lock.release_write() | |
lock.release_read() | |
t1 = Thread(name="read", target=read, daemon=True) | |
t2 = Thread(name="readwrite", target=write, daemon=True) | |
start = time.perf_counter() | |
t1.start() | |
time.sleep(0.01) | |
t2.start() | |
t1.join() | |
t2.join() | |
delta = time.perf_counter() - start | |
# definitly at least 2 * SLEEP_TIME! | |
assert ( | |
delta > 1.9 * SLEEP_TIME | |
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}" | |
@pytest.mark.timeout(TIMEOUT) | |
def test_multi_threaded_readwrite_read_exclusive_direct(): | |
lock = ReentrantRWLock() | |
def read(): | |
lock.acquire_read() | |
time.sleep(SLEEP_TIME) | |
lock.release_read() | |
def write(): | |
lock.acquire_read() | |
lock.acquire_write() | |
time.sleep(SLEEP_TIME) | |
lock.release_write() | |
lock.release_read() | |
t1 = Thread(name="readwrite", target=write, daemon=True) | |
t2 = Thread(name="read", target=read, daemon=True) | |
start = time.perf_counter() | |
t1.start() | |
time.sleep(0.01) | |
t2.start() | |
t1.join() | |
t2.join() | |
delta = time.perf_counter() - start | |
# definitly at least 2 * SLEEP_TIME! | |
assert ( | |
delta > 1.9 * SLEEP_TIME | |
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}" | |
@pytest.mark.timeout(TIMEOUT) | |
def test_multi_threaded_readwrite_readwrite_exclusive(): | |
lock = ReentrantRWLock() | |
def readwrite(): | |
with lock.for_read(): | |
time.sleep(SLEEP_TIME) | |
with lock.for_write(): | |
time.sleep(SLEEP_TIME) | |
t1 = Thread(name="readwrite1", target=readwrite, daemon=True) | |
t2 = Thread(name="readwrite2", target=readwrite, daemon=True) | |
start = time.perf_counter() | |
t1.start() | |
time.sleep(0.01) | |
t2.start() | |
t1.join() | |
t2.join() | |
delta = time.perf_counter() - start | |
# definitly at least 3 * SLEEP_TIME but less than 4 * SLEEP_TIME! | |
assert ( | |
2.9 * SLEEP_TIME < delta < 3.5 * SLEEP_TIME | |
), f"Time for both joins should be {2.9 * SLEEP_TIME} < {delta=} < {3.5 * SLEEP_TIME}" | |
@pytest.mark.timeout(TIMEOUT) | |
def test_multi_threaded_writeread_writeread_exclusive(): | |
lock = ReentrantRWLock() | |
def writeread(): | |
with lock.for_write(): | |
time.sleep(SLEEP_TIME) | |
with lock.for_read(): | |
time.sleep(SLEEP_TIME) | |
t1 = Thread(name="writeread1", target=writeread, daemon=True) | |
t2 = Thread(name="writeread2", target=writeread, daemon=True) | |
start = time.perf_counter() | |
t1.start() | |
time.sleep(0.01) | |
t2.start() | |
t1.join() | |
t2.join() | |
delta = time.perf_counter() - start | |
# definitly at least 3 * SLEEP_TIME but less than 4 * SLEEP_TIME! | |
assert ( | |
3.9 * SLEEP_TIME < delta < 4.5 * SLEEP_TIME | |
), f"Time for both joins should be {3.9 * SLEEP_TIME} < {delta=} < {4.5 * SLEEP_TIME}" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment