Created
March 18, 2018 09:20
-
-
Save linar-jether/0cc77e386c4a1b591bf3963062f0eaef to your computer and use it in GitHub Desktop.
A dask distributed scheduler based on on Celery tasks - Allows reusing an existing celery cluster for ad-hoc computation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import, division, print_function | |
import multiprocessing | |
import pickle | |
from multiprocessing.pool import ThreadPool | |
from celery import shared_task | |
from dask.local import get_async # TODO: get better get | |
from dask.context import _globals | |
from dask.optimize import fuse, cull | |
import cloudpickle | |
def _dumps(x): | |
return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) | |
_loads = pickle.loads | |
from dask.multiprocessing import pack_exception, reraise | |
def _process_get_id(): | |
import platform | |
return '%s:%s' % (platform.node(), multiprocessing.current_process().ident) | |
class SerializedCallable(object): | |
def __init__(self, func): | |
self._func_data = cloudpickle.dumps(func) | |
try: | |
import inspect | |
self._func_repr = inspect.getsource(func) | |
except: | |
self._func_repr = repr(func) | |
def __repr__(self): | |
return '<%s>' % self._func_repr | |
def func(self): | |
return cloudpickle.loads(self._func_data) | |
CELERY_PULL_INTERVAL=0.05 | |
@shared_task(name='dask_custom_task') | |
def _dask_exec_async(*args, **kwargs): | |
serialized_callable = kwargs.pop('_serialized_callable', None) | |
if serialized_callable is None: | |
raise Exception('Missing serialized_callable kwarg, must contain the cloudpickle-serialized callable.') | |
return serialized_callable.func()(*args, **kwargs) | |
def exec_async(func, queue=None, *args, **kwargs): | |
queue = queue | |
if queue is None: | |
raise ValueError("Missing queue name") | |
serialized_callable = SerializedCallable(func) | |
kwargs['_serialized_callable'] = serialized_callable | |
args_ = list(args) | |
if kwargs.pop('_as_sig', False): | |
return _dask_exec_async.s(*args_, **kwargs).set(queue=queue, compression='zlib') | |
return _dask_exec_async.apply_async(args_, kwargs, queue=queue, compression='zlib') | |
def async(func, queue=None, *args, **kwargs): | |
kwargs['_as_sig'] = True | |
return exec_async(func, queue, *args, **kwargs) | |
def get(dsk, keys, num_workers=None, queue=None, func_loads=None, func_dumps=None, | |
optimize_graph=True, **kwargs): | |
""" Multiprocessed get function appropriate for Bags | |
Parameters | |
---------- | |
dsk : dict | |
dask graph | |
keys : object or list | |
Desired results from graph | |
num_workers : int | |
Number of worker processes (defaults to number of cores) | |
queue : str | |
Queue name to which tasks are sent. | |
func_dumps : function | |
Function to use for function serialization | |
(defaults to cloudpickle.dumps) | |
func_loads : function | |
Function to use for function deserialization | |
(defaults to cloudpickle.loads) | |
optimize_graph : bool | |
If True [default], `fuse` is applied to the graph before computation. | |
""" | |
# Optimize Dask | |
dsk2, dependencies = cull(dsk, keys) | |
if optimize_graph: | |
dsk3, dependencies = fuse(dsk2, keys, dependencies) | |
else: | |
dsk3 = dsk2 | |
# We specify marshalling functions in order to catch serialization | |
# errors and report them to the user. | |
loads = func_loads or _globals.get('func_loads') or _loads | |
dumps = func_dumps or _globals.get('func_dumps') or _dumps | |
# Queue to which tasks are sent | |
if queue is None: | |
raise ValueError('Must specify a queue name to submit celery tasks.') | |
# Create thread pool to handle task callbacks | |
pool = ThreadPool(10) | |
def apply_async(func, args=None, kwargs=None, callback=None): | |
future = exec_async(lambda: func(*(args or []), **(kwargs or {})), queue=queue) | |
pool.apply_async(lambda: future.get(interval=CELERY_PULL_INTERVAL, propagate=False), callback=callback) | |
try: | |
# Run | |
result = get_async(apply_async, num_workers, dsk3, keys, | |
get_id=_process_get_id, dumps=dumps, loads=loads, | |
pack_exception=pack_exception, | |
raise_exception=reraise, **kwargs) | |
finally: | |
pool.close() | |
return result | |
if __name__ == '__main__': | |
# To launch worker either add `--include=dask_celery_scheduler` to your celery worker commandline option | |
# Or run this script after creating a celery App instance. | |
from __insert_app_module__ import app | |
app.set_current() | |
queue_name = 'dask_test' | |
# Sample worker listening to the `dask_test` queue | |
app.start(('celery worker -l info -Q %s --events -P solo --include=dask_celery_scheduler' % queue_name).split()) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment