Last active
February 26, 2025 19:19
-
-
Save tlambert03/f8c1b069c2947b411ce24ea05aa370b1 to your computer and use it in GitHub Desktop.
benchmarking acquire-zarr and tensorstore
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# requires-python = ">=3.13" | |
# dependencies = [ | |
# "acquire-zarr", | |
# "zarr", | |
# "rich", | |
# "tensorstore", | |
# ] | |
# /// | |
#!/usr/bin/env python3 | |
"""Compare write performance of TensorStore vs. acquire-zarr for a Zarr v3 store.""" | |
import sys | |
import time | |
import acquire_zarr as aqz | |
import numpy as np | |
import tensorstore | |
import zarr | |
from rich import print | |
def run_tensorstore_test(data: np.ndarray, path: str, metadata: dict) -> float: | |
"""Write data using TensorStore and print per-plane and total write times.""" | |
# Define a TensorStore spec for a Zarr v3 store. | |
spec = { | |
"driver": "zarr3", | |
"kvstore": {"driver": "file", "path": path}, | |
"metadata": metadata, | |
"delete_existing": True, | |
"create": True, | |
} | |
# Open (or create) the store. | |
ts = tensorstore.open(spec).result() | |
print(ts) | |
total_start = time.perf_counter_ns() | |
futures = [] | |
# cache data until we've reached a write-chunk-aligned block | |
chunk_length = ts.schema.chunk_layout.write_chunk.shape[0] | |
write_chunk_shape = (chunk_length, *ts.domain.shape[1:]) | |
chunk = np.empty(write_chunk_shape, dtype=np.uint16) | |
for i in range(data.shape[0]): | |
start_plane = time.perf_counter_ns() | |
chunk_idx = i % chunk_length | |
chunk[chunk_idx] = data[i] | |
if chunk_idx == chunk_length - 1: | |
slc = slice(i - chunk_length + 1, i + 1) | |
futures.append(ts[slc].write(chunk)) | |
chunk = np.empty(write_chunk_shape, dtype=np.uint16) | |
elapsed = time.perf_counter_ns() - start_plane | |
print(f"TensorStore: Plane {i} written in {elapsed / 1e6:.3f} ms") | |
start_futures = time.perf_counter_ns() | |
# Wait for all writes to finish. | |
for future in futures: | |
future.result() | |
elapsed = time.perf_counter_ns() - start_futures | |
print(f"TensorStore: Final futures took {elapsed / 1e6:.3f} ms") | |
total_elapsed = time.perf_counter_ns() - total_start | |
tot_ms = total_elapsed / 1e6 | |
print(f"TensorStore: Total write time: {tot_ms:.3f} ms") | |
return tot_ms | |
def run_acquire_zarr_test( | |
data: np.ndarray, path: str, tchunk_size: int = 1, xy_chunk_size: int = 2048 | |
) -> float: | |
"""Write data using acquire-zarr and print per-plane and total write times.""" | |
settings = aqz.StreamSettings( | |
store_path=path, | |
data_type=aqz.DataType.UINT16, | |
version=aqz.ZarrVersion.V3, | |
) | |
settings.dimensions.extend( | |
[ | |
aqz.Dimension( | |
name="t", | |
type=aqz.DimensionType.TIME, | |
array_size_px=0, | |
chunk_size_px=tchunk_size, | |
shard_size_chunks=1, | |
), | |
aqz.Dimension( | |
name="y", | |
type=aqz.DimensionType.SPACE, | |
array_size_px=2048, | |
chunk_size_px=xy_chunk_size, | |
shard_size_chunks=1, | |
), | |
aqz.Dimension( | |
name="x", | |
type=aqz.DimensionType.SPACE, | |
array_size_px=2048, | |
chunk_size_px=xy_chunk_size, | |
shard_size_chunks=1, | |
), | |
], | |
) | |
# Create a ZarrStream for appending frames. | |
stream = aqz.ZarrStream(settings) | |
total_start = time.perf_counter_ns() | |
for i in range(data.shape[0]): | |
start_plane = time.perf_counter_ns() | |
stream.append(data[i]) | |
elapsed = time.perf_counter_ns() - start_plane | |
print(f"Acquire-zarr: Plane {i} written in {elapsed / 1e6:.3f} ms") | |
# Close (or flush) the stream to finalize writes. | |
del stream | |
total_elapsed = time.perf_counter_ns() - total_start | |
tot_ms = total_elapsed / 1e6 | |
print(f"Acquire-zarr: Total write time: {tot_ms:.3f} ms") | |
return tot_ms | |
def main() -> None: | |
# Pre-generate the data (timing excluded) | |
data: np.ndarray = np.random.randint( | |
0, | |
2**16 - 1, | |
(64, 2048, 2048), | |
dtype=np.uint16, | |
) | |
T_CHUNK_SIZE = int(sys.argv[1]) if len(sys.argv) > 1 else 1 | |
XY_CHUNK_SIZE = int(sys.argv[2]) if len(sys.argv) > 2 else 2048 | |
print("tchunk_size:", T_CHUNK_SIZE) | |
print("xy_chunk_size:", XY_CHUNK_SIZE) | |
print("\nRunning acquire-zarr test:") | |
az_path = "acquire_zarr_test.zarr" | |
time_az = run_acquire_zarr_test(data, az_path, T_CHUNK_SIZE, XY_CHUNK_SIZE) | |
# use the exact same metadata that was used for the acquire-zarr test | |
# to ensure we're using the same chunks and codecs, etc... | |
az = zarr.open(az_path)["0"] | |
print("\nRunning TensorStore test:") | |
ts_path = "tensorstore_test.zarr" | |
time_ts = run_tensorstore_test( | |
data, | |
ts_path, | |
{**az.metadata.to_dict(), "data_type": "uint16"}, | |
) | |
# ensure that the data is written to disk and that they are the same | |
print("\nComparing the written data:", end=" ") | |
ts = zarr.open(ts_path) | |
np.testing.assert_array_equal(data, ts) # ensure tensorstore wrote the correct data | |
np.testing.assert_array_equal( | |
data, az | |
) # ensure acquire-zarr wrote the correct data | |
print("✅\n") | |
assert ts.metadata == az.metadata | |
print("Metadata matches:") | |
print(ts.metadata) | |
print("\nPerformance comparison:") | |
print(f" acquire-zarr: {time_az:.3f} ms") | |
print(f" TensorStore: {time_ts:.3f} ms") | |
print(f" AZ/TS Ratio: {time_az / time_ts:.3f}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment