Last active
November 15, 2020 15:14
-
-
Save anurag-7/6eeb8684010a7f38c18052da23bb3113 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import contextlib | |
import heapq | |
import types | |
__all__ = ('PrioritySemaphore') | |
class PrioritySemaphore(asyncio.Semaphore): | |
"""A Priority Queue Based Semaphore Implementation (Mostly stolen from STDLIB). | |
Parameters: | |
---------- | |
value: The optional argument gives the initial value for the internal | |
counter; it defaults to 1. If the value given is less than 0, | |
ValueError is raised. | |
default_priority: The optional argument for priority if priority isn't | |
passed when acquiuring the semaphore. | |
Usage: | |
----- | |
sem = PrioritySemaphore(3) | |
# later | |
async with sem(priority=2): | |
# code with said priority goes here | |
# Other way of usage, Not recommended. | |
await sem.acquire(priority=1) | |
try: | |
# work with shared resource | |
finally: | |
sem.release() | |
""" | |
def __init__(self, value=1, default_priority=1, *, loop=None): | |
if value < 0: | |
raise ValueError("Semaphore initial value must be >= 0") | |
self._value = value | |
self._waiters = [] | |
self._default_prio = default_priority | |
if loop is None: | |
self._loop = asyncio.events.get_event_loop() | |
else: | |
self._loop = loop | |
@types.coroutine | |
def __iter__(self): | |
# This is not a coroutine. It is meant to enable the idiom: | |
# | |
# with (yield from lock): | |
# <block> | |
# | |
# as an alternative to: | |
# | |
# yield from lock.acquire() | |
# try: | |
# <block> | |
# finally: | |
# lock.release() | |
# Deprecated, use 'async with' statement: | |
# async with lock: | |
# <block> | |
warnings.warn("'with (yield from lock)' is deprecated " | |
"use 'async with lock' instead", | |
DeprecationWarning, stacklevel=2) | |
yield from self.acquire(self._default_prio) | |
return _ContextManager(self) | |
# The flag is needed for legacy asyncio.iscoroutine() | |
__iter__._is_coroutine = asyncio.coroutines._is_coroutine | |
def _wake_up_next(self): | |
while self._waiters: | |
*prios, waiter = heapq.heappop(self._waiters) | |
if not waiter.done(): | |
waiter.set_result(None) | |
return | |
@contextlib.asynccontextmanager | |
async def __call__(self, priority=None): | |
if priority is None: | |
priority = self._default_prio | |
sem = await self.acquire(priority) | |
try: | |
yield sem | |
finally: | |
self.release() | |
async def acquire(self, priority): | |
while self._value <= 0: | |
fut = self._loop.create_future() | |
heapq.heappush(self._waiters, (priority, id(fut), fut)) | |
try: | |
await fut | |
except: | |
fut.cancel() | |
if self._value > 0 and not fut.cancelled(): | |
self._wake_up_next() | |
raise | |
self._value -= 1 | |
return True | |
async def __acquire_ctx(self, priority): | |
await self.acquire(priority) | |
return asyncio.locks._ContextManager(self) | |
def __await__(self, priority=None): | |
warnings.warn("'with await lock' is deprecated " | |
"use 'async with lock' instead", | |
DeprecationWarning, stacklevel=2) | |
# To make "with await lock" work. | |
if priority is None: | |
priority = self._default_prio | |
return self.__acquire_ctx(priority).__await__() | |
async def __aenter__(self): | |
await self.acquire(self._default_prio) | |
# We have no use for the "as ..." clause in the with | |
# statement for locks. | |
return None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment