Created
October 13, 2016 08:35
-
-
Save dansondergaard/5ef797450517e14ba2f8e3c1ac018868 to your computer and use it in GitHub Desktop.
A lightweight subprocess task queue.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import random | |
| import time | |
| from multiprocessing.connection import Client | |
| from server import StatusRequest, SubmitRequest | |
| class GWFClient: | |
| def __init__(self, *args, **kwargs): | |
| self.client = Client(*args, **kwargs) | |
| def submit(self, target, deps=None): | |
| if deps is None: | |
| deps = [] | |
| request = SubmitRequest(target=target, deps=deps) | |
| self.client.send(request) | |
| return self.client.recv() | |
| def status(self): | |
| request = StatusRequest() | |
| self.client.send(request) | |
| return self.client.recv() | |
| client = GWFClient(('localhost', 25000)) | |
| t1 = client.submit('Target1', deps=[]) | |
| t2 = client.submit('Target2', deps=[]) | |
| t3 = client.submit('Target3', deps=[t1, t2]) | |
| t4 = client.submit('Target4', deps=[t1]) | |
| t5 = client.submit('Target5', deps=[t3]) | |
| t6 = client.submit('Target6', deps=[t2]) | |
| t7 = client.submit('Target7', deps=[t2]) | |
| t8 = client.submit('Target8', deps=[t2]) | |
| t9 = client.submit('Target9', deps=[t2]) | |
| t10 = client.submit('Target10', deps=[t3]) | |
| t11 = client.submit('Target11', deps=[t3]) | |
| t12 = client.submit('Target12', deps=[t3]) | |
| t13 = client.submit('Target13', deps=[t6, t7, t8, t9]) | |
| t14 = client.submit('Target14', deps=[t6, t7, t8, t9]) | |
| t15 = client.submit('Target15', deps=[t6, t7, t8, t9]) | |
| t16 = client.submit('Target16', deps=[t6, t7, t8, t9]) | |
| t17 = client.submit('Target17', deps=[t6, t7, t8, t9]) | |
| t18 = client.submit('Target18', deps=[t6, t7, t8, t9]) | |
| t19 = client.submit('Target19', deps=[t13, t14, t15]) | |
| t20 = client.submit('Target20', deps=[t13, t14, t15]) | |
| t21 = client.submit('Target21', deps=[t13, t14, t15]) | |
| t22 = client.submit('Target22', deps=[t13, t14, t15]) | |
| t23 = client.submit('Target23', deps=[t16, t17, t18, t19, t20]) | |
| t24 = client.submit('Target24', deps=[t16, t17, t18, t19, t20]) | |
| t25 = client.submit('Target25', deps=[t16, t17, t18, t19, t20]) | |
| t26 = client.submit('Target26', deps=[t16, t17, t18, t19, t20]) | |
| while True: | |
| for task_id, state in client.status().items(): | |
| print(task_id[0:10], state) | |
| time.sleep(10) | |
| print('-' * 80) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import logging | |
| import random | |
| import sys | |
| import time | |
| import traceback | |
| import uuid | |
| from multiprocessing import Manager | |
| from multiprocessing.connection import Listener | |
| from multiprocessing.pool import ThreadPool | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format='%(asctime)s - %(levelname)s - %(name)s - %(threadName)s - %(message)s' | |
| ) | |
| def gen_task_id(): | |
| return uuid.uuid4().hex | |
| class GWFServerError(Exception): | |
| pass | |
| class State: | |
| pending = 0 | |
| started = 1 | |
| completed = 2 | |
| failed = 3 | |
| class Request: | |
| pass | |
| class SubmitRequest(Request): | |
| def __init__(self, target, deps): | |
| self.target = target | |
| self.deps = deps | |
| def handle(self, task_queue, status_dict): | |
| task_id = gen_task_id() | |
| status_dict[task_id] = State.pending | |
| task_queue.put((task_id, self)) | |
| return task_id | |
| class StatusRequest(Request): | |
| def __init__(self, task_ids=None): | |
| self.task_ids = task_ids | |
| def handle(self, task_queue, status_dict): | |
| if self.task_ids is not None: | |
| return {k: v for k, v in status_dict.items() if k in self.task_ids} | |
| return dict(status_dict) | |
| def worker(task_queue, status_dict, deps_dict, deps_lock): | |
| while True: | |
| task_id, request = task_queue.get() | |
| # If the task isn't pending, it may have been resubmitted by multiple | |
| # tasks in different workers. We shouldn't run it twice, so we'll skip | |
| # it. | |
| if status_dict[task_id] != State.pending: | |
| continue | |
| with deps_lock: | |
| any_dep_failed = any( | |
| status_dict[dep_id] == State.failed | |
| for dep_id in request.deps | |
| ) | |
| if any_dep_failed: | |
| status_dict[task_id] = State.failed | |
| logger.error( | |
| 'Task %s failed since a dependency failed.', | |
| task_id, | |
| exc_info=True | |
| ) | |
| continue | |
| has_non_satisfied_dep = False | |
| for dep_id in request.deps: | |
| if status_dict[dep_id] != State.completed: | |
| logger.debug( | |
| 'Task %s set to wait for %s.', | |
| task_id, | |
| dep_id, | |
| ) | |
| if dep_id not in deps_dict: | |
| deps_dict[dep_id] = [] | |
| x = deps_dict[dep_id] | |
| x.append((task_id, request)) | |
| deps_dict[dep_id] = x | |
| has_non_satisfied_dep = True | |
| if has_non_satisfied_dep: | |
| continue | |
| logger.debug( | |
| 'Task %s started target %r.', | |
| task_id, request.target | |
| ) | |
| status_dict[task_id] = State.started | |
| try: | |
| # Fake running a task. Just sleep for a random number of seconds. | |
| # Also, randomly fail. | |
| x = random.randint(2, 10) | |
| time.sleep(x) | |
| # if x < 13: | |
| # raise Exception('wtf!') | |
| except: | |
| status_dict[task_id] = State.failed | |
| logger.error( | |
| 'Task %s failed.', | |
| task_id, | |
| exc_info=True | |
| ) | |
| else: | |
| status_dict[task_id] = State.completed | |
| logger.debug( | |
| 'Task %s completed target %r.', | |
| task_id, request.target | |
| ) | |
| finally: | |
| with deps_lock: | |
| if task_id in deps_dict: | |
| logger.debug( | |
| 'Task %s has waiting dependents. Requeueing.', task_id) | |
| for dep_task_id, dep_request in deps_dict[task_id]: | |
| task_queue.put((dep_task_id, dep_request)) | |
| def handle_request(request, task_queue, status_dict): | |
| try: | |
| logger.debug('Received request %r.', request) | |
| return request.handle(task_queue, status_dict) | |
| except: | |
| logger.error('Invalid request %r.', request, exc_info=True) | |
| def handle_client(conn, task_queue, status_dict): | |
| logger.debug('Accepted client connection.') | |
| try: | |
| while True: | |
| request = conn.recv() | |
| response = handle_request(request, task_queue, status_dict) | |
| if response is not None: | |
| conn.send(response) | |
| except EOFError: | |
| logger.debug('Client connection closed.') | |
| def wait_for_clients(address, task_queue, status_dict): | |
| serv = Listener(address) | |
| while True: | |
| try: | |
| client = serv.accept() | |
| handle_client(client, task_queue, status_dict) | |
| except Exception: | |
| traceback.print_exc() | |
| def start(hostname='', port=25000, workers=None): | |
| try: | |
| with Manager() as manager: | |
| status_dict = manager.dict() | |
| deps_dict = manager.dict() | |
| task_queue = manager.Queue() | |
| deps_lock = manager.Lock() | |
| workers = ThreadPool( | |
| initializer=worker, | |
| initargs=(task_queue, status_dict, deps_dict, deps_lock), | |
| ) | |
| wait_for_clients((hostname, port), task_queue, status_dict) | |
| except KeyboardInterrupt: | |
| logger.debug('Shutting down...') | |
| workers.close() | |
| sys.exit(0) | |
| if __name__ == '__main__': | |
| start() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment