Last active
June 6, 2017 20:25
-
-
Save Dref360/30e0eda6f8f748460057ab2bb2aa835a to your computer and use it in GitHub Desktop.
Ordered Multiprocess executor to be used in Keras. (Looks like Pytorch's dataloader)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
from concurrent.futures import ProcessPoolExecutor | |
from itertools import cycle | |
from queue import Queue | |
from threading import Thread, Event | |
from keras.engine.training import GeneratorEnqueuer | |
class Dataset(): | |
def __getitem__(self, index): | |
raise NotImplementedError | |
def __len__(self): | |
raise NotImplementedError | |
class ExampleDataset(Dataset): | |
def __getitem__(self, index): | |
time.sleep(1) | |
return os.getpid(), index | |
def __len__(self): | |
return 100 | |
class MultiProcessExecutor(): | |
def __init__(self, dataset, workers=1, max_q_size=5): | |
self.workers = workers | |
self.executor = ProcessPoolExecutor(self.workers) | |
self.dataset = dataset | |
self.queue = Queue(max_q_size) | |
self.run_thread = None | |
self.stop_signal = Event() | |
def is_running(self): | |
return self.stop_signal.is_set() | |
def start(self): | |
self.run_thread = Thread(target=self.run) | |
self.run_thread.daemon = True | |
self.run_thread.start() | |
def run(self): | |
""" This will queue up tasks in order """ | |
indexes = cycle(range(len(self.dataset))) | |
for i in indexes: | |
if self.stop_signal.is_set(): | |
return | |
self.queue.put(self.executor.submit(self.dataset.__getitem__, [i]), block=True) | |
def get_item(self): | |
try: | |
while True: | |
yield self.queue.get(block=True).result() | |
except Exception as e: | |
self.stop() | |
print('MultiProcessExecutor has stopped because of :', type(e).__name__, str(e), flush=True) | |
raise StopIteration | |
def stop(self): | |
self.executor.shutdown() | |
self.stop_signal.set() | |
with self.queue.mutex: | |
self.queue.queue.clear() | |
self.queue.unfinished_tasks = 0 | |
self.queue.not_full.notify() | |
self.run_thread.join() | |
dataset = ExampleDataset() | |
executor = MultiProcessExecutor(dataset) | |
executor.start() | |
getter = executor.get_item() | |
start = time.time() | |
for i in range(100): | |
result = next(getter) | |
print("Took executor", time.time() - start) | |
""" | |
Comparing to Keras | |
""" | |
def keras_gen(): | |
while True: | |
time.sleep(1) | |
yield os.getpid() | |
qu = GeneratorEnqueuer(keras_gen(), pickle_safe=True) | |
qu.start(5, 10) | |
start = time.time() | |
for i in range(100): | |
while not qu.queue.qsize(): | |
time.sleep(0.5) | |
result = qu.queue.get() | |
print("Took Keras", time.time() - start) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is as expensive as Keras GeneratorEnqueuer while preserving orders.