Skip to content

Instantly share code, notes, and snippets.

@linar-jether
Created March 18, 2018 09:20
Show Gist options
  • Save linar-jether/0cc77e386c4a1b591bf3963062f0eaef to your computer and use it in GitHub Desktop.
Save linar-jether/0cc77e386c4a1b591bf3963062f0eaef to your computer and use it in GitHub Desktop.
A dask distributed scheduler based on on Celery tasks - Allows reusing an existing celery cluster for ad-hoc computation
from __future__ import absolute_import, division, print_function
import multiprocessing
import pickle
from multiprocessing.pool import ThreadPool
from celery import shared_task
from dask.local import get_async # TODO: get better get
from dask.context import _globals
from dask.optimize import fuse, cull
import cloudpickle
def _dumps(x):
return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
_loads = pickle.loads
from dask.multiprocessing import pack_exception, reraise
def _process_get_id():
import platform
return '%s:%s' % (platform.node(), multiprocessing.current_process().ident)
class SerializedCallable(object):
def __init__(self, func):
self._func_data = cloudpickle.dumps(func)
try:
import inspect
self._func_repr = inspect.getsource(func)
except:
self._func_repr = repr(func)
def __repr__(self):
return '<%s>' % self._func_repr
def func(self):
return cloudpickle.loads(self._func_data)
CELERY_PULL_INTERVAL=0.05
@shared_task(name='dask_custom_task')
def _dask_exec_async(*args, **kwargs):
serialized_callable = kwargs.pop('_serialized_callable', None)
if serialized_callable is None:
raise Exception('Missing serialized_callable kwarg, must contain the cloudpickle-serialized callable.')
return serialized_callable.func()(*args, **kwargs)
def exec_async(func, queue=None, *args, **kwargs):
queue = queue
if queue is None:
raise ValueError("Missing queue name")
serialized_callable = SerializedCallable(func)
kwargs['_serialized_callable'] = serialized_callable
args_ = list(args)
if kwargs.pop('_as_sig', False):
return _dask_exec_async.s(*args_, **kwargs).set(queue=queue, compression='zlib')
return _dask_exec_async.apply_async(args_, kwargs, queue=queue, compression='zlib')
def async(func, queue=None, *args, **kwargs):
kwargs['_as_sig'] = True
return exec_async(func, queue, *args, **kwargs)
def get(dsk, keys, num_workers=None, queue=None, func_loads=None, func_dumps=None,
optimize_graph=True, **kwargs):
""" Multiprocessed get function appropriate for Bags
Parameters
----------
dsk : dict
dask graph
keys : object or list
Desired results from graph
num_workers : int
Number of worker processes (defaults to number of cores)
queue : str
Queue name to which tasks are sent.
func_dumps : function
Function to use for function serialization
(defaults to cloudpickle.dumps)
func_loads : function
Function to use for function deserialization
(defaults to cloudpickle.loads)
optimize_graph : bool
If True [default], `fuse` is applied to the graph before computation.
"""
# Optimize Dask
dsk2, dependencies = cull(dsk, keys)
if optimize_graph:
dsk3, dependencies = fuse(dsk2, keys, dependencies)
else:
dsk3 = dsk2
# We specify marshalling functions in order to catch serialization
# errors and report them to the user.
loads = func_loads or _globals.get('func_loads') or _loads
dumps = func_dumps or _globals.get('func_dumps') or _dumps
# Queue to which tasks are sent
if queue is None:
raise ValueError('Must specify a queue name to submit celery tasks.')
# Create thread pool to handle task callbacks
pool = ThreadPool(10)
def apply_async(func, args=None, kwargs=None, callback=None):
future = exec_async(lambda: func(*(args or []), **(kwargs or {})), queue=queue)
pool.apply_async(lambda: future.get(interval=CELERY_PULL_INTERVAL, propagate=False), callback=callback)
try:
# Run
result = get_async(apply_async, num_workers, dsk3, keys,
get_id=_process_get_id, dumps=dumps, loads=loads,
pack_exception=pack_exception,
raise_exception=reraise, **kwargs)
finally:
pool.close()
return result
if __name__ == '__main__':
# To launch worker either add `--include=dask_celery_scheduler` to your celery worker commandline option
# Or run this script after creating a celery App instance.
from __insert_app_module__ import app
app.set_current()
queue_name = 'dask_test'
# Sample worker listening to the `dask_test` queue
app.start(('celery worker -l info -Q %s --events -P solo --include=dask_celery_scheduler' % queue_name).split())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment