Last active
January 30, 2019 23:19
-
-
Save Yomguithereal/65a79ed5270765ef55b65e44eec0ec1c to your computer and use it in GitHub Desktop.
Multithreaded iterators
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from collections import Counter | |
from queue import Queue | |
from threading import Condition, Event, Lock, Thread, Timer | |
# TODO: shutdown + end threads | |
# TODO: use thread safe urllib3 | |
# TODO: for ordered case, we can also have an output buffer | |
FOREVER = 365 * 24 * 60 * 60 | |
THE_END_IS_NIGH = object() | |
class ThreadSafeIterator(object): | |
def __init__(self, iterator): | |
self.__iterator = iter(iterator) | |
self.lock = Lock() | |
def __iter__(self): | |
return self | |
def __next__(self): | |
with self.lock: | |
return next(self.__iterator) | |
def multithreaded(iterator, n): | |
safe_iterator = ThreadSafeIterator(enumerate(iterator)) | |
task_queue = Queue(maxsize=n) | |
output_queue = Queue(maxsize=n) | |
finished_lock = Lock() # Really need a lock? | |
finished_times = 0 | |
last_index = -1 | |
last_index_condition = Condition() | |
current_domains_lock = Lock() | |
current_domains = Counter() # Note: counter must delete on 0 | |
buffer = {} # Must be multimap if buffer > 1 | |
waiters = {} # Need a lock? | |
def enqueue(last_name=None): | |
nonlocal finished_times | |
job = None | |
should_wait = False | |
waiter_to_release = None | |
while True: | |
# Can we use the buffer? | |
if last_name is not None: | |
with current_domains_lock: | |
if last_name in buffer: | |
job = buffer.pop(last_name) | |
if last_name in waiters: | |
waiter_to_release = waiters.pop(last_name) | |
break | |
else: | |
# NOTE: could be done only if next job in iterator is from | |
# different name | |
current_domains[last_index] -= 1 | |
# Let's consume iterator | |
job = next(safe_iterator, None) | |
if job is None: | |
break | |
with current_domains_lock: | |
name = job[1][0] | |
if current_domains[name] > 0: | |
# If buffer is full, we wait | |
if (name in buffer): | |
should_wait = True | |
break | |
# Else | |
buffer[name] = job | |
continue | |
else: | |
current_domains[name] += 1 | |
break | |
if should_wait: | |
event = Event() | |
waiters[name] = event | |
event.wait() | |
with current_domains_lock: | |
buffer[name] = job | |
return enqueue(last_name) | |
if job is not None: | |
# print('Adding', job, current_domains, buffer) | |
if waiter_to_release: | |
waiter_to_release.set() | |
task_queue.put(job, timeout=FOREVER) | |
else: | |
with finished_lock: | |
finished_times += 1 | |
# TODO: take into account cases when n < total number of items | |
if finished_times == n: | |
output_queue.put(THE_END_IS_NIGH, timeout=FOREVER) | |
def work_ordered(job): | |
nonlocal last_index | |
seconds = job[1][1] | |
time.sleep(seconds) | |
with last_index_condition: | |
while last_index != job[0] - 1: | |
last_index_condition.wait() | |
last_index = job[0] | |
last_index_condition.notify_all() | |
output_queue.put(job[1], timeout=FOREVER) | |
def work(job): | |
seconds = job[1][1] | |
# print('DOING', job) | |
time.sleep(seconds) | |
# print('DONE', job) | |
output_queue.put(job[1], timeout=FOREVER) | |
# print('OUTPUT', job) | |
def worker(): | |
while True: | |
job = task_queue.get(timeout=FOREVER) | |
if job is None: | |
break | |
work_ordered(job) | |
task_queue.task_done() | |
enqueue(job[1][0]) | |
for i in range(n): | |
thread = Thread(target=worker, daemon=True) | |
thread.start() | |
def boot(): | |
for i in range(n): | |
enqueue() | |
t = Timer(0.00001, boot) | |
t.start() | |
def output(): | |
while True: | |
result = output_queue.get(timeout=FOREVER) | |
# print('DEQUEUED', result) | |
if result is None or result is THE_END_IS_NIGH: | |
break | |
yield result | |
return output() | |
if __name__ == '__main__': | |
FILE = [ | |
('A', 3), | |
('A', 2), | |
('B', 1), | |
('B', 2), | |
('B', 3), | |
('B', 1), | |
('C', 1), | |
('D', 1) | |
] | |
# FILE = [ | |
# ('B', 1), | |
# ('B', 1), | |
# ('B', 1), | |
# ('B', 1), | |
# ('B', 2), | |
# ('B', 1), | |
# ('B', 2) | |
# ] | |
N = 2 | |
for item in multithreaded(iter(FILE), N): | |
print('YIELDED', item) | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment