Created
November 28, 2018 22:06
-
-
Save mumbleskates/2c9cfd76c8ea747b35e2eb16ed2d00fc to your computer and use it in GitHub Desktop.
closeable channel and threadsafe iterator wrapper
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
from collections import deque | |
from queue import Empty, Full | |
from threading import Condition, RLock, Thread | |
from time import monotonic as now | |
from weakref import finalize | |
class ChannelClosed(Exception): | |
"""Exception raised when a channel has been closed.""" | |
pass | |
class Channel(object): | |
""" | |
Like queue.Queue, but can be closed. Iteration goes until the channel is closed. | |
Every item put() into a channel may be provided exactly once to a caller of its get() method, | |
unless drain() is called. The structure is suitable for any number of producers and any number | |
of consumers. The channel will not accept new items if the size of its queue is currently at | |
or over `maxsize`, and will reject or block until space is available. | |
For race condition reasons, channels are not reusable once closed. | |
""" | |
def __init__(self, maxsize=float('inf')): | |
if maxsize < 1: | |
raise ValueError('maxsize must be 1 or more') | |
self.maxsize = maxsize | |
self._closed = False # set to True when the channel is flagged for closure | |
self.mutex = RLock() # lock held whenever the channel's queue is mutated | |
self.is_closed = Condition(self.mutex) # notified when the channel is closed and drained | |
self.not_empty = Condition(self.mutex) # notified when item(s) exist in the queue | |
self.not_full = Condition(self.mutex) # notified when space exists in the queue | |
self._init() | |
# Override these methods to implement other queueing models, as with standard queue.Queue. | |
def _init(self): | |
"""Initialize the queue representation.""" | |
self.queue = deque() | |
def _qsize(self): | |
"""Return the current size of the queue, in whatever unit. MUST be falsy when empty.""" | |
return len(self.queue) | |
def _put(self, item): | |
"""Put an item into the queue.""" | |
self.queue.append(item) | |
def _get(self): | |
"""Get an item from the queue.""" | |
return self.queue.popleft() | |
def get(self, timeout=None): | |
""" | |
Take and return a single item from the channel. | |
With the default timeout of None, blocks until an item is available in the queue. | |
Positive values will block for up to that many seconds waiting, and zero or negative | |
values of timeout do not block. Raises queue.Empty if a timeout is reached. | |
If the channel is closed and drained, raises ChannelClosed. | |
""" | |
with self.mutex: | |
# ensure there are items to get | |
if timeout is None: # blocking indefinitely | |
while not self._qsize(): | |
if self._closed: | |
self.is_closed.notify_all() | |
raise ChannelClosed | |
self.not_empty.wait() | |
else: # block for up to timeout seconds | |
endtime = now() + timeout | |
while not self._qsize(): | |
if self._closed: | |
self.is_closed.notify_all() | |
raise ChannelClosed | |
remaining = endtime - now() | |
if remaining <= 0: | |
raise Empty # timed out | |
self.not_empty.wait(remaining) | |
self.not_full.notify() | |
return self._get() | |
def put(self, item, timeout=None): | |
""" | |
Send an item to the channel. | |
With the default timeout of None, blocks until there is space in the queue to accept it. | |
Positive values will block for up to that many seconds waiting, and zero or negative | |
values of timeout do not block. Raises queue.Full if a timeout is reached. | |
If the channel is closed before success, raises ChannelClosed. | |
""" | |
with self.mutex: | |
# ensure the channel can accept items | |
if timeout is None: # blocking indefinitely | |
while not self._closed: | |
if self._qsize() < self.maxsize: | |
break | |
self.not_full.wait() | |
else: | |
raise ChannelClosed | |
else: # block for up to timeout seconds | |
endtime = now() + timeout | |
while not self._closed: | |
if self._qsize() < self.maxsize: | |
break | |
remaining = endtime - now() | |
if remaining <= 0: | |
raise Full | |
self.not_full.wait(remaining) | |
else: | |
raise ChannelClosed | |
self._put(item) | |
self.not_empty.notify() | |
def put_all(self, items): | |
""" | |
Sends all the items in the provided iterator to the channel, blocking until done. | |
If the channel is closed before success, raises ChannelClosed. | |
This is slightly faster than looping put(), as it reduces mutex thrashing. | |
""" | |
it = iter(items) | |
with self.mutex: | |
# ensure the channel can accept items | |
while True: | |
if self._closed: | |
raise ChannelClosed | |
if self._qsize() < self.maxsize: | |
break | |
self.not_full.wait() | |
while True: | |
# add items to the queue in bulk while there is space | |
for item in it: | |
self._put(item) | |
self.not_empty.notify() | |
if self._qsize() >= self.maxsize: | |
break | |
else: | |
return # run out of items to insert, we are done! | |
# wait for space to be available again | |
while True: | |
self.not_full.wait() | |
if self._closed: | |
raise ChannelClosed | |
if self._qsize() < self.maxsize: | |
break | |
def close(self): | |
""" | |
Flags the channel for closure. This is not reversible. | |
Once this method is called, no new items can be sent to the channel | |
""" | |
with self.mutex: | |
self._closed = True | |
# awaken all threads that need to finish consuming or fail to insert | |
if self._qsize(): # channel is now draining, let threads blocked on put() seppuku | |
self.not_full.notify_all() | |
else: # channel is now closed and dead, let threads blocked on wait_closed() and get() seppuku | |
self.not_empty.notify_all() | |
self.is_closed.notify_all() | |
def drain(self): | |
""" | |
!Danger! Discard all items currently in the channel's queue. | |
Obviously, this can lead to data loss. This method is intended for use when a channel's consumers | |
have died irreplaceably, but other threads are still awaiting the closure of this channel for a | |
clean shutdown. | |
""" | |
with self.mutex: | |
if self._closed: | |
self.is_closed.notify_all() | |
else: | |
self.not_full.notify_all() | |
self._init() | |
def wait_closed(self, timeout=None): | |
""" | |
Wait for the channel to be completely closed and drained. | |
With the default timeout None, blocks indefinitely. Positive values for timeout specify | |
the maximum amount of time to block in seconds, and values less than or equal to zero | |
do not block. | |
Raises TimeoutError if the timeout is reached. | |
""" | |
with self.mutex: | |
endtime = now() + timeout | |
while not self._closed or self._qsize(): | |
remaining = endtime - now() | |
if remaining <= 0: | |
raise TimeoutError | |
self.is_closed.wait(remaining) | |
def status(self): | |
""" | |
Non-authoritative status for the channel. Returns one of three status strings: | |
* 'open' -- available for processing items (does not indicate at which end the bottleneck lies) | |
* 'draining' -- close() has been called, but items still remain to be processed. | |
* 'closed' -- the channel is fully closed and will never reopen. | |
The status returnd may be subject to race conditions and is meant to be advisory only; the | |
channel may have since progressed to a later status. | |
""" | |
if self._closed: | |
if self._qsize(): | |
return 'draining' | |
else: | |
return 'closed' | |
else: | |
return 'open' | |
def __iter__(self): | |
"""Yield values from this channel until it is closed and drained, blocking indefinitely.""" | |
try: | |
while True: | |
yield self.get() | |
except ChannelClosed: | |
return | |
class TaskTrackChannel(Channel): | |
"""A subclass of Channel that provides task completion tracking like queue.Queue.""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.tasks_put = 0 | |
self.tasks_done = 0 | |
self.all_tasks_done = Condition(self.mutex) | |
def _put(self, item): | |
super()._put(item) | |
self.tasks_put += 1 | |
def task_done(self): | |
"""Report one task as done.""" | |
with self.mutex: | |
self.tasks_done += 1 | |
if self.tasks_done > self.tasks_put: | |
raise ValueError('task_done() called too many times') | |
elif self.tasks_done == self.tasks_put: | |
self.all_tasks_done.notify_all() | |
def join(self): | |
""" | |
Block until the number of items put into the channel is the same as the number of | |
tasks reported done. | |
""" | |
with self.mutex: | |
while self.tasks_done < self.tasks_put: | |
self.all_tasks_done.wait() | |
class IterProvider(object): | |
""" | |
Provide a generator to multiple threads. | |
Every time this object provides a new iterator, it spawns a worker thread that gets a new iterator | |
from the provided generator; the iterator returned can then be used safely by any number of threads. | |
Objects provided by this iterator are guaranteed to be passed exactly one time. | |
""" | |
def __init__(self, generator, queue_length=16): | |
self.generator = generator | |
self.queue_length = queue_length | |
def __iter__(self): | |
class Yielder(object): | |
def __init__(self, generator, queue_length): | |
channel = self.channel = Channel(maxsize=queue_length) | |
# must not hold a reference to self to prevent the thread from keeping the iterator alive | |
def work(): | |
for thing in generator(): | |
try: | |
channel.put(thing) | |
except ChannelClosed: | |
# channel was closed by someone else | |
# print("thread ending after channel closed externally") # Debug | |
return | |
channel.close() | |
# print("thread ending after generator exhausted") # Debug | |
Thread(target=work).start() | |
finalize(self, channel.close) | |
def __next__(self): | |
try: | |
return self.channel.get() | |
except ChannelClosed: | |
raise StopIteration | |
def __iter__(self): | |
return self | |
return Yielder(self.generator, self.queue_length) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment