Skip to content

Instantly share code, notes, and snippets.

@jayendra13
Created March 31, 2026 15:29
Show Gist options
  • Select an option

  • Save jayendra13/af793b3d377c4fa7d14147df7c2276df to your computer and use it in GitHub Desktop.

Select an option

Save jayendra13/af793b3d377c4fa7d14147df7c2276df to your computer and use it in GitHub Desktop.
Generic Zarr → Apache Arrow Fixed/Variable Shape Tensor (domain-agnostic)
"""
Create a realistic weather Zarr store with:
- Coordinates: time(4), lat(8), lon(16), pressure_level(5)
- Variables:
- temperature(time, lat, lon, pressure_level) -- 4D, has pressure levels
- wind_u(time, lat, lon, pressure_level) -- 4D, has pressure levels
- precipitation(time, lat, lon) -- 3D, NO pressure levels
- surface_pressure(time, lat, lon) -- 3D, NO pressure levels
This mirrors real NWP / reanalysis data where surface variables lack a
vertical dimension while upper-air variables are defined on pressure levels.
"""
import numpy as np
import xarray as xr
from pathlib import Path
import shutil
STORE = Path("weather.zarr")
def main():
if STORE.exists():
shutil.rmtree(STORE)
rng = np.random.default_rng(42)
# -- coordinates ----------------------------------------------------------
time = np.arange("2024-01-01", "2024-01-05", dtype="datetime64[D]") # 4 steps
lat = np.linspace(-90, 90, 8)
lon = np.linspace(-180, 180, 16, endpoint=False)
pressure_level = np.array([1000, 850, 500, 300, 200], dtype=np.int32) # hPa
# -- 4D upper-air variables (time, lat, lon, pressure_level) --------------
shape_4d = (len(time), len(lat), len(lon), len(pressure_level))
temperature = 250 + 50 * rng.random(shape_4d, dtype=np.float32)
wind_u = -20 + 40 * rng.random(shape_4d, dtype=np.float32)
# -- 3D surface variables (time, lat, lon) — no pressure_level ------------
shape_3d = (len(time), len(lat), len(lon))
precipitation = 50 * rng.random(shape_3d, dtype=np.float32)
surface_pressure = 950 + 100 * rng.random(shape_3d, dtype=np.float32)
ds = xr.Dataset(
{
"temperature": (["time", "lat", "lon", "pressure_level"], temperature, {
"units": "K",
"long_name": "Air Temperature",
}),
"wind_u": (["time", "lat", "lon", "pressure_level"], wind_u, {
"units": "m/s",
"long_name": "U-component of Wind",
}),
"precipitation": (["time", "lat", "lon"], precipitation, {
"units": "mm",
"long_name": "Total Precipitation",
}),
"surface_pressure": (["time", "lat", "lon"], surface_pressure, {
"units": "hPa",
"long_name": "Surface Pressure",
}),
},
coords={
"time": time,
"lat": lat,
"lon": lon,
"pressure_level": pressure_level,
},
)
ds.to_zarr(STORE)
print(f"Created {STORE}")
print(ds)
if __name__ == "__main__":
main()
[project]
name = "arrow-tensor-zarr"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"numpy>=2.4.4",
"pyarrow>=23.0.1",
"xarray>=2026.2.0",
"zarr>=3.1.6",
]
"""
Generic Zarr → Apache Arrow tensor extension types.
Reads any Zarr store, auto-discovers coordinate vs data arrays, and builds
Arrow tables using the canonical tensor extension types:
1. **Fixed Shape Tensor** (arrow.fixed_shape_tensor)
One column per data array. Each row = one slice along a user-chosen axis.
2. **Variable Shape Tensor** (arrow.variable_shape_tensor)
Single tensor column holding all data arrays. Rows may have different
shapes (some arrays have more dimensions than others).
Both tables round-trip through Arrow IPC with full extension metadata.
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from pathlib import Path
import numpy as np
import pyarrow as pa
import zarr
# ---------------------------------------------------------------------------
# Variable Shape Tensor — custom ExtensionType (not yet a pyarrow builtin)
# ---------------------------------------------------------------------------
class VariableShapeTensorType(pa.ExtensionType):
"""Arrow canonical variable_shape_tensor implemented as a pyarrow ExtensionType."""
def __init__(
self,
value_type: pa.DataType,
ndim: int,
dim_names: list[str] | None = None,
permutation: list[int] | None = None,
uniform_shape: list[int | None] | None = None,
):
self._ndim = ndim
self._dim_names = dim_names
self._permutation = permutation
self._uniform_shape = uniform_shape
storage = pa.struct([
pa.field("data", pa.list_(value_type)),
pa.field("shape", pa.list_(pa.int32(), ndim)),
])
super().__init__(storage, "arrow.variable_shape_tensor")
def __arrow_ext_serialize__(self) -> bytes:
md: dict = {}
if self._dim_names:
md["dim_names"] = self._dim_names
if self._permutation:
md["permutation"] = self._permutation
if self._uniform_shape:
md["uniform_shape"] = self._uniform_shape
return json.dumps(md).encode()
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
md = json.loads(serialized.decode())
vtype = storage_type.field("data").type.value_type
ndim = storage_type.field("shape").type.list_size
return cls(vtype, ndim, **md)
def _register_vst():
try:
pa.unregister_extension_type("arrow.variable_shape_tensor")
except KeyError:
pass
pa.register_extension_type(VariableShapeTensorType(pa.float32(), 1))
# ---------------------------------------------------------------------------
# Zarr introspection
# ---------------------------------------------------------------------------
@dataclass
class ZarrInfo:
"""Result of discovering arrays in a Zarr group."""
coords: dict[str, np.ndarray] = field(default_factory=dict)
data: dict[str, tuple[np.ndarray, list[str]]] = field(default_factory=dict)
# data values are (array, dim_names)
def discover_arrays(grp: zarr.Group) -> ZarrInfo:
"""Classify every array in *grp* as a coordinate or a data array."""
info = ZarrInfo()
dim_names_cache: dict[str, list[str]] = {}
for name, arr in grp.members():
if not isinstance(arr, zarr.Array):
continue
dims = _get_dim_names(grp, name)
dim_names_cache[name] = dims
data = np.asarray(arr)
# A coordinate is 1D and its name matches its sole dimension
if data.ndim == 1 and dims and dims[0] == name:
info.coords[name] = data
elif data.ndim >= 2:
info.data[name] = (data, dims)
else:
# 1D arrays that aren't self-named coordinates — treat as coords anyway
info.coords[name] = data
return info
def _get_dim_names(grp: zarr.Group, var_name: str) -> list[str]:
"""Read dimension_names from Zarr v3 consolidated metadata or variable metadata."""
store_path = grp.store.root
if store_path is None:
store_path = grp.store.path
meta_path = Path(store_path) / "zarr.json"
if meta_path.exists():
with open(meta_path) as f:
root_meta = json.load(f)
cm = root_meta.get("consolidated_metadata", {}).get("metadata", {})
if var_name in cm and "dimension_names" in cm[var_name]:
return cm[var_name]["dimension_names"]
# Fallback: per-array zarr.json
arr_meta_path = Path(store_path) / var_name / "zarr.json"
if arr_meta_path.exists():
with open(arr_meta_path) as f:
arr_meta = json.load(f)
if "dimension_names" in arr_meta:
return arr_meta["dimension_names"]
# Last resort: xarray v2 convention
attrs = dict(grp[var_name].attrs)
return attrs.get("_ARRAY_DIMENSIONS", [])
_NP_TO_ARROW: dict[str, pa.DataType] = {
"float16": pa.float16(),
"float32": pa.float32(),
"float64": pa.float64(),
"int8": pa.int8(),
"int16": pa.int16(),
"int32": pa.int32(),
"int64": pa.int64(),
"uint8": pa.uint8(),
"uint16": pa.uint16(),
"uint32": pa.uint32(),
"uint64": pa.uint64(),
"bool": pa.bool_(),
}
def _numpy_to_arrow_type(dtype: np.dtype) -> pa.DataType:
return _NP_TO_ARROW.get(dtype.name, pa.from_numpy_dtype(dtype))
def _resolve_slice_axis(info: ZarrInfo, slice_axis: int | str) -> tuple[str, int]:
"""Return (dim_name, axis_index) for the slicing dimension."""
first_data = next(iter(info.data.values()))
dims = first_data[1]
if isinstance(slice_axis, str):
name = slice_axis
idx = dims.index(name)
else:
idx = slice_axis
name = dims[idx] if dims else f"dim_{idx}"
return name, idx
# ---------------------------------------------------------------------------
# Approach 1 — Fixed Shape Tensor (one column per data array)
# ---------------------------------------------------------------------------
def build_fixed_shape_table(
store_path: Path,
slice_axis: int | str = 0,
) -> pa.Table:
"""
Each data array becomes a FixedShapeTensor column.
Each row = one slice along *slice_axis*.
"""
grp = zarr.open_group(store_path, mode="r")
info = discover_arrays(grp)
slice_dim, slice_idx = _resolve_slice_axis(info, slice_axis)
columns: dict[str, pa.Array] = {}
fields: list[pa.Field] = []
# Slicing coordinate as a plain column
if slice_dim in info.coords:
coord = info.coords[slice_dim]
arrow_type = _numpy_to_arrow_type(coord.dtype)
fields.append(pa.field(slice_dim, arrow_type))
columns[slice_dim] = pa.array(coord, type=arrow_type)
for var_name, (data, dims) in info.data.items():
# Find which axis of *this* array corresponds to slice_dim
if slice_dim in dims:
ax = dims.index(slice_dim)
remaining_dims = [d for i, d in enumerate(dims) if i != ax]
# Move slice axis to front then reshape
data = np.moveaxis(data, ax, 0)
else:
# This array doesn't have the slice dim — repeat it as a single row
data = data[np.newaxis, ...]
remaining_dims = list(dims)
n_slices = data.shape[0]
per_slice_shape = list(data.shape[1:])
arrow_value_type = _numpy_to_arrow_type(data.dtype)
tensor_type = pa.fixed_shape_tensor(
arrow_value_type,
per_slice_shape,
dim_names=remaining_dims if remaining_dims else None,
)
flat = data.reshape(n_slices, -1)
storage = pa.FixedSizeListArray.from_arrays(
pa.array(flat.ravel(), type=arrow_value_type),
int(np.prod(per_slice_shape)),
)
arr = pa.ExtensionArray.from_storage(tensor_type, storage)
fields.append(pa.field(var_name, tensor_type))
columns[var_name] = arr
return pa.table(columns, schema=pa.schema(fields))
# ---------------------------------------------------------------------------
# Approach 2 — Variable Shape Tensor (single mixed-shape column)
# ---------------------------------------------------------------------------
def build_variable_shape_table(
store_path: Path,
slice_axis: int | str = 0,
) -> pa.Table:
"""
One row per (slice_index, array_name). All data arrays share a single
VariableShapeTensor column. Arrays with fewer dims are padded with
trailing size-1 dims to match the maximum rank.
"""
grp = zarr.open_group(store_path, mode="r")
info = discover_arrays(grp)
slice_dim, slice_idx = _resolve_slice_axis(info, slice_axis)
# Determine max ndim after removing slice axis, and collect dim names
per_array_sliced: list[tuple[str, np.ndarray, list[str]]] = []
max_ndim = 0
for var_name, (data, dims) in info.data.items():
if slice_dim in dims:
ax = dims.index(slice_dim)
remaining_dims = [d for i, d in enumerate(dims) if i != ax]
data = np.moveaxis(data, ax, 0)
else:
remaining_dims = list(dims)
data = data[np.newaxis, ...]
per_array_sliced.append((var_name, data, remaining_dims))
max_ndim = max(max_ndim, len(data.shape) - 1)
# Build unified dim_names (ordered: most-common-first, padded dims last)
all_dim_names: list[str] = []
for _, _, rdims in per_array_sliced:
for d in rdims:
if d not in all_dim_names:
all_dim_names.append(d)
# Pad to max_ndim
while len(all_dim_names) < max_ndim:
all_dim_names.append(f"dim_{len(all_dim_names)}")
# Compute uniform_shape: size where all arrays agree, None where they differ
all_shapes: list[list[int]] = []
for _, data, rdims in per_array_sliced:
per_slice = list(data.shape[1:])
while len(per_slice) < max_ndim:
per_slice.append(1)
all_shapes.append(per_slice)
uniform_shape: list[int | None] = []
for dim_i in range(max_ndim):
sizes = {s[dim_i] for s in all_shapes}
uniform_shape.append(sizes.pop() if len(sizes) == 1 else None)
# Resolve the common arrow value type (promote to widest)
dtypes = [data.dtype for _, data, _ in per_array_sliced]
common_dtype = np.result_type(*dtypes)
arrow_value_type = _numpy_to_arrow_type(common_dtype)
vst = VariableShapeTensorType(
arrow_value_type,
max_ndim,
dim_names=all_dim_names,
uniform_shape=uniform_shape,
)
# Slice coordinate
if slice_dim in info.coords:
coord = info.coords[slice_dim]
else:
first_data = per_array_sliced[0][1]
coord = np.arange(first_data.shape[0])
row_coords = []
row_varnames = []
data_lists: list[list] = []
shape_vals: list[int] = []
for var_name, data, rdims in per_array_sliced:
n_slices = data.shape[0]
for s_idx in range(n_slices):
row_coords.append(coord[s_idx] if s_idx < len(coord) else s_idx)
row_varnames.append(var_name)
slab = data[s_idx].astype(common_dtype)
# Pad trailing dims to max_ndim
while slab.ndim < max_ndim:
slab = slab[..., np.newaxis]
shape_vals.extend(slab.shape)
data_lists.append(slab.ravel().tolist())
n_rows = len(data_lists)
data_arr = pa.array(data_lists, type=pa.list_(arrow_value_type))
shape_arr = pa.FixedSizeListArray.from_arrays(
pa.array(shape_vals, type=pa.int32()), max_ndim,
)
storage = pa.StructArray.from_arrays(
[data_arr, shape_arr], names=["data", "shape"],
)
tensor_col = pa.ExtensionArray.from_storage(vst, storage)
coord_arrow_type = _numpy_to_arrow_type(np.array(row_coords).dtype)
table = pa.table({
slice_dim: pa.array(row_coords, type=coord_arrow_type),
"array_name": pa.array(row_varnames, type=pa.utf8()),
"tensor": tensor_col,
})
return table
# ---------------------------------------------------------------------------
# round-trip verification
# ---------------------------------------------------------------------------
def write_ipc(table: pa.Table, path: Path):
with pa.ipc.new_file(path, table.schema) as writer:
writer.write_table(table)
print(f" Wrote {path} ({path.stat().st_size:,} bytes)")
def read_ipc(path: Path) -> pa.Table:
with pa.ipc.open_file(path) as reader:
return reader.read_all()
def verify_roundtrip(original: pa.Table, path: Path):
"""Write → read → compare."""
write_ipc(original, path)
restored = read_ipc(path)
assert original.schema.equals(restored.schema), "Schema mismatch!"
assert original.num_rows == restored.num_rows, "Row count mismatch!"
for col_name in original.column_names:
orig_col = original.column(col_name)
rest_col = restored.column(col_name)
assert orig_col.equals(rest_col), f"Column {col_name} mismatch!"
print(f" Round-trip OK ({original.num_rows} rows, {len(original.schema)} cols)")
# ---------------------------------------------------------------------------
# main
# ---------------------------------------------------------------------------
def main():
_register_vst()
store = Path("weather.zarr")
grp = zarr.open_group(store, mode="r")
info = discover_arrays(grp)
print("Discovered arrays:")
print(f" Coordinates: {list(info.coords.keys())}")
for name, (arr, dims) in info.data.items():
print(f" Data: {name:20s} shape={arr.shape} dims={dims}")
print()
# -- Approach 1 ----------------------------------------------------------
print("=" * 70)
print("APPROACH 1: Fixed Shape Tensor (one column per data array)")
print("=" * 70)
fixed_table = build_fixed_shape_table(store, slice_axis="time")
print(fixed_table.schema)
print()
print(fixed_table)
print()
verify_roundtrip(fixed_table, Path("fixed_shape.arrow"))
# -- Approach 2 ----------------------------------------------------------
print()
print("=" * 70)
print("APPROACH 2: Variable Shape Tensor (mixed-shape single column)")
print("=" * 70)
vst_table = build_variable_shape_table(store, slice_axis="time")
print(vst_table.schema)
print()
print(vst_table)
print()
verify_roundtrip(vst_table, Path("variable_shape.arrow"))
# -- Tensor extraction demo ----------------------------------------------
print()
print("=" * 70)
print("TENSOR EXTRACTION DEMO")
print("=" * 70)
# Fixed shape: first data column, row 0
data_col_names = [
f.name for f in fixed_table.schema
if isinstance(f.type, pa.FixedShapeTensorType)
]
for col_name in data_col_names:
col = fixed_table.column(col_name)
tt = col.type
row = np.array(col[0].as_py(), dtype=np.float32).reshape(tt.shape)
print(f"\n {col_name}[0] shape={row.shape} dims={tt.dim_names} "
f"min={row.min():.2f} max={row.max():.2f}")
# Variable shape: first and last row
tensor_col = vst_table.column("tensor")
for i in [0, len(tensor_col) - 1]:
row = tensor_col[i].as_py()
shape = row["shape"]
data = np.array(row["data"], dtype=np.float32).reshape(shape)
vname = vst_table.column("array_name")[i].as_py()
print(f"\n row[{i}] ({vname}) shape={tuple(shape)} "
f"min={data.min():.2f} max={data.max():.2f}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment