Created
September 8, 2024 08:55
-
-
Save synodriver/ae8c9afb6d52b54d1bf4705f0921d30c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Copyright (c) 2008-2022 synodriver <[email protected]> | |
""" | |
import asyncio | |
from enum import IntEnum | |
from typing import List, Literal, Optional | |
class NotAvailable(Exception): | |
pass | |
class LockState(IntEnum): | |
empty = 0 # 空 | |
reading = 1 # 只有读锁 | |
writing = 2 # 有写锁 | |
waiting_write = 3 # 有读锁,没有写锁,但是有写锁在等待队列,因此此时不能继续获取读锁 | |
class RWLock: | |
def __init__( | |
self, blocking: bool = True, *, loop: Optional[asyncio.AbstractEventLoop] = None | |
): | |
self.blocking = blocking | |
self._read_waiters = [] # type: List[asyncio.Future] | |
self._write_waiters = [] # type: List[asyncio.Future] | |
self._pending_reads = 0 # type: int | |
self._pending_writes = 0 # type: Literal[0, 1] | |
# self._state = LockState.empty | |
self._loop = loop or asyncio.get_event_loop() | |
@property | |
def state(self) -> LockState: | |
if not self._pending_reads and not self._pending_writes: | |
return LockState.empty | |
elif self._pending_reads and not self._write_waiters: | |
return LockState.reading | |
elif self._pending_writes: | |
return LockState.writing | |
elif self._pending_reads and self._write_waiters: | |
return LockState.waiting_write | |
else: | |
raise RuntimeError | |
async def acquire(self, mode: Literal["r", "w"] = "r"): | |
if mode == "r": | |
if self.state == LockState.empty or self.state == LockState.reading: | |
self._pending_reads += 1 | |
return # 空的 或者 只有读锁 可以立刻获取 | |
elif self.blocking: | |
waiter = self._loop.create_future() | |
self._read_waiters.append(waiter) | |
try: | |
await waiter # 被cancel也不要紧,finally总会删除的 | |
finally: | |
self._read_waiters.remove(waiter) # 写锁释放的时候才会唤醒获取读锁的协程 | |
return await self.acquire(mode) # 没出问题才会到这里 | |
else: | |
raise NotAvailable | |
elif mode == "w": | |
if self.state == LockState.empty: | |
self._pending_writes = 1 | |
return | |
elif self.blocking: | |
waiter = self._loop.create_future() | |
self._write_waiters.append(waiter) | |
try: | |
await waiter | |
finally: | |
self._write_waiters.remove(waiter) | |
assert self._pending_writes == 1 | |
else: | |
raise NotAvailable | |
else: | |
raise ValueError("mode must be 'r' or 'w'") | |
async def release(self, mode: Literal["r", "w"] = "r"): | |
if mode == "r": | |
if self.state in (LockState.reading, LockState.waiting_write): | |
self._pending_reads -= 1 | |
if self._pending_reads == 0: | |
if self._write_waiters: # 轮一下写锁等待队列 | |
self._write_waiters[0].set_result(None) | |
self._pending_writes = 1 | |
else: | |
raise ValueError("can not release more than acquire") | |
elif mode == "w": | |
if self.state == LockState.writing: | |
self._pending_writes = 0 | |
if self._write_waiters: | |
self._write_waiters[0].set_result(None) | |
self._pending_writes = 1 | |
else: | |
assert self.state == LockState.empty | |
for waiter in self._read_waiters: | |
waiter.set_result(None) | |
else: | |
raise ValueError("can not release write lock without acquire it") | |
else: | |
raise ValueError("mode must be 'r' or 'w'") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment