Created
July 8, 2025 21:28
-
-
Save lgarrison/ef2922fdb8ca78e8753c39758afedb56 to your computer and use it in GitHub Desktop.
zarr write benchmark script
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# requires-python = ">=3.13" | |
# dependencies = [ | |
# "click", | |
# "tqdm", | |
# "zarr", | |
# "zarrs", | |
# ] | |
# /// | |
""" | |
Benchmark writing a zarr array. | |
The basic setup is that we're filling a zarr array with N "images" of size CxC. | |
We are calling each image a chunk. The chunks are individually small, so to | |
avoid writing too many files, we group the chunks into shards. Writes in zarr | |
are parallelized over shards, so we need to make sure we have enough shards to | |
keep all the workers busy. Since we're using the zarrs (rust) backend, each | |
Python worker may have a number of Rust workers. In general, it's better to | |
use fewer Python workers and more Rust workers, but real applications may want | |
multiple Python workers to do pre-processing, so benchmarking that case is | |
interesting, too. | |
This can max out the SSD on my workstation (6+ GB/s), and do pretty well on | |
the ceph network filesystem, too (3+ GB/s). | |
""" | |
import concurrent.futures | |
import functools | |
import multiprocessing | |
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor | |
from pathlib import Path | |
from timeit import default_timer as timer | |
import click | |
import numpy as np | |
import zarr | |
import zarr.codecs | |
import zarr.storage | |
import zarrs # noqa: F401 | |
from tqdm import tqdm | |
def fill(z, sl): | |
z[sl] = 1 | |
@click.command() | |
@click.argument( | |
"data_dir", | |
type=click.Path(exists=True, file_okay=False, path_type=Path), | |
) | |
@click.option( | |
"-pw", | |
"--python-workers", | |
default=1, | |
type=int, | |
help="Number of Python threads or processes", | |
) | |
@click.option( | |
"-rw", | |
"--rust-workers", | |
default=None, | |
type=int, | |
help="Number of Rust (zarrs) threads", | |
) | |
@click.option( | |
"-c", | |
"--chunk-size", | |
default=256, | |
type=int, | |
help="Chunk size C for CxC chunks", | |
) | |
@click.option( | |
"-n", | |
"--N", | |
default=10000, | |
type=int, | |
help="Total number of chunks to create", | |
) | |
@click.option( | |
"-p", | |
"--use-processes", | |
default=False, | |
is_flag=True, | |
help="Use processes instead of threads for parallel execution", | |
) | |
@click.option( | |
"-s", | |
"--shard-size", | |
default=1000, | |
type=int, | |
help="Size of each shard", | |
) | |
@click.option( | |
"-C", | |
"--compression", | |
default=False, | |
is_flag=True, | |
help="Compress with Blosc", | |
) | |
def main( | |
data_dir: Path, | |
python_workers: int, | |
rust_workers: int, | |
chunk_size: int, | |
n: int, | |
use_processes: bool, | |
compression: bool, | |
shard_size: int, | |
): | |
# Array dimensions and chunking | |
shape = (n, chunk_size, chunk_size) | |
chunk_shape = (1, chunk_size, chunk_size) | |
shard_shape = (shard_size, chunk_size, chunk_size) | |
print("Creating zarr array:") | |
print(f" Shape : {shape}") | |
print(f" Chunk shape: {chunk_shape}") | |
print(f" Shard shape: {shard_shape}") | |
if compression: | |
# This is probably not interesting unless the `fill()` function writes a non-constant | |
compressors = [ | |
zarr.codecs.BloscCodec( | |
cname="zstd", | |
clevel=1, | |
shuffle=zarr.codecs.BloscShuffle.shuffle, | |
blocksize=1 << 20, | |
) | |
] | |
else: | |
compressors = None | |
# thread or process pool? | |
if use_processes: | |
@functools.wraps(ProcessPoolExecutor) | |
def PoolExecutor(*args, **kwargs): | |
return ProcessPoolExecutor( | |
*args, **kwargs, mp_context=multiprocessing.get_context("forkserver") | |
) | |
else: | |
PoolExecutor = ThreadPoolExecutor | |
# figure out how many shards, and how many shards per worker | |
n_shards = (n + shard_shape[0] - 1) // shard_shape[0] | |
# Send as many shards as possible to each Python worker, so that each Rust worker | |
# can write one shard. | |
# This is imagining that in real applications, there is probably some processing | |
# that's going to be parallelized in Python. Otherwise, we would shove it all to Rust. | |
shards_per_worker = (n_shards + python_workers - 1) // python_workers | |
print( | |
f"Using {PoolExecutor.__name__} with {python_workers} Python workers, {'default' if rust_workers is None else rust_workers} Rust workers, {shards_per_worker} shards per Python worker" | |
) | |
zarr.config.set( | |
{ | |
# TODO: when using ThreadPoolExecutor, is this still per-Python-worker? | |
"threading.max_workers": rust_workers, | |
"array.write_empty_chunks": False, | |
"codec_pipeline.path": "zarrs.ZarrsCodecPipeline", | |
"codec_pipeline.validate_checksums": False, | |
"codec_pipeline.store_empty_chunks": True, | |
"codec_pipeline.chunk_concurrent_maximum": rust_workers, | |
"codec_pipeline.chunk_concurrent_minimum": rust_workers, | |
} | |
) | |
# Create the empty zarr array | |
store = zarr.storage.LocalStore(data_dir / "large_array.zarr") | |
z = zarr.create_array( | |
store=store, | |
shape=shape, | |
chunks=chunk_shape, | |
shards=shard_shape, | |
dtype=np.float32, | |
overwrite=True, | |
compressors=compressors, | |
) | |
# Fill the array in parallel | |
with PoolExecutor(max_workers=python_workers) as executor: | |
t = -timer() | |
futures = [] | |
# TODO: does zarr have any native way to find shard boundaries? | |
for worker_id in range(python_workers): | |
worker_start_shard = worker_id * shards_per_worker | |
worker_end_shard = min(worker_start_shard + shards_per_worker, n_shards) | |
chunk_start = worker_start_shard * shard_shape[0] | |
chunk_end = min(worker_end_shard * shard_shape[0], n) | |
futures.append(executor.submit(fill, z, slice(chunk_start, chunk_end))) | |
assert worker_end_shard == n_shards | |
for future in tqdm( | |
concurrent.futures.as_completed(futures), | |
desc="Python Workers done", | |
total=len(futures), | |
): | |
if future.exception() is not None: | |
print(f"Error filling shard: {future.exception()}") | |
t += timer() | |
print("\nZarr array creation and filling complete") | |
print( | |
f"Total size: {z.nbytes / (1024**3):.2f} GB (file size: {z.nbytes // n_shards / 1024**2:.2f} MB)" | |
) | |
print(f"Rate: {z.nbytes / (1024**3 * t):.4g} GB/s ({t:.4g} seconds)\n") | |
# sometimes expensive, but can see compression ratios, etc. | |
# print("Zarr info:") | |
# print(z.info_complete()) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment