Created
November 27, 2019 20:55
-
-
Save loganlinn/fb562868cfa0629874586c951f414891 to your computer and use it in GitHub Desktop.
dask graph wrapper a la https://github.com/plumatic/plumbing - experimental
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
from typing import Iterable | |
import dask | |
import dask.core | |
import dask.multiprocessing | |
import dask.optimization | |
class LazyGraph: | |
def __init__(self, dsk: dict, optimize_graph=True): | |
self._dsk = dsk | |
self._optimize_graph = optimize_graph | |
self._get = dask.multiprocessing.get | |
def __getattr__(self, key): | |
if key not in self._dsk: | |
raise AttributeError(key) | |
if not dask.istask(self._dsk[key]): | |
return self._dsk[key] | |
def _runtask(**kwargs): | |
# copy of graph | |
dsk = dict(self._dsk) | |
# parameter substitution | |
for task_arg_key, task_arg_val in kwargs.items(): | |
dsk[key] = dask.core.subs( | |
dsk[key], task_arg_key, dask.core.quote(task_arg_val) | |
) | |
# optimize grpah | |
keys = [key] | |
dsk2, dependencies = dask.optimization.cull(dsk, keys) | |
if self._optimize_graph: | |
dsk3, dependencies = dask.optimization.fuse(dsk2, keys, dependencies) | |
else: | |
dsk3 = dsk2 | |
return self._get(dsk3, keys) | |
return _runtask | |
def __dir__(self) -> Iterable[str]: | |
return itertools.chain(super().__dir__(), self._dsk.keys()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment