Skip to content

Instantly share code, notes, and snippets.

@NelsonMinar
Last active December 5, 2023 15:26
Show Gist options
  • Save NelsonMinar/e11eeb524b4940a40ddf to your computer and use it in GitHub Desktop.
Save NelsonMinar/e11eeb524b4940a40ddf to your computer and use it in GitHub Desktop.
multiprocessing Pool example
#!/usr/bin/env python2
"Demonstrate using multiprocessing.Pool()"
import multiprocessing, time, logging, os, random, signal, pprint, traceback
logging.basicConfig(level=logging.DEBUG)
_L = logging.getLogger()
class JobTimeoutException(Exception):
def __init__(self, jobstack=[]):
super(JobTimeoutException, self).__init__()
self.jobstack = jobstack
# http://stackoverflow.com/questions/8616630/time-out-decorator-on-a-multprocessing-function
def timeout(timeout):
"""
Return a decorator that raises a JobTimeoutException exception
after timeout seconds, if the decorated function did not return.
"""
def decorate(f):
def timeout_handler(signum, frame):
raise JobTimeoutException(traceback.format_stack())
def new_f(*args, **kwargs):
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout)
result = f(*args, **kwargs) # f() always returns, in this scheme
signal.signal(signal.SIGALRM, old_handler) # Old signal handler is restored
signal.alarm(0) # Alarm removed
return result
new_f.func_name = f.func_name
return new_f
return decorate
@timeout(3)
def task(args):
"Sleep t seconds, logging messages at random intervals then return n*n"
# Python 2 Pool.map() can only pass one argument to a task, so we unpack args here
n, t = args
_L.info("PID %s starting task(%d, %.2f)", os.getpid(), n, t)
# Sleep t seconds, printing log messages at random intervals
start_time = time.time()
end_time = time.time() + t
avg_sleep = t / 5.0
while time.time() < end_time:
nap_time = min(random.uniform(avg_sleep*0.5, avg_sleep*1.5), end_time - time.time())
time.sleep(nap_time)
_L.info(u"PID %s \u2603 running for %.2fs", os.getpid(), (time.time() - start_time))
_L.info("PID %s ending task(%d, %.2f)", os.getpid(), n, t)
return n*n
# The work we want to do
tasks = ((1, 2.0), (2, 1.0), (3, 3.2), (4, 1.0), (5, 1.0), (6, 1.0))
pool = multiprocessing.Pool(processes=2, maxtasksperchild=1)
# Run the tasks unordered through the pool and give us an iterator
result_iter = pool.imap_unordered(task, tasks, chunksize = 1)
# Result collection object
result_collection = []
# Iterate through all the results
try:
while True:
try:
# if no timeout is set, Ctrl-C does weird things.
result = result_iter.next(timeout=99999999999)
_L.info("Result received %s", result)
result_collection.append(result)
except JobTimeoutException as timeout_ex:
_L.warning("Job timed out %s", timeout_ex)
_L.warning("Stack trace:\n%s", ''.join(timeout_ex.jobstack))
result_collection.append(None)
except StopIteration:
_L.info("All jobs complete!")
pass
print result_collection
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment