Skip to content

Instantly share code, notes, and snippets.

@d-v-b
Created May 30, 2023 15:33
Show Gist options
  • Save d-v-b/756887e3237f9cca5d4fa79922e94380 to your computer and use it in GitHub Desktop.
Save d-v-b/756887e3237f9cca5d4fa79922e94380 to your computer and use it in GitHub Desktop.
Integrating pydantic with zarr for typed zarr hierarchies.
from __future__ import annotations
from typing import (
Any,
Dict, Generic, Iterable, Literal, TypeVar,
TypedDict, Union, Protocol, runtime_checkable)
from pydantic import ValidationError
from pydantic.generics import GenericModel
from zarr.storage import init_group, BaseStore
import zarr
import os
from rich import print_json
AttrsType = TypeVar('AttrsType', bound=Dict[str, Any])
class ArraySpec(GenericModel, Generic[AttrsType]):
zarr_version: Literal[2] = 2
shape: tuple[int, ...]
chunks: tuple[int, ...]
dtype: str
fill_value: Union[None, int, float] = 0
order: Union[Literal['C'], Literal['F']] = 'C'
filters: Dict[str, Any] = {}
dimension_separator: Union[Literal['.'], Literal['/']] = '/'
compressor: Dict[str, Any] = None
attrs: AttrsType = {}
@classmethod
def from_array(cls,
data: ArrayLike,
chunks,
fill_value = 0,
order = 'C',
filters = {},
dimension_separator = '/',
compressor = None,
attrs = {}
):
return cls(
shape=data.shape,
dtype=data.dtype,
chunks=chunks,
fill_value=fill_value,
order=order,
filters=filters,
dimension_separator=dimension_separator,
compressor=compressor,
attrs=attrs
)
ValuesType = TypeVar('ValuesType', bound=Union['ArraySpec', 'GroupSpec'])
class GroupSpec(GenericModel, Generic[AttrsType, ValuesType]):
zarr_version: Literal[2] = 2
attrs: AttrsType = {}
children: dict[str, ValuesType] = {}
@runtime_checkable
class NodeLike(Protocol):
basename: str
attrs: dict[str, Any]
@runtime_checkable
class ArrayLike(NodeLike, Protocol):
attrs: dict[str, Any]
fill_value: Any
chunks: tuple[int, ...]
shape: tuple[int, ...]
dtype: str
@runtime_checkable
class GroupLike(NodeLike, Protocol):
attrs: dict[str, Any]
def values(self) -> Iterable[Union[GroupLike, ArrayLike]]:
"""
Iterable of the children of this group
"""
...
def to_tree(element: Union[GroupLike, ArrayLike]) -> tuple(str, Union[ArraySpec, GroupSpec]):
"""
Recursively parse a Zarr group or Zarr array into an ArraySpec or GroupSpec.
"""
result: Union[GroupSpec, ArraySpec]
if isinstance(element, ArrayLike):
result = (element.basename, ArraySpec(
shape=element.shape,
dtype=str(element.dtype),
attrs=dict(element.attrs),
chunks=element.chunks,
fill_value=element.fill_value
))
elif isinstance(element, GroupLike):
children = tuple(map(to_tree, element.values()))
result = (element.name, GroupSpec(attrs=dict(element.attrs),
children=children))
else:
msg = f"Object of type {type(element)} cannot be processed."
raise ValueError(msg)
return result
def from_tree(store: BaseStore, path: str, tree: Union[ArraySpec, GroupSpec]) -> Union[zarr.Array, zarr.Group]:
"""
Materialize a zarr hierarchy on a given storage backend from an ArraySpec or
GroupSpec
"""
if isinstance(tree, ArraySpec):
tree_dict = tree.dict()
attrs = tree_dict.pop('attrs')
result: zarr.Array = zarr.create(store=store, path=path, **tree_dict)
result.attrs.put(attrs)
elif isinstance(tree, GroupSpec):
tree_dict = tree.dict()
tree_dict.pop('children')
attrs = tree_dict.pop('attrs')
# needing to call init_group, then zarr.group is not ergonomic
init_group(store=store, path=path)
result = zarr.group(store=store, path=path, **tree_dict)
result.attrs.put(attrs)
for name, child in tree.children.items():
subpath = os.path.join(path, name)
from_tree(store, subpath, child)
else:
msg = f'''
Invalid argument for `tree`. Expected an instance of GroupSpec or ArraySpec, got
{type(tree)} instead.
'''
raise ValueError(msg)
return result
import zarr
# an example of structured group attributes
class GroupAttrs(TypedDict):
foo: int
bar: list[int]
# an example of structured array attributes
class ArrayAttrs(TypedDict):
scale: list[float]
store = zarr.MemoryStore()
spec = GroupSpec(
attrs={'foo': 100, 'bar': ['a','b','c']},
children={
's0': ArraySpec(
shape=(1000,),
chunks=(100,),
dtype='uint8',
attrs=ArrayAttrs(scale=[1.0])),
's1': ArraySpec(
shape=(500,),
chunks=(100,),
dtype='uint8',
attrs=ArrayAttrs(scale=[2.0]))})
# materialize a zarr group, based on the spec
group_a = from_tree(store, '/group_a', spec)
# parse the spec from that group
name, tree = to_tree(group_a)
# check that the spec round-tripped
assert tree == spec
print_json(tree.json())
# now add types for schema validation
try:
invalid_spec = tree.dict()
# the type of the group attributes is invalid, so this will raise a validation error
parsed = GroupSpec[GroupAttrs, ArraySpec[ArrayAttrs]](**invalid_spec)
except ValidationError as e:
msg = f'There was a validation error: {e}.'
print(msg)
tree_dict = tree.dict()
# this now complies with the declared type of GroupAttrs
tree_dict['attrs']['bar'] = [1,2,3]
parsed = GroupSpec[GroupAttrs, ArraySpec[ArrayAttrs]](**tree_dict)
# create a valid zarr group
group_b = from_tree(store, '/group_b', parsed)
print_json(parsed.json())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment