Skip to content

Instantly share code, notes, and snippets.

@tlambert03
Last active February 26, 2025 19:19
Show Gist options
  • Save tlambert03/f8c1b069c2947b411ce24ea05aa370b1 to your computer and use it in GitHub Desktop.
Save tlambert03/f8c1b069c2947b411ce24ea05aa370b1 to your computer and use it in GitHub Desktop.
benchmarking acquire-zarr and tensorstore
# /// script
# requires-python = ">=3.13"
# dependencies = [
# "acquire-zarr",
# "zarr",
# "rich",
# "tensorstore",
# ]
# ///
#!/usr/bin/env python3
"""Compare write performance of TensorStore vs. acquire-zarr for a Zarr v3 store."""
import sys
import time
import acquire_zarr as aqz
import numpy as np
import tensorstore
import zarr
from rich import print
def run_tensorstore_test(data: np.ndarray, path: str, metadata: dict) -> float:
"""Write data using TensorStore and print per-plane and total write times."""
# Define a TensorStore spec for a Zarr v3 store.
spec = {
"driver": "zarr3",
"kvstore": {"driver": "file", "path": path},
"metadata": metadata,
"delete_existing": True,
"create": True,
}
# Open (or create) the store.
ts = tensorstore.open(spec).result()
print(ts)
total_start = time.perf_counter_ns()
futures = []
# cache data until we've reached a write-chunk-aligned block
chunk_length = ts.schema.chunk_layout.write_chunk.shape[0]
write_chunk_shape = (chunk_length, *ts.domain.shape[1:])
chunk = np.empty(write_chunk_shape, dtype=np.uint16)
for i in range(data.shape[0]):
start_plane = time.perf_counter_ns()
chunk_idx = i % chunk_length
chunk[chunk_idx] = data[i]
if chunk_idx == chunk_length - 1:
slc = slice(i - chunk_length + 1, i + 1)
futures.append(ts[slc].write(chunk))
chunk = np.empty(write_chunk_shape, dtype=np.uint16)
elapsed = time.perf_counter_ns() - start_plane
print(f"TensorStore: Plane {i} written in {elapsed / 1e6:.3f} ms")
start_futures = time.perf_counter_ns()
# Wait for all writes to finish.
for future in futures:
future.result()
elapsed = time.perf_counter_ns() - start_futures
print(f"TensorStore: Final futures took {elapsed / 1e6:.3f} ms")
total_elapsed = time.perf_counter_ns() - total_start
tot_ms = total_elapsed / 1e6
print(f"TensorStore: Total write time: {tot_ms:.3f} ms")
return tot_ms
def run_acquire_zarr_test(
data: np.ndarray, path: str, tchunk_size: int = 1, xy_chunk_size: int = 2048
) -> float:
"""Write data using acquire-zarr and print per-plane and total write times."""
settings = aqz.StreamSettings(
store_path=path,
data_type=aqz.DataType.UINT16,
version=aqz.ZarrVersion.V3,
)
settings.dimensions.extend(
[
aqz.Dimension(
name="t",
type=aqz.DimensionType.TIME,
array_size_px=0,
chunk_size_px=tchunk_size,
shard_size_chunks=1,
),
aqz.Dimension(
name="y",
type=aqz.DimensionType.SPACE,
array_size_px=2048,
chunk_size_px=xy_chunk_size,
shard_size_chunks=1,
),
aqz.Dimension(
name="x",
type=aqz.DimensionType.SPACE,
array_size_px=2048,
chunk_size_px=xy_chunk_size,
shard_size_chunks=1,
),
],
)
# Create a ZarrStream for appending frames.
stream = aqz.ZarrStream(settings)
total_start = time.perf_counter_ns()
for i in range(data.shape[0]):
start_plane = time.perf_counter_ns()
stream.append(data[i])
elapsed = time.perf_counter_ns() - start_plane
print(f"Acquire-zarr: Plane {i} written in {elapsed / 1e6:.3f} ms")
# Close (or flush) the stream to finalize writes.
del stream
total_elapsed = time.perf_counter_ns() - total_start
tot_ms = total_elapsed / 1e6
print(f"Acquire-zarr: Total write time: {tot_ms:.3f} ms")
return tot_ms
def main() -> None:
# Pre-generate the data (timing excluded)
data: np.ndarray = np.random.randint(
0,
2**16 - 1,
(64, 2048, 2048),
dtype=np.uint16,
)
T_CHUNK_SIZE = int(sys.argv[1]) if len(sys.argv) > 1 else 1
XY_CHUNK_SIZE = int(sys.argv[2]) if len(sys.argv) > 2 else 2048
print("tchunk_size:", T_CHUNK_SIZE)
print("xy_chunk_size:", XY_CHUNK_SIZE)
print("\nRunning acquire-zarr test:")
az_path = "acquire_zarr_test.zarr"
time_az = run_acquire_zarr_test(data, az_path, T_CHUNK_SIZE, XY_CHUNK_SIZE)
# use the exact same metadata that was used for the acquire-zarr test
# to ensure we're using the same chunks and codecs, etc...
az = zarr.open(az_path)["0"]
print("\nRunning TensorStore test:")
ts_path = "tensorstore_test.zarr"
time_ts = run_tensorstore_test(
data,
ts_path,
{**az.metadata.to_dict(), "data_type": "uint16"},
)
# ensure that the data is written to disk and that they are the same
print("\nComparing the written data:", end=" ")
ts = zarr.open(ts_path)
np.testing.assert_array_equal(data, ts) # ensure tensorstore wrote the correct data
np.testing.assert_array_equal(
data, az
) # ensure acquire-zarr wrote the correct data
print("✅\n")
assert ts.metadata == az.metadata
print("Metadata matches:")
print(ts.metadata)
print("\nPerformance comparison:")
print(f" acquire-zarr: {time_az:.3f} ms")
print(f" TensorStore: {time_ts:.3f} ms")
print(f" AZ/TS Ratio: {time_az / time_ts:.3f}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment