Last active
November 19, 2020 05:39
-
-
Save ptmcg/eae1eee566e8bacb948fc2d2d565a7a5 to your computer and use it in GitHub Desktop.
RCU class for read-copy-update synchronization to a shared resource
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# rcu.py | |
# | |
# Paul McGuire - November, 2020 | |
# | |
from contextlib import contextmanager | |
import copy | |
import threading | |
class RcuSynchronizer: | |
""" | |
Class to implement Read-Copy-Update synchronization of a shared resource, in which | |
multiple readers can access the resource without requiring any locking, while | |
updaters may update the resource by solely synchronizing among each other (because | |
they work with a copy of the shared object instead of modifying it in-place). | |
Parameters: | |
managed_object: object to be accessed using RCU synchronization | |
copy_fn: (optional) method to make a copy of the managed object, with method | |
signature: | |
copy_function(obj: T) -> T: | |
if not provided, copy.copy is used | |
Readers may access the shared value by calling rcu.get(). | |
Writers access the shared value using a context manager returned by | |
calling rcu.updating(): | |
with rcu.updating() as shared_updater: | |
value_copy = shared_updater.update_value | |
... code that modifies value_copy ... | |
# if the shared object is mutable, then no additional | |
# code is needed | |
# if the shared object is immutable, then the writer must | |
# write back into the context manager | |
shared_updater.update_value = value_copy | |
""" | |
def __init__(self, managed_object, copy_fn=None): | |
self._shared = [managed_object] | |
self._copy_fn = copy_fn if copy_fn is not None else copy.copy | |
self._update_lock = threading.Lock() | |
def read(self): | |
return self._shared[0] | |
get = read | |
@contextmanager | |
def updating(self): | |
with self._update_lock: | |
try: | |
self.update_value = self._get_copy() | |
yield self | |
finally: | |
self._shared[0] = self.update_value | |
del self.update_value | |
def _get_copy(self): | |
return self._copy_fn(self._shared[0]) | |
def _update(self, obj): | |
# internal method to do synchronized update to shared object | |
with self._update_lock: | |
self._shared[0] = obj | |
if __name__ == '__main__': | |
# | |
# demo code | |
# | |
from collections import deque | |
from random import randint, randrange, shuffle | |
import time | |
shared = RcuSynchronizer([randint(1,10000) for _ in range(10)]) | |
shared_history = deque(maxlen=50) | |
shared_history.append(sum(shared.get())) | |
busted = False | |
def do_read(): | |
global busted | |
# get the current shared value of the list | |
shared_list = shared.get() | |
print("reading", flush=True) | |
# compute sum slowly, giving updaters a chance to modify the list, but our copy | |
# will remain unchanged - so we should always have a consistent snapshot for the life | |
# of our access to the list, without having to make a copy | |
slow_sum = 0 | |
for i in shared_list: | |
time.sleep(0.005) | |
slow_sum += i | |
# call attention to instances where slow_sum is not the same as the sum of any | |
# of the historical values of the shared value (if we got an inconsistent snapshot, | |
# then it shouldn't match up with any of the historical lists) | |
if slow_sum not in shared_history: | |
print("\n>>>>>>>>>>>>BUST!!!<<<<<<<<<<<<<<<<<<<<"*3, flush=True) | |
print("\ngot {}, expected one of {}, (using list {})".format(slow_sum, historical_sums, shared_list)) | |
busted = True | |
def do_update(): | |
print("updating", flush=True) | |
# modify the shared resource | |
with shared.updating() as shared_updater: | |
shared_list = shared_updater.update_value | |
# make up to 5 changes in the list | |
for _ in range(randint(1, 5)): | |
if randint(1, 20) < len(shared_list) / 2: | |
shared_list.pop(randrange(len(shared_list))) | |
else: | |
shared_list.append(randint(1, 10000)) | |
# update shared_history for reader validation | |
shared_history.append(sum(shared_list)) | |
# show the modified resource | |
print(shared.get(), flush=True) | |
iters = 500 | |
updating = True | |
def reader(): | |
""" | |
target method for threaded readers | |
""" | |
while updating: | |
do_read() | |
def writer(): | |
""" | |
target method for threaded updaters | |
""" | |
for i in range(iters): | |
do_update() | |
time.sleep(0.2) | |
readers = [threading.Thread(target=reader) for _ in range(25)] | |
writers = [threading.Thread(target=writer) for _ in range(5)] | |
# kick off threads - after shuffling readers and writers | |
all_threads = readers + writers | |
shuffle(all_threads) | |
for t in all_threads: | |
t.start() | |
# wait for all writers to finish updates | |
for t in writers: | |
t.join() | |
# clear flag telling readers that updating is still in progress, so | |
# they can stop reading | |
updating = False | |
for t in readers: | |
t.join() | |
# any failures? | |
print(f"{busted=}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment