Created
May 30, 2023 15:33
-
-
Save d-v-b/756887e3237f9cca5d4fa79922e94380 to your computer and use it in GitHub Desktop.
Integrating pydantic with zarr for typed zarr hierarchies.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from typing import ( | |
Any, | |
Dict, Generic, Iterable, Literal, TypeVar, | |
TypedDict, Union, Protocol, runtime_checkable) | |
from pydantic import ValidationError | |
from pydantic.generics import GenericModel | |
from zarr.storage import init_group, BaseStore | |
import zarr | |
import os | |
from rich import print_json | |
AttrsType = TypeVar('AttrsType', bound=Dict[str, Any]) | |
class ArraySpec(GenericModel, Generic[AttrsType]): | |
zarr_version: Literal[2] = 2 | |
shape: tuple[int, ...] | |
chunks: tuple[int, ...] | |
dtype: str | |
fill_value: Union[None, int, float] = 0 | |
order: Union[Literal['C'], Literal['F']] = 'C' | |
filters: Dict[str, Any] = {} | |
dimension_separator: Union[Literal['.'], Literal['/']] = '/' | |
compressor: Dict[str, Any] = None | |
attrs: AttrsType = {} | |
@classmethod | |
def from_array(cls, | |
data: ArrayLike, | |
chunks, | |
fill_value = 0, | |
order = 'C', | |
filters = {}, | |
dimension_separator = '/', | |
compressor = None, | |
attrs = {} | |
): | |
return cls( | |
shape=data.shape, | |
dtype=data.dtype, | |
chunks=chunks, | |
fill_value=fill_value, | |
order=order, | |
filters=filters, | |
dimension_separator=dimension_separator, | |
compressor=compressor, | |
attrs=attrs | |
) | |
ValuesType = TypeVar('ValuesType', bound=Union['ArraySpec', 'GroupSpec']) | |
class GroupSpec(GenericModel, Generic[AttrsType, ValuesType]): | |
zarr_version: Literal[2] = 2 | |
attrs: AttrsType = {} | |
children: dict[str, ValuesType] = {} | |
@runtime_checkable | |
class NodeLike(Protocol): | |
basename: str | |
attrs: dict[str, Any] | |
@runtime_checkable | |
class ArrayLike(NodeLike, Protocol): | |
attrs: dict[str, Any] | |
fill_value: Any | |
chunks: tuple[int, ...] | |
shape: tuple[int, ...] | |
dtype: str | |
@runtime_checkable | |
class GroupLike(NodeLike, Protocol): | |
attrs: dict[str, Any] | |
def values(self) -> Iterable[Union[GroupLike, ArrayLike]]: | |
""" | |
Iterable of the children of this group | |
""" | |
... | |
def to_tree(element: Union[GroupLike, ArrayLike]) -> tuple(str, Union[ArraySpec, GroupSpec]): | |
""" | |
Recursively parse a Zarr group or Zarr array into an ArraySpec or GroupSpec. | |
""" | |
result: Union[GroupSpec, ArraySpec] | |
if isinstance(element, ArrayLike): | |
result = (element.basename, ArraySpec( | |
shape=element.shape, | |
dtype=str(element.dtype), | |
attrs=dict(element.attrs), | |
chunks=element.chunks, | |
fill_value=element.fill_value | |
)) | |
elif isinstance(element, GroupLike): | |
children = tuple(map(to_tree, element.values())) | |
result = (element.name, GroupSpec(attrs=dict(element.attrs), | |
children=children)) | |
else: | |
msg = f"Object of type {type(element)} cannot be processed." | |
raise ValueError(msg) | |
return result | |
def from_tree(store: BaseStore, path: str, tree: Union[ArraySpec, GroupSpec]) -> Union[zarr.Array, zarr.Group]: | |
""" | |
Materialize a zarr hierarchy on a given storage backend from an ArraySpec or | |
GroupSpec | |
""" | |
if isinstance(tree, ArraySpec): | |
tree_dict = tree.dict() | |
attrs = tree_dict.pop('attrs') | |
result: zarr.Array = zarr.create(store=store, path=path, **tree_dict) | |
result.attrs.put(attrs) | |
elif isinstance(tree, GroupSpec): | |
tree_dict = tree.dict() | |
tree_dict.pop('children') | |
attrs = tree_dict.pop('attrs') | |
# needing to call init_group, then zarr.group is not ergonomic | |
init_group(store=store, path=path) | |
result = zarr.group(store=store, path=path, **tree_dict) | |
result.attrs.put(attrs) | |
for name, child in tree.children.items(): | |
subpath = os.path.join(path, name) | |
from_tree(store, subpath, child) | |
else: | |
msg = f''' | |
Invalid argument for `tree`. Expected an instance of GroupSpec or ArraySpec, got | |
{type(tree)} instead. | |
''' | |
raise ValueError(msg) | |
return result | |
import zarr | |
# an example of structured group attributes | |
class GroupAttrs(TypedDict): | |
foo: int | |
bar: list[int] | |
# an example of structured array attributes | |
class ArrayAttrs(TypedDict): | |
scale: list[float] | |
store = zarr.MemoryStore() | |
spec = GroupSpec( | |
attrs={'foo': 100, 'bar': ['a','b','c']}, | |
children={ | |
's0': ArraySpec( | |
shape=(1000,), | |
chunks=(100,), | |
dtype='uint8', | |
attrs=ArrayAttrs(scale=[1.0])), | |
's1': ArraySpec( | |
shape=(500,), | |
chunks=(100,), | |
dtype='uint8', | |
attrs=ArrayAttrs(scale=[2.0]))}) | |
# materialize a zarr group, based on the spec | |
group_a = from_tree(store, '/group_a', spec) | |
# parse the spec from that group | |
name, tree = to_tree(group_a) | |
# check that the spec round-tripped | |
assert tree == spec | |
print_json(tree.json()) | |
# now add types for schema validation | |
try: | |
invalid_spec = tree.dict() | |
# the type of the group attributes is invalid, so this will raise a validation error | |
parsed = GroupSpec[GroupAttrs, ArraySpec[ArrayAttrs]](**invalid_spec) | |
except ValidationError as e: | |
msg = f'There was a validation error: {e}.' | |
print(msg) | |
tree_dict = tree.dict() | |
# this now complies with the declared type of GroupAttrs | |
tree_dict['attrs']['bar'] = [1,2,3] | |
parsed = GroupSpec[GroupAttrs, ArraySpec[ArrayAttrs]](**tree_dict) | |
# create a valid zarr group | |
group_b = from_tree(store, '/group_b', parsed) | |
print_json(parsed.json()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment