Skip to content

Instantly share code, notes, and snippets.

@conradry
Created April 24, 2026 17:06
Show Gist options
  • Select an option

  • Save conradry/c5a611cdada8caec3d50a295177f2a0c to your computer and use it in GitHub Desktop.

Select an option

Save conradry/c5a611cdada8caec3d50a295177f2a0c to your computer and use it in GitHub Desktop.
Pytorch dataloader for zarr using async and kornia augmentation
"""Minimal ``AsyncZarrDataset`` pattern: async zarr crop reads + kornia collate.
Shows a :class:`torch.utils.data.Dataset` that fans out bounding-box crop reads
via ``zarr.AsyncArray.getitem`` + ``asyncio.gather``, decoded by the ``zarrs``
Rust codec pipeline. A kornia ``AugmentationSequential`` is wired as the
DataLoader ``collate_fn`` so stacking, device upload, and augmentation happen
in one batched call.
The zarr handle is opened lazily in ``_ensure_initialized`` rather than
``__init__`` so the dataset stays picklable across spawn-based workers (the
zarr array handle is not picklable).
Run:
python async_zarr_kornia_demo.py
"""
import asyncio
import os
import shutil
import tempfile
import numpy as np
import obstore
import torch
import zarr
import zarrs # noqa: F401 — registers ``zarrs.ZarrsCodecPipeline``
from torch.utils.data import DataLoader, Dataset
from zarr.core.sync import sync as zarr_sync
class AsyncZarrDataset(Dataset):
"""Dataset of bounding-box crops read via async zarr + zarrs codec.
Implements ``__getitems__`` (plural) so the DataLoader hands us a whole
batch of indices in one call; we issue one ``getitem`` per crop and await
them together with ``asyncio.gather``. Returns a pre-stacked
``(B, *box_shape, *trailing)`` ndarray. Trailing is almost always a channel dimension
"""
def __init__(
self,
zarr_path: str,
array_key: str,
box_shape: tuple[int, ...],
min_corners: np.ndarray,
max_corners: np.ndarray,
) -> None:
if min_corners.shape != max_corners.shape:
raise ValueError(f"min/max corner shape mismatch: {min_corners.shape} vs {max_corners.shape}")
if min_corners.shape[1] != len(box_shape):
raise ValueError(f"corners have {min_corners.shape[1]} dims but box_shape has {len(box_shape)}")
self._zarr_path = zarr_path
self._array_key = array_key
self._box_shape = tuple(int(b) for b in box_shape)
self._min_corners = np.ascontiguousarray(min_corners, dtype=np.int64)
self._max_corners = np.ascontiguousarray(max_corners, dtype=np.int64)
self._async_arr = None
self._trailing_ndim: int = 0
def __len__(self) -> int:
return int(self._min_corners.shape[0])
def _ensure_initialized(self) -> None:
if self._async_arr is not None:
return
zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"})
store = obstore.store.LocalStore(prefix=os.path.abspath(self._zarr_path))
arr = zarr.open_group(zarr.storage.ObjectStore(store), mode="r")[self._array_key]
self._async_arr = arr._async_array
self._trailing_ndim = len(arr.shape) - len(self._box_shape)
def __getitem__(self, idx: int) -> np.ndarray:
return self.__getitems__([idx])[0]
def __getitems__(self, indices: list[int]) -> np.ndarray:
self._ensure_initialized()
idx = np.asarray(indices, dtype=np.int64)
mins = self._min_corners[idx]
maxes = self._max_corners[idx]
trailing = (slice(None),) * self._trailing_ndim
async def _gather():
coros = [
self._async_arr.getitem(
tuple(slice(int(lo[j]), int(hi[j])) for j in range(mins.shape[1])) + trailing
)
for lo, hi in zip(mins, maxes, strict=True)
]
return await asyncio.gather(*coros)
return np.stack(zarr_sync(_gather()), axis=0)
def __getstate__(self) -> dict:
state = self.__dict__.copy()
state["_async_arr"] = None # zarr handle isn't picklable; reopen in worker
return state
class KorniaCollate:
"""Collate fn: NHWC→NCHW, device upload, kornia augmentation.
``AsyncZarrDataset.__getitems__`` returns a pre-stacked ndarray of shape
``(B, H, W)`` or ``(B, H, W, C)``, so this collate skips the usual
per-sample stacking and goes straight to the batched augmentation.
"""
def __init__(self, crop_hw: tuple[int, int], device: str, dtype_max: float = 255.0) -> None:
import kornia.augmentation as K
self._device = device
self._dtype_max = dtype_max
self._aug = K.AugmentationSequential(
K.RandomCrop(crop_hw, same_on_batch=False),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
K.RandomRotation(degrees=15.0, p=0.5),
K.ColorJitter(brightness=0.2, contrast=0.2, p=0.5),
same_on_batch=False,
).to(device)
self._aug.eval()
def __call__(self, batch: np.ndarray) -> torch.Tensor:
if batch.ndim == 3: # (B, H, W) — treat as single-channel
t = torch.from_numpy(batch).unsqueeze(1)
elif batch.ndim == 4: # (B, H, W, C) → (B, C, H, W)
t = torch.from_numpy(np.ascontiguousarray(batch)).permute(0, 3, 1, 2)
else:
raise ValueError(f"unexpected batch ndim {batch.ndim}; expected 3 or 4")
t = t.to(torch.float32).div_(self._dtype_max).to(self._device, non_blocking=True)
return self._aug(t)
def _write_demo_zarr(zarr_path: str, shape: tuple[int, int, int]) -> None:
"""Write a small random 3-channel image into a local zarr for the demo."""
os.makedirs(zarr_path, exist_ok=True)
store = obstore.store.LocalStore(prefix=os.path.abspath(zarr_path))
root = zarr.open_group(zarr.storage.ObjectStore(store), mode="w")
arr = root.create_array("image", shape=shape, chunks=(128, 128, shape[2]), dtype=np.uint8)
arr[:] = np.random.default_rng(0).integers(0, 256, size=shape, dtype=np.uint8)
def _random_boxes(
array_hw: tuple[int, int], box_hw: tuple[int, int], n: int
) -> tuple[np.ndarray, np.ndarray]:
rng = np.random.default_rng(42)
ys = rng.integers(0, array_hw[0] - box_hw[0] + 1, size=n)
xs = rng.integers(0, array_hw[1] - box_hw[1] + 1, size=n)
mins = np.stack([ys, xs], axis=1).astype(np.int64)
maxes = mins + np.array(box_hw, dtype=np.int64)
return mins, maxes
if __name__ == "__main__":
image_shape = (512, 512, 3)
box_hw = (128, 128)
crop_hw = (112, 112) # kornia RandomCrop target
n_boxes, batch_size = 32, 8
tmp = tempfile.mkdtemp()
try:
zarr_path = os.path.join(tmp, "demo.zarr")
_write_demo_zarr(zarr_path, image_shape)
mins, maxes = _random_boxes(image_shape[:2], box_hw, n_boxes)
dataset = AsyncZarrDataset(zarr_path, "image", box_hw, mins, maxes)
# ``spawn`` (not the default ``fork``) is required by zarr
loader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=KorniaCollate(crop_hw, device="cpu"),
num_workers=4,
multiprocessing_context="spawn",
persistent_workers=True,
)
for i, batch in enumerate(loader):
print(f"batch {i}: shape={tuple(batch.shape)} dtype={batch.dtype} device={batch.device}")
finally:
shutil.rmtree(tmp)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment