Created
February 3, 2023 10:44
-
-
Save clbarnes/d5d00dce176aa074032958b25f46f411 to your computer and use it in GitHub Desktop.
Improve ergonomics for creating a dask array from a stack or chunking of files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Load a list of images in dask.""" | |
from pathlib import Path | |
import re | |
from typing import Callable, Iterable, Optional | |
import dask.array as da | |
from dask import delayed | |
import numpy as np | |
LoaderFn = Callable[[Path], np.ndarray] | |
def parse_ax(s: str) -> Optional[int]: | |
"""Get the integer value from a string like ``"ax10"``. | |
None otherwise. | |
Examples | |
-------- | |
>>> parse_ax("ax10") | |
10 | |
>>> parse_ax("potato") | |
None | |
""" | |
if not s.startswith("ax"): | |
return None | |
try: | |
return int(s[2:]) | |
except ValueError: | |
return None | |
def parse_axes(d: dict[str, str]) -> tuple[int, ...]: | |
"""Get the index tuple from a dict of strings. | |
Examples | |
-------- | |
>>> parse_axes({"ax1": "1", "ax0": "0", "spade": "orange"}) | |
(0, 1) | |
""" | |
idx_to_val = dict() | |
for k, v in d.items(): | |
ax_idx = parse_ax(k) | |
if ax_idx is None: | |
continue | |
idx_to_val[ax_idx] = int(v) | |
out = [] | |
for idx in range(len(idx_to_val)): | |
if idx not in idx_to_val: | |
raise ValueError( | |
f"Axis keys are not consecutive from 0: {sorted(idx_to_val)}" | |
) | |
out.append(idx_to_val[idx]) | |
return tuple(out) | |
def max_axes(maxes: tuple[int, ...], new: tuple[int, ...]) -> tuple[int, ...]: | |
"""Find the maximum for each index of two tuples.""" | |
if len(maxes) != len(new): | |
raise ValueError("Tuples have different lengths") | |
return tuple(max(m, n) for m, n in zip(maxes, new)) | |
def files_to_dask_block( | |
pattern: str, | |
paths: Iterable[Path], | |
loader_fn: LoaderFn, | |
fill_value=0, | |
transpose=None, | |
moveaxis=None, | |
) -> da.Array: | |
"""Create a dask array from files describing chunks. | |
Files are assumed to contain arrays of identical shape and dtype. | |
The first file matching the pattern will be read eagerly | |
to determine this shape and dtype; | |
the others will be wrapped in ``dask.delayed.delayed``. | |
Parameters | |
---------- | |
pattern : str | |
Regex pattern for parsing chunk indices from file path. | |
The index in dimension D must be in a named capture group | |
with name ``f"ax{D}"``, which must describe an integer (e.g. ``\\d+``). | |
e.g. ``r"path/to/section(?P<ax0>\\d+).tiff"`` | |
paths : Iterable[Path] | |
File paths to look through. | |
Any which do not match the pattern above will be discarded. | |
loader_fn : LoaderFn | |
Function which takes a path and returns a numpy array. | |
fill_value : int, optional | |
If chunk indices imply the existence of a chunk | |
which does not have a file present for it, | |
instead use an array with the given value, by default 0 | |
transpose : tuple[int], optional | |
Transpose the resulting array. | |
If None (default) there will be no transpose. | |
If True, behave like ``dask.array.transpose(..., axes=None)``. | |
Otherwise, takes the place of the ``axes`` argument in ``dask.array.transpose``. | |
moveaxis : tuple[list[int], list[int]], optional | |
Move axes in resulting array, by default None (no axis moving). | |
Otherwise, the two elements of the tuple take the place | |
of the ``source`` and ``destination`` arguments in ``dask.array.moveaxis``. | |
Returns | |
------- | |
da.Array | |
Dask array made up of chunks which are delayed arrays loaded from files. | |
""" | |
regex = re.compile(pattern) | |
empty = None | |
idxs = dict() | |
axmax = None | |
if moveaxis is not None and transpose is not None: | |
raise ValueError("Only one of moveaxis and transpose should be given") | |
for p in paths: | |
m = regex.match(str(p)) | |
if m is None: | |
continue | |
if empty is None: | |
empty = np.full_like(loader_fn(p), fill_value) | |
g = m.groupdict() | |
idx = parse_axes(g) | |
idxs[idx] = p | |
if axmax is None: | |
axmax = idx | |
else: | |
axmax = max_axes(axmax, idx) | |
if empty is None or axmax is None: | |
raise ValueError("No valid files found") | |
blocks = np.empty(tuple(a+1 for a in axmax), dtype=object) | |
# can't use np.full because it tries to unpack internals | |
blocks.fill(empty) | |
for idx, path in idxs.items(): | |
blocks[idx] = da.from_delayed( | |
delayed(loader_fn, name=f"load {path}", pure=True)(path), | |
empty.shape, | |
empty.dtype, | |
name=str(path), | |
) | |
darr = da.block(blocks.tolist()) | |
if transpose is not None: | |
if transpose is True: | |
transpose = None | |
darr = da.transpose(darr, transpose) | |
elif moveaxis is not None: | |
source, destination = moveaxis | |
darr = da.moveaxis(darr, source, destination) | |
return darr | |
def files_to_dask_stack(paths: list[Path], loader_fn: LoaderFn, axis=0) -> da.Array: | |
"""Create a dask array from a stack of delayed reads of files. | |
Assumes all files contain non-offset arrays of the same shape and dtype. | |
Parameters | |
---------- | |
paths : list[Path] | |
Ordered list of files to read. | |
loader_fn : LoaderFn | |
Callable which takes a path returns a numpy array. | |
axis : int, optional | |
Where to put new axis created by stacking, by default 0. | |
See ``dask.array.stack`` for more information. | |
Returns | |
------- | |
da.Array | |
Dask array which will lazily read files as necessary. | |
""" | |
arr = loader_fn(paths[0]) | |
seq = [ | |
da.from_delayed( | |
delayed(loader_fn, name=f"load {p}", pure=True)(p), | |
arr.shape, | |
arr.dtype, | |
name=str(p) | |
) | |
for p in paths | |
] | |
return da.stack(seq, axis) | |
if __name__ == "__main__": | |
arr = np.ones((2, 2)) | |
outdir = Path("data") | |
outdir.mkdir(exist_ok=True) | |
for val in range(1, 4): | |
x, y = divmod(val, 2) | |
fpath = outdir / f"array_{x}_{y}.npy" | |
np.save(fpath, arr * val) | |
dask_arr = files_to_dask_block( | |
r".*/array_(?P<ax0>\d+)_(?P<ax1>\d+).npy", | |
outdir.glob("*.npy"), | |
np.load, | |
fill_value=0, | |
) | |
# until this point, one array has been read | |
# to determine the chunks' shape and dtype; | |
# the others have not | |
print(np.array(dask_arr)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment