Skip to content

Instantly share code, notes, and snippets.

@clbarnes
Created February 3, 2023 10:44
Show Gist options
  • Save clbarnes/d5d00dce176aa074032958b25f46f411 to your computer and use it in GitHub Desktop.
Save clbarnes/d5d00dce176aa074032958b25f46f411 to your computer and use it in GitHub Desktop.
Improve ergonomics for creating a dask array from a stack or chunking of files
#!/usr/bin/env python3
"""Load a list of images in dask."""
from pathlib import Path
import re
from typing import Callable, Iterable, Optional
import dask.array as da
from dask import delayed
import numpy as np
LoaderFn = Callable[[Path], np.ndarray]
def parse_ax(s: str) -> Optional[int]:
"""Get the integer value from a string like ``"ax10"``.
None otherwise.
Examples
--------
>>> parse_ax("ax10")
10
>>> parse_ax("potato")
None
"""
if not s.startswith("ax"):
return None
try:
return int(s[2:])
except ValueError:
return None
def parse_axes(d: dict[str, str]) -> tuple[int, ...]:
"""Get the index tuple from a dict of strings.
Examples
--------
>>> parse_axes({"ax1": "1", "ax0": "0", "spade": "orange"})
(0, 1)
"""
idx_to_val = dict()
for k, v in d.items():
ax_idx = parse_ax(k)
if ax_idx is None:
continue
idx_to_val[ax_idx] = int(v)
out = []
for idx in range(len(idx_to_val)):
if idx not in idx_to_val:
raise ValueError(
f"Axis keys are not consecutive from 0: {sorted(idx_to_val)}"
)
out.append(idx_to_val[idx])
return tuple(out)
def max_axes(maxes: tuple[int, ...], new: tuple[int, ...]) -> tuple[int, ...]:
"""Find the maximum for each index of two tuples."""
if len(maxes) != len(new):
raise ValueError("Tuples have different lengths")
return tuple(max(m, n) for m, n in zip(maxes, new))
def files_to_dask_block(
pattern: str,
paths: Iterable[Path],
loader_fn: LoaderFn,
fill_value=0,
transpose=None,
moveaxis=None,
) -> da.Array:
"""Create a dask array from files describing chunks.
Files are assumed to contain arrays of identical shape and dtype.
The first file matching the pattern will be read eagerly
to determine this shape and dtype;
the others will be wrapped in ``dask.delayed.delayed``.
Parameters
----------
pattern : str
Regex pattern for parsing chunk indices from file path.
The index in dimension D must be in a named capture group
with name ``f"ax{D}"``, which must describe an integer (e.g. ``\\d+``).
e.g. ``r"path/to/section(?P<ax0>\\d+).tiff"``
paths : Iterable[Path]
File paths to look through.
Any which do not match the pattern above will be discarded.
loader_fn : LoaderFn
Function which takes a path and returns a numpy array.
fill_value : int, optional
If chunk indices imply the existence of a chunk
which does not have a file present for it,
instead use an array with the given value, by default 0
transpose : tuple[int], optional
Transpose the resulting array.
If None (default) there will be no transpose.
If True, behave like ``dask.array.transpose(..., axes=None)``.
Otherwise, takes the place of the ``axes`` argument in ``dask.array.transpose``.
moveaxis : tuple[list[int], list[int]], optional
Move axes in resulting array, by default None (no axis moving).
Otherwise, the two elements of the tuple take the place
of the ``source`` and ``destination`` arguments in ``dask.array.moveaxis``.
Returns
-------
da.Array
Dask array made up of chunks which are delayed arrays loaded from files.
"""
regex = re.compile(pattern)
empty = None
idxs = dict()
axmax = None
if moveaxis is not None and transpose is not None:
raise ValueError("Only one of moveaxis and transpose should be given")
for p in paths:
m = regex.match(str(p))
if m is None:
continue
if empty is None:
empty = np.full_like(loader_fn(p), fill_value)
g = m.groupdict()
idx = parse_axes(g)
idxs[idx] = p
if axmax is None:
axmax = idx
else:
axmax = max_axes(axmax, idx)
if empty is None or axmax is None:
raise ValueError("No valid files found")
blocks = np.empty(tuple(a+1 for a in axmax), dtype=object)
# can't use np.full because it tries to unpack internals
blocks.fill(empty)
for idx, path in idxs.items():
blocks[idx] = da.from_delayed(
delayed(loader_fn, name=f"load {path}", pure=True)(path),
empty.shape,
empty.dtype,
name=str(path),
)
darr = da.block(blocks.tolist())
if transpose is not None:
if transpose is True:
transpose = None
darr = da.transpose(darr, transpose)
elif moveaxis is not None:
source, destination = moveaxis
darr = da.moveaxis(darr, source, destination)
return darr
def files_to_dask_stack(paths: list[Path], loader_fn: LoaderFn, axis=0) -> da.Array:
"""Create a dask array from a stack of delayed reads of files.
Assumes all files contain non-offset arrays of the same shape and dtype.
Parameters
----------
paths : list[Path]
Ordered list of files to read.
loader_fn : LoaderFn
Callable which takes a path returns a numpy array.
axis : int, optional
Where to put new axis created by stacking, by default 0.
See ``dask.array.stack`` for more information.
Returns
-------
da.Array
Dask array which will lazily read files as necessary.
"""
arr = loader_fn(paths[0])
seq = [
da.from_delayed(
delayed(loader_fn, name=f"load {p}", pure=True)(p),
arr.shape,
arr.dtype,
name=str(p)
)
for p in paths
]
return da.stack(seq, axis)
if __name__ == "__main__":
arr = np.ones((2, 2))
outdir = Path("data")
outdir.mkdir(exist_ok=True)
for val in range(1, 4):
x, y = divmod(val, 2)
fpath = outdir / f"array_{x}_{y}.npy"
np.save(fpath, arr * val)
dask_arr = files_to_dask_block(
r".*/array_(?P<ax0>\d+)_(?P<ax1>\d+).npy",
outdir.glob("*.npy"),
np.load,
fill_value=0,
)
# until this point, one array has been read
# to determine the chunks' shape and dtype;
# the others have not
print(np.array(dask_arr))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment