Created
November 26, 2022 23:54
-
-
Save shoyer/5b0c485979cc9c36a9685d8cf8e94565 to your computer and use it in GitHub Desktop.
xarray zarr via tensor store
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2022 Google LLC. | |
# SPDX-License-Identifier: Apache-2.0 | |
import tensorstore | |
import json | |
import os.path | |
import fsspec | |
import xarray | |
import xarray.backends | |
def zarr_spec_from_path(path): | |
return { | |
'driver': 'zarr', | |
'kvstore': { | |
'driver': 'file', | |
'path': path, | |
} | |
} | |
def load_zarr_consolidated_metadata(path): | |
metadata_path = os.path.join(path, '.zmetadata') | |
with open(metadata_path, 'r') as f: | |
contents = json.load(f) | |
if contents.get('zarr_consolidated_format') != 1: | |
raise ValueError('invalid .zmetadata') | |
metadata = contents['metadata'] | |
return metadata | |
def load_zattrs(group_path, array_names): | |
paths = {k: os.path.join(group_path, k, '.zattrs') for k in array_names} | |
fs = fsspec.filesystem('file') | |
data = fs.cat(paths.values()) | |
expanded_paths = {k: fs.expand_path(v)[0] for k, v in paths.items()} | |
return {k: json.loads(data[expanded_paths[k]].decode('utf8')) for k in array_names} | |
class ZarrTensorStoreDataStore(xarray.backends.AbstractDataStore): | |
def __init__(self, variables, attrs): | |
self.variables = variables | |
self.attrs = attrs | |
def load(self): | |
return self.variables, self.attrs | |
class TensorStoreWrapper(xarray.backends.BackendArray): | |
def __init__(self, ts_array): | |
self.ts_array = ts_array | |
self.shape = ts_array.shape | |
self.dtype = ts_array.dtype.numpy_dtype | |
def __getitem__(self, key): | |
if isinstance(key, xarray.core.indexing.OuterIndexer): | |
indexed = self.ts_array.oindex[key.tuple] | |
elif isinstance(key, xarray.core.indexing.VectorizedIndexer): | |
indexed = self.ts_array.vindex[key.tuple] | |
else: | |
assert isinstance(key, xarray.core.indexing.BasicIndexer) | |
indexed = self.ts_array[key.tuple] | |
return indexed.read().result() | |
def __repr__(self): | |
return f'{type(self).__name__}({self.ts_array!r})' | |
def open_zarr_via_tensorstore(path): | |
"""Open a Zarr store via TensorStore. | |
Current limitations: | |
1. The Zarr store must be stored with consolidated metadata. | |
2. Only supports the "file" TensoreStore driver. | |
""" | |
metadata = load_zarr_consolidated_metadata(path) # blocking | |
array_names = [ | |
k[:-len('/.zarray')] for k in metadata if k[-len('/.zarray'):] == '/.zarray' | |
] | |
specs = {k: zarr_spec_from_path(os.path.join(path, k)) for k in array_names} | |
array_futures = {k: tensorstore.open(spec) for k, spec in specs.items()} | |
array_zattrs = load_zattrs(path, array_names) # blocking | |
variables = {} | |
for name in array_names: | |
dims = array_zattrs[name]['_ARRAY_DIMENSIONS'] | |
data = TensorStoreWrapper(array_futures[name].result()) | |
attrs = {k: v for k, v in array_zattrs[name].items() if k != '_ARRAY_DIMENSIONS'} | |
variables[name] = xarray.Variable(dims, data, attrs) | |
store = ZarrTensorStoreDataStore(variables, metadata['.zattrs']) | |
return xarray.open_dataset(store, engine='store') | |
def run_unit_test(): | |
ds = xarray.tutorial.load_dataset('eraint_uvz') | |
ds.to_zarr('eraint_uvz.zarr') | |
roundtripped = open_zarr_via_tensorstore('eraint_uvz.zarr') | |
assert 'TensorStore' in repr(roundtripped.variables['u']._data) | |
xarray.testing.assert_identical(roundtripped, ds) | |
run_unit_test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note: see https://github.com/google/xarray-tensorstore