Skip to content

Instantly share code, notes, and snippets.

@icezyclon
Forked from Eboubaker/read_write_lock.py
Last active July 31, 2024 11:28
Show Gist options
  • Save icezyclon/124df594496dee71ce8455a31b1dd29f to your computer and use it in GitHub Desktop.
Save icezyclon/124df594496dee71ce8455a31b1dd29f to your computer and use it in GitHub Desktop.
Python Reentrant Read Write Lock: Allowing Multithreaded Read Access But Only a Single Writer. Fixed a bug (deadlock while using read and write with context managers), updated with 3.10+ type information and added some tests.
import contextlib
import threading
from typing import Generator
class ReentrantRWLock:
"""This class implements reentrant read-write lock objects.
A read-write lock can be aquired in read mode or in write mode or both.
Many different readers are allowed while no thread holds the write lock.
While a writer holds the write lock, no other threads, aside from the writer,
may hold the read or the write lock.
A thread may upgrade the lock to write mode while already holding the read lock.
Similarly, a thread already having write access may aquire the read lock
(or may already have it), to retain read access when releasing the write lock.
A reentrant lock must be released by the thread that acquired it. Once a
thread has acquired a reentrant lock (read or write), the same thread may acquire it
again without blocking any number of times;
the thread must release each lock (read/write) the same number of times it has acquired it!
The lock provides contextmanagers in the form of `for_read()` and `for_write()`,
which automatically aquire and release the corresponding lock, e.g.,
>>> with lock.for_read(): # get read access until end of context
>>> ...
>>> with lock.for_write(): # upgrade to write access until end of inner
>>> ...
"""
def __init__(self) -> None:
self._writer: int | None = None # current writer
self._writer_count: int = 0 # number of times writer holding write lock
# set of current readers mapping to number of times holding read lock
# entry is missing if not holding the read lock (no 0 values)
self._readers: dict[int, int] = dict()
# main lock + condition, is used for:
# * protecting read/write access to _writer, _writer_times and _readers
# * is actively held when having write access (so no other thread has access)
# * future writers can wait() on the lock to be notified once nobody is reading/writing anymore
self._lock = threading.Condition(threading.RLock()) # reentrant
@contextlib.contextmanager
def for_read(self) -> Generator["ReentrantRWLock", None, None]:
"""
used for 'with' block, e.g., with lock.for_read(): ...
"""
try:
self.acquire_read()
yield self
finally:
self.release_read()
@contextlib.contextmanager
def for_write(self) -> Generator["ReentrantRWLock", None, None]:
"""
used for 'with' block, e.g., with lock.for_write(): ...
"""
try:
self.acquire_write()
yield self
finally:
self.release_write()
def acquire_read(self) -> None:
"""
Acquire one read lock. Blocks only if a another thread has acquired the write lock.
"""
ident: int = threading.current_thread().ident # type: ignore
with self._lock:
self._readers[ident] = self._readers.get(ident, 0) + 1
def release_read(self) -> None:
"""
Release one currently held read lock from this thread.
"""
ident: int = threading.current_thread().ident # type: ignore
with self._lock:
if ident not in self._readers:
raise RuntimeError(
f"Read lock was released while not holding it by thread {ident}"
)
if self._readers[ident] == 1:
del self._readers[ident]
else:
self._readers[ident] -= 1
if not self._readers: # if no other readers remain
self._lock.notify() # wake the next writer if any
def acquire_write(self) -> None:
"""
Acquire one write lock. Blocks until there are no acquired read or write locks from other threads.
"""
ident: int = threading.current_thread().ident # type: ignore
self._lock.acquire() # is reentrant, so current writer can aquire again
if self._writer == ident:
self._writer_count += 1
return
# do not be reader while waiting for write or notify will not be called
times_reading = self._readers.pop(ident, 0)
while len(self._readers) > 0:
self._lock.wait()
self._writer = ident
self._writer_count += 1
if times_reading:
# restore number of read locks thread originally had
self._readers[ident] = times_reading
def release_write(self) -> None:
"""
Release one currently held write lock from this thread.
"""
if self._writer != threading.current_thread().ident:
raise RuntimeError(
f"Write lock was released while not holding it by thread {threading.current_thread().ident}"
)
self._writer_count -= 1
if self._writer_count == 0:
self._writer = None
self._lock.notify() # wake the next writer if any
self._lock.release()
import time
from threading import Thread
import pytest # pip install pytest pytest-timeout
from read_write_lock import ReentrantRWLock
# Note: set timeout for these tests, in case of deadlock we want to fail the test
# Should be at least 3 * SLEEP_TIME
TIMEOUT = 3
# Some tests use sleeps to simulate possible race conditions
# Define how long this should be - test precision is dependend on this!
SLEEP_TIME = 0.2
@pytest.mark.timeout(TIMEOUT)
def test_single_threaded_upgrade():
lock = ReentrantRWLock()
with lock.for_read():
with lock.for_write():
pass
assert lock._readers, "Released read lock incorrectly"
@pytest.mark.timeout(TIMEOUT)
def test_single_threaded_reentrant_read():
lock = ReentrantRWLock()
with lock.for_read():
with lock.for_read():
pass
assert lock._readers, "Released read lock too early"
@pytest.mark.timeout(TIMEOUT)
def test_single_threaded_reentrant_write():
lock = ReentrantRWLock()
with lock.for_write():
with lock.for_write():
pass
assert lock._writer is not None, "Released write lock too early"
@pytest.mark.timeout(TIMEOUT)
def test_single_threaded_read_when_write():
lock = ReentrantRWLock()
with lock.for_write():
with lock.for_read():
pass
assert lock._writer is not None, "Released write lock incorrectly"
@pytest.mark.timeout(TIMEOUT)
def test_single_threaded_deep():
lock = ReentrantRWLock()
with lock.for_read():
with lock.for_read():
with lock.for_write():
with lock.for_read():
with lock.for_write():
pass
assert lock._writer is None, "Did not release write lock correctly"
assert lock._readers, "Released read lock too early"
assert not lock._readers, "Did not release read lock correctly"
assert lock._writer is None, "Aquired write lock again??"
@pytest.mark.timeout(TIMEOUT)
def test_lock_no_ambiguous_context():
lock = ReentrantRWLock()
with pytest.raises(AttributeError):
with lock: # type: ignore
pass
@pytest.mark.timeout(TIMEOUT)
def test_lock_wrong_release():
lock = ReentrantRWLock()
with pytest.raises(RuntimeError):
lock.release_read()
with pytest.raises(RuntimeError):
lock.release_write()
@pytest.mark.timeout(TIMEOUT)
def test_multi_threaded_many_reads():
lock = ReentrantRWLock()
def read():
with lock.for_read():
time.sleep(SLEEP_TIME)
t1 = Thread(name="read1", target=read, daemon=True)
t2 = Thread(name="read2", target=read, daemon=True)
start = time.perf_counter()
t1.start()
t2.start()
t1.join()
t2.join()
delta = time.perf_counter() - start
# a bit more than SLEEP_TIME, definitly less than 2 * SLEEP_TIME!
assert (
delta < 1.5 * SLEEP_TIME
), f"Time for both joins should be {delta=} > {1.5 * SLEEP_TIME}"
@pytest.mark.timeout(TIMEOUT)
def test_multi_threaded_write_exclusive():
lock = ReentrantRWLock()
def write():
with lock.for_write():
time.sleep(SLEEP_TIME)
t1 = Thread(name="write1", target=write, daemon=True)
t2 = Thread(name="write2", target=write, daemon=True)
start = time.perf_counter()
t1.start()
t2.start()
t1.join()
t2.join()
delta = time.perf_counter() - start
# definitly at least 2 * SLEEP_TIME!
assert (
delta > 1.9 * SLEEP_TIME
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}"
@pytest.mark.timeout(TIMEOUT)
def test_multi_threaded_read_write_exclusive():
lock = ReentrantRWLock()
def read():
with lock.for_read():
time.sleep(SLEEP_TIME)
def write():
with lock.for_write():
time.sleep(SLEEP_TIME)
t1 = Thread(name="read", target=read, daemon=True)
t2 = Thread(name="write", target=write, daemon=True)
start = time.perf_counter()
t1.start()
time.sleep(0.01)
t2.start()
t1.join()
t2.join()
delta = time.perf_counter() - start
# definitly at least 2 * SLEEP_TIME!
assert (
delta > 1.9 * SLEEP_TIME
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}"
@pytest.mark.timeout(TIMEOUT)
def test_multi_threaded_write_read_exclusive():
lock = ReentrantRWLock()
def read():
with lock.for_read():
time.sleep(SLEEP_TIME)
def write():
with lock.for_write():
time.sleep(SLEEP_TIME)
t1 = Thread(name="write", target=write, daemon=True)
t2 = Thread(name="read", target=read, daemon=True)
start = time.perf_counter()
t1.start()
time.sleep(0.01)
t2.start()
t1.join()
t2.join()
delta = time.perf_counter() - start
# definitly at least 2 * SLEEP_TIME!
assert (
delta > 1.9 * SLEEP_TIME
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}"
@pytest.mark.timeout(TIMEOUT)
def test_multi_threaded_read_write_exclusive_direct():
lock = ReentrantRWLock()
def read():
lock.acquire_read()
time.sleep(SLEEP_TIME)
lock.release_read()
def write():
lock.acquire_write()
time.sleep(SLEEP_TIME)
lock.release_write()
t1 = Thread(name="read", target=read, daemon=True)
t2 = Thread(name="write", target=write, daemon=True)
start = time.perf_counter()
t1.start()
time.sleep(0.01)
t2.start()
t1.join()
t2.join()
delta = time.perf_counter() - start
# definitly at least 2 * SLEEP_TIME!
assert (
delta > 1.9 * SLEEP_TIME
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}"
@pytest.mark.timeout(TIMEOUT)
def test_multi_threaded_write_read_exclusive_direct():
lock = ReentrantRWLock()
def read():
lock.acquire_read()
time.sleep(SLEEP_TIME)
lock.release_read()
def write():
lock.acquire_write()
time.sleep(SLEEP_TIME)
lock.release_write()
t1 = Thread(name="write", target=write, daemon=True)
t2 = Thread(name="read", target=read, daemon=True)
start = time.perf_counter()
t1.start()
time.sleep(0.01)
t2.start()
t1.join()
t2.join()
delta = time.perf_counter() - start
# definitly at least 2 * SLEEP_TIME!
assert (
delta > 1.9 * SLEEP_TIME
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}"
@pytest.mark.timeout(TIMEOUT)
def test_multi_threaded_read_readwrite_exclusive():
lock = ReentrantRWLock()
def read():
with lock.for_read():
time.sleep(SLEEP_TIME)
def write():
with lock.for_read():
with lock.for_write():
time.sleep(SLEEP_TIME)
t1 = Thread(name="read", target=read, daemon=True)
t2 = Thread(name="readwrite", target=write, daemon=True)
start = time.perf_counter()
t1.start()
time.sleep(0.01)
t2.start()
t1.join()
t2.join()
delta = time.perf_counter() - start
# definitly at least 2 * SLEEP_TIME!
assert (
delta > 1.9 * SLEEP_TIME
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}"
@pytest.mark.timeout(TIMEOUT)
def test_multi_threaded_readwrite_read_exclusive():
lock = ReentrantRWLock()
def read():
with lock.for_read():
time.sleep(SLEEP_TIME)
def write():
with lock.for_read():
with lock.for_write():
time.sleep(SLEEP_TIME)
t1 = Thread(name="readwrite", target=write, daemon=True)
t2 = Thread(name="read", target=read, daemon=True)
start = time.perf_counter()
t1.start()
time.sleep(0.01)
t2.start()
t1.join()
t2.join()
delta = time.perf_counter() - start
# definitly at least 2 * SLEEP_TIME!
assert (
delta > 1.9 * SLEEP_TIME
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}"
@pytest.mark.timeout(TIMEOUT)
def test_multi_threaded_read_readwrite_exclusive_direct():
lock = ReentrantRWLock()
def read():
lock.acquire_read()
time.sleep(SLEEP_TIME)
lock.release_read()
def write():
lock.acquire_read()
lock.acquire_write()
time.sleep(SLEEP_TIME)
lock.release_write()
lock.release_read()
t1 = Thread(name="read", target=read, daemon=True)
t2 = Thread(name="readwrite", target=write, daemon=True)
start = time.perf_counter()
t1.start()
time.sleep(0.01)
t2.start()
t1.join()
t2.join()
delta = time.perf_counter() - start
# definitly at least 2 * SLEEP_TIME!
assert (
delta > 1.9 * SLEEP_TIME
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}"
@pytest.mark.timeout(TIMEOUT)
def test_multi_threaded_readwrite_read_exclusive_direct():
lock = ReentrantRWLock()
def read():
lock.acquire_read()
time.sleep(SLEEP_TIME)
lock.release_read()
def write():
lock.acquire_read()
lock.acquire_write()
time.sleep(SLEEP_TIME)
lock.release_write()
lock.release_read()
t1 = Thread(name="readwrite", target=write, daemon=True)
t2 = Thread(name="read", target=read, daemon=True)
start = time.perf_counter()
t1.start()
time.sleep(0.01)
t2.start()
t1.join()
t2.join()
delta = time.perf_counter() - start
# definitly at least 2 * SLEEP_TIME!
assert (
delta > 1.9 * SLEEP_TIME
), f"Time for both joins should be {delta=} > {1.9 * SLEEP_TIME}"
@pytest.mark.timeout(TIMEOUT)
def test_multi_threaded_readwrite_readwrite_exclusive():
lock = ReentrantRWLock()
def readwrite():
with lock.for_read():
time.sleep(SLEEP_TIME)
with lock.for_write():
time.sleep(SLEEP_TIME)
t1 = Thread(name="readwrite1", target=readwrite, daemon=True)
t2 = Thread(name="readwrite2", target=readwrite, daemon=True)
start = time.perf_counter()
t1.start()
time.sleep(0.01)
t2.start()
t1.join()
t2.join()
delta = time.perf_counter() - start
# definitly at least 3 * SLEEP_TIME but less than 4 * SLEEP_TIME!
assert (
2.9 * SLEEP_TIME < delta < 3.5 * SLEEP_TIME
), f"Time for both joins should be {2.9 * SLEEP_TIME} < {delta=} < {3.5 * SLEEP_TIME}"
@pytest.mark.timeout(TIMEOUT)
def test_multi_threaded_writeread_writeread_exclusive():
lock = ReentrantRWLock()
def writeread():
with lock.for_write():
time.sleep(SLEEP_TIME)
with lock.for_read():
time.sleep(SLEEP_TIME)
t1 = Thread(name="writeread1", target=writeread, daemon=True)
t2 = Thread(name="writeread2", target=writeread, daemon=True)
start = time.perf_counter()
t1.start()
time.sleep(0.01)
t2.start()
t1.join()
t2.join()
delta = time.perf_counter() - start
# definitly at least 3 * SLEEP_TIME but less than 4 * SLEEP_TIME!
assert (
3.9 * SLEEP_TIME < delta < 4.5 * SLEEP_TIME
), f"Time for both joins should be {3.9 * SLEEP_TIME} < {delta=} < {4.5 * SLEEP_TIME}"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment