Skip to content

Instantly share code, notes, and snippets.

@loganlinn
Created November 27, 2019 20:55
Show Gist options
  • Save loganlinn/fb562868cfa0629874586c951f414891 to your computer and use it in GitHub Desktop.
Save loganlinn/fb562868cfa0629874586c951f414891 to your computer and use it in GitHub Desktop.
dask graph wrapper a la https://github.com/plumatic/plumbing - experimental
import itertools
from typing import Iterable
import dask
import dask.core
import dask.multiprocessing
import dask.optimization
class LazyGraph:
def __init__(self, dsk: dict, optimize_graph=True):
self._dsk = dsk
self._optimize_graph = optimize_graph
self._get = dask.multiprocessing.get
def __getattr__(self, key):
if key not in self._dsk:
raise AttributeError(key)
if not dask.istask(self._dsk[key]):
return self._dsk[key]
def _runtask(**kwargs):
# copy of graph
dsk = dict(self._dsk)
# parameter substitution
for task_arg_key, task_arg_val in kwargs.items():
dsk[key] = dask.core.subs(
dsk[key], task_arg_key, dask.core.quote(task_arg_val)
)
# optimize grpah
keys = [key]
dsk2, dependencies = dask.optimization.cull(dsk, keys)
if self._optimize_graph:
dsk3, dependencies = dask.optimization.fuse(dsk2, keys, dependencies)
else:
dsk3 = dsk2
return self._get(dsk3, keys)
return _runtask
def __dir__(self) -> Iterable[str]:
return itertools.chain(super().__dir__(), self._dsk.keys())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment