Created
April 24, 2026 17:06
-
-
Save conradry/c5a611cdada8caec3d50a295177f2a0c to your computer and use it in GitHub Desktop.
Pytorch dataloader for zarr using async and kornia augmentation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Minimal ``AsyncZarrDataset`` pattern: async zarr crop reads + kornia collate. | |
| Shows a :class:`torch.utils.data.Dataset` that fans out bounding-box crop reads | |
| via ``zarr.AsyncArray.getitem`` + ``asyncio.gather``, decoded by the ``zarrs`` | |
| Rust codec pipeline. A kornia ``AugmentationSequential`` is wired as the | |
| DataLoader ``collate_fn`` so stacking, device upload, and augmentation happen | |
| in one batched call. | |
| The zarr handle is opened lazily in ``_ensure_initialized`` rather than | |
| ``__init__`` so the dataset stays picklable across spawn-based workers (the | |
| zarr array handle is not picklable). | |
| Run: | |
| python async_zarr_kornia_demo.py | |
| """ | |
| import asyncio | |
| import os | |
| import shutil | |
| import tempfile | |
| import numpy as np | |
| import obstore | |
| import torch | |
| import zarr | |
| import zarrs # noqa: F401 — registers ``zarrs.ZarrsCodecPipeline`` | |
| from torch.utils.data import DataLoader, Dataset | |
| from zarr.core.sync import sync as zarr_sync | |
| class AsyncZarrDataset(Dataset): | |
| """Dataset of bounding-box crops read via async zarr + zarrs codec. | |
| Implements ``__getitems__`` (plural) so the DataLoader hands us a whole | |
| batch of indices in one call; we issue one ``getitem`` per crop and await | |
| them together with ``asyncio.gather``. Returns a pre-stacked | |
| ``(B, *box_shape, *trailing)`` ndarray. Trailing is almost always a channel dimension | |
| """ | |
| def __init__( | |
| self, | |
| zarr_path: str, | |
| array_key: str, | |
| box_shape: tuple[int, ...], | |
| min_corners: np.ndarray, | |
| max_corners: np.ndarray, | |
| ) -> None: | |
| if min_corners.shape != max_corners.shape: | |
| raise ValueError(f"min/max corner shape mismatch: {min_corners.shape} vs {max_corners.shape}") | |
| if min_corners.shape[1] != len(box_shape): | |
| raise ValueError(f"corners have {min_corners.shape[1]} dims but box_shape has {len(box_shape)}") | |
| self._zarr_path = zarr_path | |
| self._array_key = array_key | |
| self._box_shape = tuple(int(b) for b in box_shape) | |
| self._min_corners = np.ascontiguousarray(min_corners, dtype=np.int64) | |
| self._max_corners = np.ascontiguousarray(max_corners, dtype=np.int64) | |
| self._async_arr = None | |
| self._trailing_ndim: int = 0 | |
| def __len__(self) -> int: | |
| return int(self._min_corners.shape[0]) | |
| def _ensure_initialized(self) -> None: | |
| if self._async_arr is not None: | |
| return | |
| zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"}) | |
| store = obstore.store.LocalStore(prefix=os.path.abspath(self._zarr_path)) | |
| arr = zarr.open_group(zarr.storage.ObjectStore(store), mode="r")[self._array_key] | |
| self._async_arr = arr._async_array | |
| self._trailing_ndim = len(arr.shape) - len(self._box_shape) | |
| def __getitem__(self, idx: int) -> np.ndarray: | |
| return self.__getitems__([idx])[0] | |
| def __getitems__(self, indices: list[int]) -> np.ndarray: | |
| self._ensure_initialized() | |
| idx = np.asarray(indices, dtype=np.int64) | |
| mins = self._min_corners[idx] | |
| maxes = self._max_corners[idx] | |
| trailing = (slice(None),) * self._trailing_ndim | |
| async def _gather(): | |
| coros = [ | |
| self._async_arr.getitem( | |
| tuple(slice(int(lo[j]), int(hi[j])) for j in range(mins.shape[1])) + trailing | |
| ) | |
| for lo, hi in zip(mins, maxes, strict=True) | |
| ] | |
| return await asyncio.gather(*coros) | |
| return np.stack(zarr_sync(_gather()), axis=0) | |
| def __getstate__(self) -> dict: | |
| state = self.__dict__.copy() | |
| state["_async_arr"] = None # zarr handle isn't picklable; reopen in worker | |
| return state | |
| class KorniaCollate: | |
| """Collate fn: NHWC→NCHW, device upload, kornia augmentation. | |
| ``AsyncZarrDataset.__getitems__`` returns a pre-stacked ndarray of shape | |
| ``(B, H, W)`` or ``(B, H, W, C)``, so this collate skips the usual | |
| per-sample stacking and goes straight to the batched augmentation. | |
| """ | |
| def __init__(self, crop_hw: tuple[int, int], device: str, dtype_max: float = 255.0) -> None: | |
| import kornia.augmentation as K | |
| self._device = device | |
| self._dtype_max = dtype_max | |
| self._aug = K.AugmentationSequential( | |
| K.RandomCrop(crop_hw, same_on_batch=False), | |
| K.RandomHorizontalFlip(p=0.5), | |
| K.RandomVerticalFlip(p=0.5), | |
| K.RandomRotation(degrees=15.0, p=0.5), | |
| K.ColorJitter(brightness=0.2, contrast=0.2, p=0.5), | |
| same_on_batch=False, | |
| ).to(device) | |
| self._aug.eval() | |
| def __call__(self, batch: np.ndarray) -> torch.Tensor: | |
| if batch.ndim == 3: # (B, H, W) — treat as single-channel | |
| t = torch.from_numpy(batch).unsqueeze(1) | |
| elif batch.ndim == 4: # (B, H, W, C) → (B, C, H, W) | |
| t = torch.from_numpy(np.ascontiguousarray(batch)).permute(0, 3, 1, 2) | |
| else: | |
| raise ValueError(f"unexpected batch ndim {batch.ndim}; expected 3 or 4") | |
| t = t.to(torch.float32).div_(self._dtype_max).to(self._device, non_blocking=True) | |
| return self._aug(t) | |
| def _write_demo_zarr(zarr_path: str, shape: tuple[int, int, int]) -> None: | |
| """Write a small random 3-channel image into a local zarr for the demo.""" | |
| os.makedirs(zarr_path, exist_ok=True) | |
| store = obstore.store.LocalStore(prefix=os.path.abspath(zarr_path)) | |
| root = zarr.open_group(zarr.storage.ObjectStore(store), mode="w") | |
| arr = root.create_array("image", shape=shape, chunks=(128, 128, shape[2]), dtype=np.uint8) | |
| arr[:] = np.random.default_rng(0).integers(0, 256, size=shape, dtype=np.uint8) | |
| def _random_boxes( | |
| array_hw: tuple[int, int], box_hw: tuple[int, int], n: int | |
| ) -> tuple[np.ndarray, np.ndarray]: | |
| rng = np.random.default_rng(42) | |
| ys = rng.integers(0, array_hw[0] - box_hw[0] + 1, size=n) | |
| xs = rng.integers(0, array_hw[1] - box_hw[1] + 1, size=n) | |
| mins = np.stack([ys, xs], axis=1).astype(np.int64) | |
| maxes = mins + np.array(box_hw, dtype=np.int64) | |
| return mins, maxes | |
| if __name__ == "__main__": | |
| image_shape = (512, 512, 3) | |
| box_hw = (128, 128) | |
| crop_hw = (112, 112) # kornia RandomCrop target | |
| n_boxes, batch_size = 32, 8 | |
| tmp = tempfile.mkdtemp() | |
| try: | |
| zarr_path = os.path.join(tmp, "demo.zarr") | |
| _write_demo_zarr(zarr_path, image_shape) | |
| mins, maxes = _random_boxes(image_shape[:2], box_hw, n_boxes) | |
| dataset = AsyncZarrDataset(zarr_path, "image", box_hw, mins, maxes) | |
| # ``spawn`` (not the default ``fork``) is required by zarr | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| collate_fn=KorniaCollate(crop_hw, device="cpu"), | |
| num_workers=4, | |
| multiprocessing_context="spawn", | |
| persistent_workers=True, | |
| ) | |
| for i, batch in enumerate(loader): | |
| print(f"batch {i}: shape={tuple(batch.shape)} dtype={batch.dtype} device={batch.device}") | |
| finally: | |
| shutil.rmtree(tmp) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment