Skip to content

Instantly share code, notes, and snippets.

@synodriver
Created September 8, 2024 08:55
Show Gist options
  • Save synodriver/ae8c9afb6d52b54d1bf4705f0921d30c to your computer and use it in GitHub Desktop.
Save synodriver/ae8c9afb6d52b54d1bf4705f0921d30c to your computer and use it in GitHub Desktop.
"""
Copyright (c) 2008-2022 synodriver <[email protected]>
"""
import asyncio
from enum import IntEnum
from typing import List, Literal, Optional
class NotAvailable(Exception):
pass
class LockState(IntEnum):
empty = 0 # 空
reading = 1 # 只有读锁
writing = 2 # 有写锁
waiting_write = 3 # 有读锁,没有写锁,但是有写锁在等待队列,因此此时不能继续获取读锁
class RWLock:
def __init__(
self, blocking: bool = True, *, loop: Optional[asyncio.AbstractEventLoop] = None
):
self.blocking = blocking
self._read_waiters = [] # type: List[asyncio.Future]
self._write_waiters = [] # type: List[asyncio.Future]
self._pending_reads = 0 # type: int
self._pending_writes = 0 # type: Literal[0, 1]
# self._state = LockState.empty
self._loop = loop or asyncio.get_event_loop()
@property
def state(self) -> LockState:
if not self._pending_reads and not self._pending_writes:
return LockState.empty
elif self._pending_reads and not self._write_waiters:
return LockState.reading
elif self._pending_writes:
return LockState.writing
elif self._pending_reads and self._write_waiters:
return LockState.waiting_write
else:
raise RuntimeError
async def acquire(self, mode: Literal["r", "w"] = "r"):
if mode == "r":
if self.state == LockState.empty or self.state == LockState.reading:
self._pending_reads += 1
return # 空的 或者 只有读锁 可以立刻获取
elif self.blocking:
waiter = self._loop.create_future()
self._read_waiters.append(waiter)
try:
await waiter # 被cancel也不要紧,finally总会删除的
finally:
self._read_waiters.remove(waiter) # 写锁释放的时候才会唤醒获取读锁的协程
return await self.acquire(mode) # 没出问题才会到这里
else:
raise NotAvailable
elif mode == "w":
if self.state == LockState.empty:
self._pending_writes = 1
return
elif self.blocking:
waiter = self._loop.create_future()
self._write_waiters.append(waiter)
try:
await waiter
finally:
self._write_waiters.remove(waiter)
assert self._pending_writes == 1
else:
raise NotAvailable
else:
raise ValueError("mode must be 'r' or 'w'")
async def release(self, mode: Literal["r", "w"] = "r"):
if mode == "r":
if self.state in (LockState.reading, LockState.waiting_write):
self._pending_reads -= 1
if self._pending_reads == 0:
if self._write_waiters: # 轮一下写锁等待队列
self._write_waiters[0].set_result(None)
self._pending_writes = 1
else:
raise ValueError("can not release more than acquire")
elif mode == "w":
if self.state == LockState.writing:
self._pending_writes = 0
if self._write_waiters:
self._write_waiters[0].set_result(None)
self._pending_writes = 1
else:
assert self.state == LockState.empty
for waiter in self._read_waiters:
waiter.set_result(None)
else:
raise ValueError("can not release write lock without acquire it")
else:
raise ValueError("mode must be 'r' or 'w'")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment