Created
January 7, 2021 22:11
Trio channels with priorities
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from heapq import heappush, heappop | |
from math import inf | |
from trio import BrokenResourceError, ClosedResourceError, EndOfChannel, WouldBlock | |
from trio.abc import ReceiveChannel, SendChannel | |
from trio.lowlevel import ParkingLot, checkpoint, checkpoint_if_cancelled, cancel_shielded_checkpoint | |
from trio._channel import MemoryChannelStats | |
from trio._util import NoPublicConstructor | |
class MemoryChannelState: | |
__slots__ = ('data', 'max_buffer_size', 'number', 'open_send_channels', | |
'open_receive_channels', 'priority', 'receivers', | |
'senders') | |
def __init__(self, max_buffer_size, priority): | |
self.max_buffer_size = max_buffer_size | |
self.priority = priority | |
self.data = [] | |
self.open_send_channels = 0 | |
self.open_receive_channels = 0 | |
self.senders = ParkingLot() | |
self.receivers = ParkingLot() | |
def statistics(self): | |
return MemoryChannelStats(current_buffer_used=len(self.data), | |
max_buffer_size=self.max_buffer_size, | |
open_send_channels=self.open_send_channels, | |
open_receive_channels=self.open_receive_channels, | |
tasks_waiting_send=self.senders.statistics().tasks_waiting, | |
tasks_waiting_receive=self.receivers.statistics().tasks_waiting) | |
class MemorySendChannel(SendChannel, metaclass=NoPublicConstructor): | |
__slots__ = ('_closed', '_state') | |
def __init__(self, state): | |
self._state = state | |
self._closed = False | |
state.open_send_channels += 1 | |
def clone(self): | |
if self._closed: | |
raise ClosedResourceError | |
return MemorySendChannel._create(self._state) | |
def statistics(self): | |
return self._state.statistics() | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.close() | |
def close(self): | |
if self._closed: | |
return | |
self._closed = True | |
self._state.open_send_channels -= 1 | |
if not self._state_open_send_channels: | |
assert not self._state.senders | |
self._state.receivers.unpark_all() | |
async def aclose(self): | |
self.close() | |
await checkpoint() | |
def send_nowait(self, value, *, _could_block=True): | |
if self._closed: | |
raise ClosedResourceError | |
if not self._state.open_receive_channels: | |
raise BrokenResourceError | |
if len(self._state.data) >= self._state.max_buffer_size: | |
assert _could_block | |
raise WouldBlock | |
number = self._state.number | |
self._state.number = number + 1 | |
heappush(self._state.data, | |
(self._state.priority(value), number, value)) | |
if self._state.receivers: | |
assert len(self._state.data) == 1 | |
self._state.receivers.unpark() | |
async def send(self, value): | |
await checkpoint_if_cancelled() | |
try: | |
self.send_nowait(value) | |
except WouldBlock: | |
pass | |
else: | |
await cancel_shielded_checkpoint() | |
return | |
await self._state.senders.park() | |
self.send_nowait(value, _could_block=False) | |
class MemoryReceiveChannel(ReceiveChannel, metaclass=NoPublicConstructor): | |
__slots__ = ('_closed', '_state') | |
def __init__(self, state): | |
self._state = state | |
self._closed = False | |
state.open_receive_channels += 1 | |
def clone(self): | |
if self._closed: | |
raise ClosedResourceError | |
return MemoryReceiveChannel._create(self._state) | |
def statistics(self): | |
return self._state.statistics() | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.close() | |
def close(self): | |
if self._closed: | |
return | |
self._closed = True | |
self._state.open_receive_channels -= 1 | |
if not self._state.open_receive_channels: | |
assert not self._state.receivers | |
self._state.senders.unpark_all() | |
async def aclose(self): | |
self.close() | |
await checkpoint() | |
def receive_nowait(self, *, _could_block=True): | |
if self._closed: | |
raise ClosedResourceError | |
try: | |
_priority, _number, value = heappop(self._state.data) | |
except IndexError: | |
if not self._state.open_send_channels: | |
raise EndOfChannel | |
assert _could_block | |
raise WouldBlock | |
assert not self._state.receivers | |
if self._state.senders: | |
assert (len(self._state.data) == | |
self._state.max_buffer_size - 1) | |
self._state.senders.unpark() | |
return value | |
async def receive(self): | |
await checkpoint_if_cancelled() | |
try: | |
value = self.receive_nowait() | |
except WouldBlock: | |
pass | |
else: | |
await cancel_shielded_checkpoint() | |
return value | |
await self._state.receivers.park() | |
return self.receive_nowait(_could_block=False) | |
def open_memory_channel(max_buffer_size, *, priority=None): | |
if priority is None: | |
priority = lambda x: 0 | |
if max_buffer_size != inf and not isinstance(max_buffer_size, int): | |
raise TypeError("max_buffer_size must be an integer or math.inf") | |
if max_buffer_size < 0: | |
raise ValueError("max_buffer_size must be >= 0") | |
state = MemoryChannelState(max_buffer_size, priority) | |
return (MemorySendChannel._create(state), | |
MemoryReceiveChannel._create(state)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment