Created
March 31, 2026 15:29
-
-
Save jayendra13/af793b3d377c4fa7d14147df7c2276df to your computer and use it in GitHub Desktop.
Generic Zarr → Apache Arrow Fixed/Variable Shape Tensor (domain-agnostic)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Create a realistic weather Zarr store with: | |
| - Coordinates: time(4), lat(8), lon(16), pressure_level(5) | |
| - Variables: | |
| - temperature(time, lat, lon, pressure_level) -- 4D, has pressure levels | |
| - wind_u(time, lat, lon, pressure_level) -- 4D, has pressure levels | |
| - precipitation(time, lat, lon) -- 3D, NO pressure levels | |
| - surface_pressure(time, lat, lon) -- 3D, NO pressure levels | |
| This mirrors real NWP / reanalysis data where surface variables lack a | |
| vertical dimension while upper-air variables are defined on pressure levels. | |
| """ | |
| import numpy as np | |
| import xarray as xr | |
| from pathlib import Path | |
| import shutil | |
| STORE = Path("weather.zarr") | |
| def main(): | |
| if STORE.exists(): | |
| shutil.rmtree(STORE) | |
| rng = np.random.default_rng(42) | |
| # -- coordinates ---------------------------------------------------------- | |
| time = np.arange("2024-01-01", "2024-01-05", dtype="datetime64[D]") # 4 steps | |
| lat = np.linspace(-90, 90, 8) | |
| lon = np.linspace(-180, 180, 16, endpoint=False) | |
| pressure_level = np.array([1000, 850, 500, 300, 200], dtype=np.int32) # hPa | |
| # -- 4D upper-air variables (time, lat, lon, pressure_level) -------------- | |
| shape_4d = (len(time), len(lat), len(lon), len(pressure_level)) | |
| temperature = 250 + 50 * rng.random(shape_4d, dtype=np.float32) | |
| wind_u = -20 + 40 * rng.random(shape_4d, dtype=np.float32) | |
| # -- 3D surface variables (time, lat, lon) — no pressure_level ------------ | |
| shape_3d = (len(time), len(lat), len(lon)) | |
| precipitation = 50 * rng.random(shape_3d, dtype=np.float32) | |
| surface_pressure = 950 + 100 * rng.random(shape_3d, dtype=np.float32) | |
| ds = xr.Dataset( | |
| { | |
| "temperature": (["time", "lat", "lon", "pressure_level"], temperature, { | |
| "units": "K", | |
| "long_name": "Air Temperature", | |
| }), | |
| "wind_u": (["time", "lat", "lon", "pressure_level"], wind_u, { | |
| "units": "m/s", | |
| "long_name": "U-component of Wind", | |
| }), | |
| "precipitation": (["time", "lat", "lon"], precipitation, { | |
| "units": "mm", | |
| "long_name": "Total Precipitation", | |
| }), | |
| "surface_pressure": (["time", "lat", "lon"], surface_pressure, { | |
| "units": "hPa", | |
| "long_name": "Surface Pressure", | |
| }), | |
| }, | |
| coords={ | |
| "time": time, | |
| "lat": lat, | |
| "lon": lon, | |
| "pressure_level": pressure_level, | |
| }, | |
| ) | |
| ds.to_zarr(STORE) | |
| print(f"Created {STORE}") | |
| print(ds) | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| [project] | |
| name = "arrow-tensor-zarr" | |
| version = "0.1.0" | |
| description = "Add your description here" | |
| readme = "README.md" | |
| requires-python = ">=3.12" | |
| dependencies = [ | |
| "numpy>=2.4.4", | |
| "pyarrow>=23.0.1", | |
| "xarray>=2026.2.0", | |
| "zarr>=3.1.6", | |
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Generic Zarr → Apache Arrow tensor extension types. | |
| Reads any Zarr store, auto-discovers coordinate vs data arrays, and builds | |
| Arrow tables using the canonical tensor extension types: | |
| 1. **Fixed Shape Tensor** (arrow.fixed_shape_tensor) | |
| One column per data array. Each row = one slice along a user-chosen axis. | |
| 2. **Variable Shape Tensor** (arrow.variable_shape_tensor) | |
| Single tensor column holding all data arrays. Rows may have different | |
| shapes (some arrays have more dimensions than others). | |
| Both tables round-trip through Arrow IPC with full extension metadata. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| import numpy as np | |
| import pyarrow as pa | |
| import zarr | |
| # --------------------------------------------------------------------------- | |
| # Variable Shape Tensor — custom ExtensionType (not yet a pyarrow builtin) | |
| # --------------------------------------------------------------------------- | |
| class VariableShapeTensorType(pa.ExtensionType): | |
| """Arrow canonical variable_shape_tensor implemented as a pyarrow ExtensionType.""" | |
| def __init__( | |
| self, | |
| value_type: pa.DataType, | |
| ndim: int, | |
| dim_names: list[str] | None = None, | |
| permutation: list[int] | None = None, | |
| uniform_shape: list[int | None] | None = None, | |
| ): | |
| self._ndim = ndim | |
| self._dim_names = dim_names | |
| self._permutation = permutation | |
| self._uniform_shape = uniform_shape | |
| storage = pa.struct([ | |
| pa.field("data", pa.list_(value_type)), | |
| pa.field("shape", pa.list_(pa.int32(), ndim)), | |
| ]) | |
| super().__init__(storage, "arrow.variable_shape_tensor") | |
| def __arrow_ext_serialize__(self) -> bytes: | |
| md: dict = {} | |
| if self._dim_names: | |
| md["dim_names"] = self._dim_names | |
| if self._permutation: | |
| md["permutation"] = self._permutation | |
| if self._uniform_shape: | |
| md["uniform_shape"] = self._uniform_shape | |
| return json.dumps(md).encode() | |
| @classmethod | |
| def __arrow_ext_deserialize__(cls, storage_type, serialized): | |
| md = json.loads(serialized.decode()) | |
| vtype = storage_type.field("data").type.value_type | |
| ndim = storage_type.field("shape").type.list_size | |
| return cls(vtype, ndim, **md) | |
| def _register_vst(): | |
| try: | |
| pa.unregister_extension_type("arrow.variable_shape_tensor") | |
| except KeyError: | |
| pass | |
| pa.register_extension_type(VariableShapeTensorType(pa.float32(), 1)) | |
| # --------------------------------------------------------------------------- | |
| # Zarr introspection | |
| # --------------------------------------------------------------------------- | |
| @dataclass | |
| class ZarrInfo: | |
| """Result of discovering arrays in a Zarr group.""" | |
| coords: dict[str, np.ndarray] = field(default_factory=dict) | |
| data: dict[str, tuple[np.ndarray, list[str]]] = field(default_factory=dict) | |
| # data values are (array, dim_names) | |
| def discover_arrays(grp: zarr.Group) -> ZarrInfo: | |
| """Classify every array in *grp* as a coordinate or a data array.""" | |
| info = ZarrInfo() | |
| dim_names_cache: dict[str, list[str]] = {} | |
| for name, arr in grp.members(): | |
| if not isinstance(arr, zarr.Array): | |
| continue | |
| dims = _get_dim_names(grp, name) | |
| dim_names_cache[name] = dims | |
| data = np.asarray(arr) | |
| # A coordinate is 1D and its name matches its sole dimension | |
| if data.ndim == 1 and dims and dims[0] == name: | |
| info.coords[name] = data | |
| elif data.ndim >= 2: | |
| info.data[name] = (data, dims) | |
| else: | |
| # 1D arrays that aren't self-named coordinates — treat as coords anyway | |
| info.coords[name] = data | |
| return info | |
| def _get_dim_names(grp: zarr.Group, var_name: str) -> list[str]: | |
| """Read dimension_names from Zarr v3 consolidated metadata or variable metadata.""" | |
| store_path = grp.store.root | |
| if store_path is None: | |
| store_path = grp.store.path | |
| meta_path = Path(store_path) / "zarr.json" | |
| if meta_path.exists(): | |
| with open(meta_path) as f: | |
| root_meta = json.load(f) | |
| cm = root_meta.get("consolidated_metadata", {}).get("metadata", {}) | |
| if var_name in cm and "dimension_names" in cm[var_name]: | |
| return cm[var_name]["dimension_names"] | |
| # Fallback: per-array zarr.json | |
| arr_meta_path = Path(store_path) / var_name / "zarr.json" | |
| if arr_meta_path.exists(): | |
| with open(arr_meta_path) as f: | |
| arr_meta = json.load(f) | |
| if "dimension_names" in arr_meta: | |
| return arr_meta["dimension_names"] | |
| # Last resort: xarray v2 convention | |
| attrs = dict(grp[var_name].attrs) | |
| return attrs.get("_ARRAY_DIMENSIONS", []) | |
| _NP_TO_ARROW: dict[str, pa.DataType] = { | |
| "float16": pa.float16(), | |
| "float32": pa.float32(), | |
| "float64": pa.float64(), | |
| "int8": pa.int8(), | |
| "int16": pa.int16(), | |
| "int32": pa.int32(), | |
| "int64": pa.int64(), | |
| "uint8": pa.uint8(), | |
| "uint16": pa.uint16(), | |
| "uint32": pa.uint32(), | |
| "uint64": pa.uint64(), | |
| "bool": pa.bool_(), | |
| } | |
| def _numpy_to_arrow_type(dtype: np.dtype) -> pa.DataType: | |
| return _NP_TO_ARROW.get(dtype.name, pa.from_numpy_dtype(dtype)) | |
| def _resolve_slice_axis(info: ZarrInfo, slice_axis: int | str) -> tuple[str, int]: | |
| """Return (dim_name, axis_index) for the slicing dimension.""" | |
| first_data = next(iter(info.data.values())) | |
| dims = first_data[1] | |
| if isinstance(slice_axis, str): | |
| name = slice_axis | |
| idx = dims.index(name) | |
| else: | |
| idx = slice_axis | |
| name = dims[idx] if dims else f"dim_{idx}" | |
| return name, idx | |
| # --------------------------------------------------------------------------- | |
| # Approach 1 — Fixed Shape Tensor (one column per data array) | |
| # --------------------------------------------------------------------------- | |
| def build_fixed_shape_table( | |
| store_path: Path, | |
| slice_axis: int | str = 0, | |
| ) -> pa.Table: | |
| """ | |
| Each data array becomes a FixedShapeTensor column. | |
| Each row = one slice along *slice_axis*. | |
| """ | |
| grp = zarr.open_group(store_path, mode="r") | |
| info = discover_arrays(grp) | |
| slice_dim, slice_idx = _resolve_slice_axis(info, slice_axis) | |
| columns: dict[str, pa.Array] = {} | |
| fields: list[pa.Field] = [] | |
| # Slicing coordinate as a plain column | |
| if slice_dim in info.coords: | |
| coord = info.coords[slice_dim] | |
| arrow_type = _numpy_to_arrow_type(coord.dtype) | |
| fields.append(pa.field(slice_dim, arrow_type)) | |
| columns[slice_dim] = pa.array(coord, type=arrow_type) | |
| for var_name, (data, dims) in info.data.items(): | |
| # Find which axis of *this* array corresponds to slice_dim | |
| if slice_dim in dims: | |
| ax = dims.index(slice_dim) | |
| remaining_dims = [d for i, d in enumerate(dims) if i != ax] | |
| # Move slice axis to front then reshape | |
| data = np.moveaxis(data, ax, 0) | |
| else: | |
| # This array doesn't have the slice dim — repeat it as a single row | |
| data = data[np.newaxis, ...] | |
| remaining_dims = list(dims) | |
| n_slices = data.shape[0] | |
| per_slice_shape = list(data.shape[1:]) | |
| arrow_value_type = _numpy_to_arrow_type(data.dtype) | |
| tensor_type = pa.fixed_shape_tensor( | |
| arrow_value_type, | |
| per_slice_shape, | |
| dim_names=remaining_dims if remaining_dims else None, | |
| ) | |
| flat = data.reshape(n_slices, -1) | |
| storage = pa.FixedSizeListArray.from_arrays( | |
| pa.array(flat.ravel(), type=arrow_value_type), | |
| int(np.prod(per_slice_shape)), | |
| ) | |
| arr = pa.ExtensionArray.from_storage(tensor_type, storage) | |
| fields.append(pa.field(var_name, tensor_type)) | |
| columns[var_name] = arr | |
| return pa.table(columns, schema=pa.schema(fields)) | |
| # --------------------------------------------------------------------------- | |
| # Approach 2 — Variable Shape Tensor (single mixed-shape column) | |
| # --------------------------------------------------------------------------- | |
| def build_variable_shape_table( | |
| store_path: Path, | |
| slice_axis: int | str = 0, | |
| ) -> pa.Table: | |
| """ | |
| One row per (slice_index, array_name). All data arrays share a single | |
| VariableShapeTensor column. Arrays with fewer dims are padded with | |
| trailing size-1 dims to match the maximum rank. | |
| """ | |
| grp = zarr.open_group(store_path, mode="r") | |
| info = discover_arrays(grp) | |
| slice_dim, slice_idx = _resolve_slice_axis(info, slice_axis) | |
| # Determine max ndim after removing slice axis, and collect dim names | |
| per_array_sliced: list[tuple[str, np.ndarray, list[str]]] = [] | |
| max_ndim = 0 | |
| for var_name, (data, dims) in info.data.items(): | |
| if slice_dim in dims: | |
| ax = dims.index(slice_dim) | |
| remaining_dims = [d for i, d in enumerate(dims) if i != ax] | |
| data = np.moveaxis(data, ax, 0) | |
| else: | |
| remaining_dims = list(dims) | |
| data = data[np.newaxis, ...] | |
| per_array_sliced.append((var_name, data, remaining_dims)) | |
| max_ndim = max(max_ndim, len(data.shape) - 1) | |
| # Build unified dim_names (ordered: most-common-first, padded dims last) | |
| all_dim_names: list[str] = [] | |
| for _, _, rdims in per_array_sliced: | |
| for d in rdims: | |
| if d not in all_dim_names: | |
| all_dim_names.append(d) | |
| # Pad to max_ndim | |
| while len(all_dim_names) < max_ndim: | |
| all_dim_names.append(f"dim_{len(all_dim_names)}") | |
| # Compute uniform_shape: size where all arrays agree, None where they differ | |
| all_shapes: list[list[int]] = [] | |
| for _, data, rdims in per_array_sliced: | |
| per_slice = list(data.shape[1:]) | |
| while len(per_slice) < max_ndim: | |
| per_slice.append(1) | |
| all_shapes.append(per_slice) | |
| uniform_shape: list[int | None] = [] | |
| for dim_i in range(max_ndim): | |
| sizes = {s[dim_i] for s in all_shapes} | |
| uniform_shape.append(sizes.pop() if len(sizes) == 1 else None) | |
| # Resolve the common arrow value type (promote to widest) | |
| dtypes = [data.dtype for _, data, _ in per_array_sliced] | |
| common_dtype = np.result_type(*dtypes) | |
| arrow_value_type = _numpy_to_arrow_type(common_dtype) | |
| vst = VariableShapeTensorType( | |
| arrow_value_type, | |
| max_ndim, | |
| dim_names=all_dim_names, | |
| uniform_shape=uniform_shape, | |
| ) | |
| # Slice coordinate | |
| if slice_dim in info.coords: | |
| coord = info.coords[slice_dim] | |
| else: | |
| first_data = per_array_sliced[0][1] | |
| coord = np.arange(first_data.shape[0]) | |
| row_coords = [] | |
| row_varnames = [] | |
| data_lists: list[list] = [] | |
| shape_vals: list[int] = [] | |
| for var_name, data, rdims in per_array_sliced: | |
| n_slices = data.shape[0] | |
| for s_idx in range(n_slices): | |
| row_coords.append(coord[s_idx] if s_idx < len(coord) else s_idx) | |
| row_varnames.append(var_name) | |
| slab = data[s_idx].astype(common_dtype) | |
| # Pad trailing dims to max_ndim | |
| while slab.ndim < max_ndim: | |
| slab = slab[..., np.newaxis] | |
| shape_vals.extend(slab.shape) | |
| data_lists.append(slab.ravel().tolist()) | |
| n_rows = len(data_lists) | |
| data_arr = pa.array(data_lists, type=pa.list_(arrow_value_type)) | |
| shape_arr = pa.FixedSizeListArray.from_arrays( | |
| pa.array(shape_vals, type=pa.int32()), max_ndim, | |
| ) | |
| storage = pa.StructArray.from_arrays( | |
| [data_arr, shape_arr], names=["data", "shape"], | |
| ) | |
| tensor_col = pa.ExtensionArray.from_storage(vst, storage) | |
| coord_arrow_type = _numpy_to_arrow_type(np.array(row_coords).dtype) | |
| table = pa.table({ | |
| slice_dim: pa.array(row_coords, type=coord_arrow_type), | |
| "array_name": pa.array(row_varnames, type=pa.utf8()), | |
| "tensor": tensor_col, | |
| }) | |
| return table | |
| # --------------------------------------------------------------------------- | |
| # round-trip verification | |
| # --------------------------------------------------------------------------- | |
| def write_ipc(table: pa.Table, path: Path): | |
| with pa.ipc.new_file(path, table.schema) as writer: | |
| writer.write_table(table) | |
| print(f" Wrote {path} ({path.stat().st_size:,} bytes)") | |
| def read_ipc(path: Path) -> pa.Table: | |
| with pa.ipc.open_file(path) as reader: | |
| return reader.read_all() | |
| def verify_roundtrip(original: pa.Table, path: Path): | |
| """Write → read → compare.""" | |
| write_ipc(original, path) | |
| restored = read_ipc(path) | |
| assert original.schema.equals(restored.schema), "Schema mismatch!" | |
| assert original.num_rows == restored.num_rows, "Row count mismatch!" | |
| for col_name in original.column_names: | |
| orig_col = original.column(col_name) | |
| rest_col = restored.column(col_name) | |
| assert orig_col.equals(rest_col), f"Column {col_name} mismatch!" | |
| print(f" Round-trip OK ({original.num_rows} rows, {len(original.schema)} cols)") | |
| # --------------------------------------------------------------------------- | |
| # main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| _register_vst() | |
| store = Path("weather.zarr") | |
| grp = zarr.open_group(store, mode="r") | |
| info = discover_arrays(grp) | |
| print("Discovered arrays:") | |
| print(f" Coordinates: {list(info.coords.keys())}") | |
| for name, (arr, dims) in info.data.items(): | |
| print(f" Data: {name:20s} shape={arr.shape} dims={dims}") | |
| print() | |
| # -- Approach 1 ---------------------------------------------------------- | |
| print("=" * 70) | |
| print("APPROACH 1: Fixed Shape Tensor (one column per data array)") | |
| print("=" * 70) | |
| fixed_table = build_fixed_shape_table(store, slice_axis="time") | |
| print(fixed_table.schema) | |
| print() | |
| print(fixed_table) | |
| print() | |
| verify_roundtrip(fixed_table, Path("fixed_shape.arrow")) | |
| # -- Approach 2 ---------------------------------------------------------- | |
| print() | |
| print("=" * 70) | |
| print("APPROACH 2: Variable Shape Tensor (mixed-shape single column)") | |
| print("=" * 70) | |
| vst_table = build_variable_shape_table(store, slice_axis="time") | |
| print(vst_table.schema) | |
| print() | |
| print(vst_table) | |
| print() | |
| verify_roundtrip(vst_table, Path("variable_shape.arrow")) | |
| # -- Tensor extraction demo ---------------------------------------------- | |
| print() | |
| print("=" * 70) | |
| print("TENSOR EXTRACTION DEMO") | |
| print("=" * 70) | |
| # Fixed shape: first data column, row 0 | |
| data_col_names = [ | |
| f.name for f in fixed_table.schema | |
| if isinstance(f.type, pa.FixedShapeTensorType) | |
| ] | |
| for col_name in data_col_names: | |
| col = fixed_table.column(col_name) | |
| tt = col.type | |
| row = np.array(col[0].as_py(), dtype=np.float32).reshape(tt.shape) | |
| print(f"\n {col_name}[0] shape={row.shape} dims={tt.dim_names} " | |
| f"min={row.min():.2f} max={row.max():.2f}") | |
| # Variable shape: first and last row | |
| tensor_col = vst_table.column("tensor") | |
| for i in [0, len(tensor_col) - 1]: | |
| row = tensor_col[i].as_py() | |
| shape = row["shape"] | |
| data = np.array(row["data"], dtype=np.float32).reshape(shape) | |
| vname = vst_table.column("array_name")[i].as_py() | |
| print(f"\n row[{i}] ({vname}) shape={tuple(shape)} " | |
| f"min={data.min():.2f} max={data.max():.2f}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment