Skip to content

Instantly share code, notes, and snippets.

@lgarrison
Created July 8, 2025 21:28
Show Gist options
  • Save lgarrison/ef2922fdb8ca78e8753c39758afedb56 to your computer and use it in GitHub Desktop.
Save lgarrison/ef2922fdb8ca78e8753c39758afedb56 to your computer and use it in GitHub Desktop.
zarr write benchmark script
# /// script
# requires-python = ">=3.13"
# dependencies = [
# "click",
# "tqdm",
# "zarr",
# "zarrs",
# ]
# ///
"""
Benchmark writing a zarr array.
The basic setup is that we're filling a zarr array with N "images" of size CxC.
We are calling each image a chunk. The chunks are individually small, so to
avoid writing too many files, we group the chunks into shards. Writes in zarr
are parallelized over shards, so we need to make sure we have enough shards to
keep all the workers busy. Since we're using the zarrs (rust) backend, each
Python worker may have a number of Rust workers. In general, it's better to
use fewer Python workers and more Rust workers, but real applications may want
multiple Python workers to do pre-processing, so benchmarking that case is
interesting, too.
This can max out the SSD on my workstation (6+ GB/s), and do pretty well on
the ceph network filesystem, too (3+ GB/s).
"""
import concurrent.futures
import functools
import multiprocessing
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from pathlib import Path
from timeit import default_timer as timer
import click
import numpy as np
import zarr
import zarr.codecs
import zarr.storage
import zarrs # noqa: F401
from tqdm import tqdm
def fill(z, sl):
z[sl] = 1
@click.command()
@click.argument(
"data_dir",
type=click.Path(exists=True, file_okay=False, path_type=Path),
)
@click.option(
"-pw",
"--python-workers",
default=1,
type=int,
help="Number of Python threads or processes",
)
@click.option(
"-rw",
"--rust-workers",
default=None,
type=int,
help="Number of Rust (zarrs) threads",
)
@click.option(
"-c",
"--chunk-size",
default=256,
type=int,
help="Chunk size C for CxC chunks",
)
@click.option(
"-n",
"--N",
default=10000,
type=int,
help="Total number of chunks to create",
)
@click.option(
"-p",
"--use-processes",
default=False,
is_flag=True,
help="Use processes instead of threads for parallel execution",
)
@click.option(
"-s",
"--shard-size",
default=1000,
type=int,
help="Size of each shard",
)
@click.option(
"-C",
"--compression",
default=False,
is_flag=True,
help="Compress with Blosc",
)
def main(
data_dir: Path,
python_workers: int,
rust_workers: int,
chunk_size: int,
n: int,
use_processes: bool,
compression: bool,
shard_size: int,
):
# Array dimensions and chunking
shape = (n, chunk_size, chunk_size)
chunk_shape = (1, chunk_size, chunk_size)
shard_shape = (shard_size, chunk_size, chunk_size)
print("Creating zarr array:")
print(f" Shape : {shape}")
print(f" Chunk shape: {chunk_shape}")
print(f" Shard shape: {shard_shape}")
if compression:
# This is probably not interesting unless the `fill()` function writes a non-constant
compressors = [
zarr.codecs.BloscCodec(
cname="zstd",
clevel=1,
shuffle=zarr.codecs.BloscShuffle.shuffle,
blocksize=1 << 20,
)
]
else:
compressors = None
# thread or process pool?
if use_processes:
@functools.wraps(ProcessPoolExecutor)
def PoolExecutor(*args, **kwargs):
return ProcessPoolExecutor(
*args, **kwargs, mp_context=multiprocessing.get_context("forkserver")
)
else:
PoolExecutor = ThreadPoolExecutor
# figure out how many shards, and how many shards per worker
n_shards = (n + shard_shape[0] - 1) // shard_shape[0]
# Send as many shards as possible to each Python worker, so that each Rust worker
# can write one shard.
# This is imagining that in real applications, there is probably some processing
# that's going to be parallelized in Python. Otherwise, we would shove it all to Rust.
shards_per_worker = (n_shards + python_workers - 1) // python_workers
print(
f"Using {PoolExecutor.__name__} with {python_workers} Python workers, {'default' if rust_workers is None else rust_workers} Rust workers, {shards_per_worker} shards per Python worker"
)
zarr.config.set(
{
# TODO: when using ThreadPoolExecutor, is this still per-Python-worker?
"threading.max_workers": rust_workers,
"array.write_empty_chunks": False,
"codec_pipeline.path": "zarrs.ZarrsCodecPipeline",
"codec_pipeline.validate_checksums": False,
"codec_pipeline.store_empty_chunks": True,
"codec_pipeline.chunk_concurrent_maximum": rust_workers,
"codec_pipeline.chunk_concurrent_minimum": rust_workers,
}
)
# Create the empty zarr array
store = zarr.storage.LocalStore(data_dir / "large_array.zarr")
z = zarr.create_array(
store=store,
shape=shape,
chunks=chunk_shape,
shards=shard_shape,
dtype=np.float32,
overwrite=True,
compressors=compressors,
)
# Fill the array in parallel
with PoolExecutor(max_workers=python_workers) as executor:
t = -timer()
futures = []
# TODO: does zarr have any native way to find shard boundaries?
for worker_id in range(python_workers):
worker_start_shard = worker_id * shards_per_worker
worker_end_shard = min(worker_start_shard + shards_per_worker, n_shards)
chunk_start = worker_start_shard * shard_shape[0]
chunk_end = min(worker_end_shard * shard_shape[0], n)
futures.append(executor.submit(fill, z, slice(chunk_start, chunk_end)))
assert worker_end_shard == n_shards
for future in tqdm(
concurrent.futures.as_completed(futures),
desc="Python Workers done",
total=len(futures),
):
if future.exception() is not None:
print(f"Error filling shard: {future.exception()}")
t += timer()
print("\nZarr array creation and filling complete")
print(
f"Total size: {z.nbytes / (1024**3):.2f} GB (file size: {z.nbytes // n_shards / 1024**2:.2f} MB)"
)
print(f"Rate: {z.nbytes / (1024**3 * t):.4g} GB/s ({t:.4g} seconds)\n")
# sometimes expensive, but can see compression ratios, etc.
# print("Zarr info:")
# print(z.info_complete())
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment