Last active
April 12, 2021 14:25
-
-
Save mniip/83bb6d7463b5ab6db34c8388d95d2dba to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import collections | |
class Var: | |
__slots__ = "value", "lock", "version", "subscribers" | |
def __init__(self, value=None): | |
self.value = value | |
self.lock = asyncio.Lock() | |
self.version = 0 | |
self.subscribers = set() | |
ReadSetEntry = collections.namedtuple("ReadSetEntry", ["version", "value"]) | |
class TxRetryNow(BaseException): | |
pass | |
class TxRetryLater(BaseException): | |
pass | |
class Tx: | |
__slots__ = "read_set", "write_set", "event" | |
def __init__(self): | |
self.read_set = {} | |
self.write_set = {} | |
self.event = asyncio.Event() | |
async def read(self, var): | |
if var in self.write_set: | |
return self.write_set[var] | |
if var in self.read_set: | |
return self.read_set[var].value | |
async with var.lock: | |
entry = ReadSetEntry(var.version, var.value) | |
self.read_set[var] = entry | |
return entry.value | |
def write(self, var, value): | |
self.write_set[var] = value | |
def retry(self): | |
raise TxRetryLater() | |
async def run_stm(fun, *args, **kwargs): | |
tx = Tx() | |
while True: | |
try: | |
value = await fun(tx, *args, **kwargs) | |
lock_order = sorted(tx.write_set, key=id) | |
try: | |
for i in range(len(lock_order)): | |
await lock_order[i].lock.acquire() | |
for var, entry in tx.read_set.items(): | |
if var.version > entry.version: | |
raise TxRetryNow() | |
for var, value in tx.write_set.items(): | |
var.version += 1 | |
var.value = value | |
for event in var.subscribers: | |
event.set() | |
finally: | |
for j in range(i, -1, -1): | |
lock_order[j].lock.release() | |
return value | |
except TxRetryLater: | |
try: | |
for var in tx.read_set: | |
var.subscribers.add(tx.event) | |
await tx.event.wait() | |
finally: | |
for var in tx.read_set: | |
var.subscribers.remove(tx.event) | |
tx.event.clear() | |
except TxRetryNow: | |
pass | |
tx.read_set = {} | |
tx.write_set = {} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment