Created
September 15, 2009 12:26
-
-
Save ask/187254 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Index: Lib/multiprocessing/pool.py | |
=================================================================== | |
--- Lib/multiprocessing/pool.py (revision 74797) | |
+++ Lib/multiprocessing/pool.py (working copy) | |
@@ -12,11 +12,14 @@ | |
# Imports | |
# | |
+import os | |
+import errno | |
import threading | |
import Queue | |
import itertools | |
import collections | |
import time | |
+from signal import signal, SIGUSR1 | |
from multiprocessing import Process, cpu_count, TimeoutError | |
from multiprocessing.util import Finalize, debug | |
@@ -42,9 +45,21 @@ | |
# Code run by worker processes | |
# | |
-def worker(inqueue, outqueue, initializer=None, initargs=()): | |
+class TimeLimitExceeded(Exception): | |
+ """The time limit has been exceeded and the job has been terminated.""" | |
+ | |
+class SoftTimeLimitExceeded(Exception): | |
+ """The soft time limit has been exceeded. This exception | |
+ is raised to give the job a chance to clean up.""" | |
+ | |
+def soft_timeout_sighandler(signum, frame): | |
+ raise SoftTimeLimitExceeded() | |
+ | |
+def worker(inqueue, outqueue, ackqueue, initializer=None, initargs=()): | |
+ pid = os.getpid() | |
put = outqueue.put | |
get = inqueue.get | |
+ ack = ackqueue.put | |
if hasattr(inqueue, '_writer'): | |
inqueue._writer.close() | |
outqueue._reader.close() | |
@@ -52,6 +67,8 @@ | |
if initializer is not None: | |
initializer(*initargs) | |
+ signal(SIGUSR1, soft_timeout_sighandler) | |
+ | |
while 1: | |
try: | |
task = get() | |
@@ -64,6 +81,7 @@ | |
break | |
job, i, func, args, kwds = task | |
+ ack((job, i, time.time(), pid)) | |
try: | |
result = (True, func(*args, **kwds)) | |
except Exception, e: | |
@@ -80,9 +98,14 @@ | |
''' | |
Process = Process | |
- def __init__(self, processes=None, initializer=None, initargs=()): | |
+ def __init__(self, processes=None, initializer=None, initargs=(), | |
+ timeout=None, soft_timeout=None): | |
self._setup_queues() | |
self._taskqueue = Queue.Queue() | |
+ self.timeout = timeout | |
+ self.soft_timeout = soft_timeout | |
+ self._initializer = initializer | |
+ self._initargs = initargs | |
self._cache = {} | |
self._state = RUN | |
@@ -95,16 +118,7 @@ | |
if initializer is not None and not hasattr(initializer, '__call__'): | |
raise TypeError('initializer must be a callable') | |
- self._pool = [] | |
- for i in range(processes): | |
- w = self.Process( | |
- target=worker, | |
- args=(self._inqueue, self._outqueue, initializer, initargs) | |
- ) | |
- self._pool.append(w) | |
- w.name = w.name.replace('Process', 'PoolWorker') | |
- w.daemon = True | |
- w.start() | |
+ self._pool = [self._add_worker() for i in range(processes)] | |
self._task_handler = threading.Thread( | |
target=Pool._handle_tasks, | |
@@ -114,6 +128,31 @@ | |
self._task_handler._state = RUN | |
self._task_handler.start() | |
+ # Thread processing acknowledgements form the ackqueue. | |
+ self._ack_handler = threading.Thread( | |
+ target=Pool._handle_ack, | |
+ args=(self._ackqueue, self._quick_get_ack, self._cache) | |
+ ) | |
+ self._ack_handler.daemon = True | |
+ self._ack_handler._state = RUN | |
+ self._ack_handler.start() | |
+ | |
+ # Thread killing timedout jobs. | |
+ if self.timeout or self.soft_timeout: | |
+ self._timeout_handler_stopped = threading.Event() | |
+ self._timeout_handler = threading.Thread( | |
+ target=Pool._handle_timeouts, | |
+ args=(self, self._timeout_handler_stopped, self._cache, | |
+ self.soft_timeout, self.timeout) | |
+ ) | |
+ self._timeout_handler.deamon = True | |
+ self._timeout_handler._state = RUN | |
+ self._timeout_handler.start() | |
+ else: | |
+ self._timeout_handler_stopped = None | |
+ self._timeout_handler = None | |
+ | |
+ # Thread processing results in the outqueue. | |
self._result_handler = threading.Thread( | |
target=Pool._handle_results, | |
args=(self._outqueue, self._quick_get, self._cache) | |
@@ -124,17 +163,37 @@ | |
self._terminate = Finalize( | |
self, self._terminate_pool, | |
- args=(self._taskqueue, self._inqueue, self._outqueue, self._pool, | |
- self._task_handler, self._result_handler, self._cache), | |
+ args=(self._taskqueue, self._inqueue, self._outqueue, | |
+ self._ackqueue, self._pool, self._ack_handler, | |
+ self._task_handler, self._result_handler, self._cache, | |
+ self._timeout_handler, | |
+ self._timeout_handler_stopped), | |
exitpriority=15 | |
) | |
+ | |
+ def _add_worker(self): | |
+ """Add another worker to the pool.""" | |
+ w = self.Process( | |
+ target=worker, | |
+ args=(self._inqueue, self._outqueue, self._ackqueue, | |
+ self._initializer, self._initargs) | |
+ ) | |
+ w.name = w.name.replace('Process', 'PoolWorker') | |
+ w.daemon = True | |
+ w.start() | |
+ return w | |
+ def grow(self, n=1): | |
+ self._pool.extend([self._add_worker() for i in range(n)]) | |
+ | |
def _setup_queues(self): | |
from .queues import SimpleQueue | |
self._inqueue = SimpleQueue() | |
self._outqueue = SimpleQueue() | |
+ self._ackqueue = SimpleQueue() | |
self._quick_put = self._inqueue._writer.send | |
self._quick_get = self._outqueue._reader.recv | |
+ self._quick_get_ack = self._ackqueue._reader.recv | |
def apply(self, func, args=(), kwds={}): | |
''' | |
@@ -186,12 +245,25 @@ | |
for i, x in enumerate(task_batches)), result._set_length)) | |
return (item for chunk in result for item in chunk) | |
- def apply_async(self, func, args=(), kwds={}, callback=None): | |
+ def apply_async(self, func, args=(), kwds={}, | |
+ callback=None, accept_callback=None): | |
''' | |
- Asynchronous equivalent of `apply()` builtin | |
+ Asynchronous equivalent of `apply()` builtin. | |
+ | |
+ Callback is called when the functions return value is ready. | |
+ The accept callback is called when the job is accepted to be executed. | |
+ | |
+ Simplified the flow is like this: | |
+ | |
+ >>> if accept_callback: | |
+ ... accept_callback() | |
+ >>> retval = func(*args, **kwds) | |
+ >>> if callback: | |
+ ... callback(retval) | |
+ | |
''' | |
assert self._state == RUN | |
- result = ApplyResult(self._cache, callback) | |
+ result = ApplyResult(self._cache, callback, accept_callback) | |
self._taskqueue.put(([(result._job, None, func, args, kwds)], None)) | |
return result | |
@@ -240,7 +312,6 @@ | |
else: | |
debug('task handler got sentinel') | |
- | |
try: | |
# tell result handler to finish when cache is empty | |
debug('task handler sending sentinel to result handler') | |
@@ -256,14 +327,138 @@ | |
debug('task handler exiting') | |
@staticmethod | |
+ def _handle_timeouts(pool, sentinel_event, cache, t_soft, t_hard): | |
+ thread = threading.current_thread() | |
+ processes = pool._pool | |
+ dirty = set() | |
+ | |
+ def _process_by_pid(pid): | |
+ for index, process in enumerate(processes): | |
+ if process.pid == pid: | |
+ return process, index | |
+ return (None, None) | |
+ | |
+ def _pop_by_pid(pid): | |
+ process, index = _process_by_pid(pid) | |
+ if not process: | |
+ return | |
+ p = processes.pop(index) | |
+ assert p is process | |
+ return process | |
+ | |
+ def _timed_out(start, timeout): | |
+ if not start or not timeout: | |
+ return False | |
+ if time.time() >= start + timeout: | |
+ return True | |
+ | |
+ def _on_soft_timeout(job, i): | |
+ debug('soft time limit exceeded for %i', i) | |
+ process, _index = _process_by_pid(job._accept_pid) | |
+ if not process: | |
+ return | |
+ | |
+ try: | |
+ os.kill(job._accept_pid, SIGUSR1) | |
+ except OSError, exc: | |
+ if exc.errno == errno.ESRCH: | |
+ pass | |
+ else: | |
+ raise | |
+ | |
+ dirty.add(i) | |
+ | |
+ def _on_hard_timeout(job, i): | |
+ debug('hard time limit exceeded for %i', i) | |
+ # Remove from _pool | |
+ process = _pop_by_pid(job._accept_pid) | |
+ # Remove from cache and set return value to an exception. | |
+ job._set(i, (False, TimeLimitExceeded())) | |
+ if not process: | |
+ return | |
+ # Terminate the process and create a new one. | |
+ process.terminate() | |
+ pool.grow(1) | |
+ | |
+ # Inner-loop | |
+ while 1: | |
+ if sentinel_event.isSet(): | |
+ debug('timeout handler recieved sentinel.') | |
+ break | |
+ | |
+ # Remove dirty items not in cache anymore. | |
+ if dirty: | |
+ dirty = set(k for k in dirty if k in cache) | |
+ | |
+ for i, job in cache.items(): | |
+ ack_time = job._time_accepted | |
+ if _timed_out(ack_time, t_hard): | |
+ _on_hard_timeout(job, i) | |
+ elif i not in dirty and _timed_out(ack_time, t_soft): | |
+ _on_soft_timeout(job, i) | |
+ | |
+ time.sleep(1) # Don't waste CPU cycles. | |
+ | |
+ debug('timeout handler exiting') | |
+ | |
+ @staticmethod | |
+ def _handle_ack(ackqueue, get, cache): | |
+ thread = threading.current_thread() | |
+ | |
+ while 1: | |
+ try: | |
+ task = get() | |
+ except (IOError, EOFError), exc: | |
+ debug('ack handler got %s -- exiting', | |
+ exc.__class__.__name__) | |
+ | |
+ if thread._state: | |
+ assert thread._state == TERMINATE | |
+ debug('ack handler found thread._state=TERMINATE') | |
+ break | |
+ | |
+ if task is None: | |
+ debug('ack handler got sentinel') | |
+ break | |
+ | |
+ job, i, time_accepted, pid = task | |
+ try: | |
+ cache[job]._ack(time_accepted, pid) | |
+ except (KeyError, AttributeError): | |
+ # Object gone, or doesn't support _ack (e.g. IMapIterator) | |
+ pass | |
+ | |
+ while cache and thread._state != TERMINATE: | |
+ try: | |
+ task = get() | |
+ except (IOError, EOFError), exc: | |
+ debug('ack handler got %s -- exiting', | |
+ exc.__class__.__name__) | |
+ return | |
+ | |
+ if task is None: | |
+ debug('result handler ignoring extra sentinel') | |
+ continue | |
+ | |
+ job, i = task | |
+ try: | |
+ cache[job]._ack(i) | |
+ except KeyError: | |
+ pass | |
+ | |
+ debug('ack handler exiting: len(cache)=%s, thread._state=%s', | |
+ len(cache), thread._state) | |
+ | |
+ @staticmethod | |
def _handle_results(outqueue, get, cache): | |
thread = threading.current_thread( |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment