Last active
August 21, 2018 15:12
-
-
Save mivade/db0f280ff845d51979a86f87d8ec6d03 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import codecs | |
import json | |
from typing import Union | |
import h5py | |
import numpy as np | |
import pandas as pd | |
vlen = np.vectorize(len) | |
vencode = np.vectorize(codecs.encode) | |
vdecode = np.vectorize(codecs.decode) | |
def save(hfile: h5py.File, where: str, data: Union[pd.DataFrame, np.array]): | |
"""Save record array-like data to HDF5. | |
Parameters | |
---------- | |
hfile | |
Opened HDF5 file object. | |
where | |
Dataset name. | |
data | |
The data to write. | |
""" | |
original_type = str(type(data)) | |
if isinstance(data, pd.DataFrame): | |
data = data.to_records() | |
if not isinstance(data, np.recarray): | |
data = np.rec.array(data) | |
dtype = [] | |
encoded = set() | |
for name in data.dtype.names: | |
this_dtype = data[name].dtype | |
if this_dtype == object or this_dtype.char == "U": | |
maxlen = np.amax(vlen(data[name])) | |
dtype.append((name, f"|S{maxlen}")) | |
encoded.add(name) | |
else: | |
dtype.append((name, this_dtype)) | |
sanitized = np.recarray(data.shape, dtype=dtype) | |
for name, _ in dtype: | |
if name in encoded: | |
sanitized[name] = vencode(data[name]) | |
else: | |
sanitized[name] = data[name] | |
hfile[where] = sanitized | |
hfile[where].attrs["utf8_encoded_fields"] = json.dumps(list(encoded)) | |
hfile[where].attrs["original_type"] = original_type | |
def load(hfile: h5py.File, where: str) -> np.array: | |
"""Load data stored with :func:`save`. | |
Parameters | |
---------- | |
hfile | |
Open HDF5 file object. | |
where | |
Key to load data from. | |
""" | |
data = pd.DataFrame(hfile[where][:]) | |
encoded = json.loads(hfile[where].attrs["utf8_encoded_fields"]) | |
columns = {key: value for key, value in data.items()} | |
for name in encoded: | |
columns[name] = vdecode(columns[name]) | |
df = pd.DataFrame(columns) | |
if "DataFrame" not in hfile[where].attrs["original_type"]: | |
return df.to_records() | |
return df | |
if __name__ == "__main__": | |
df = pd.DataFrame({ | |
"string": ["a", "string"], | |
"integer": [1, 2], | |
"float": [1., 2.], | |
}) | |
ra = np.rec.array( | |
[("a", 1), ("longer string", 2)], | |
dtype=[("description", "<U32"), ("number", int)] | |
) | |
with h5py.File("test.h5", "w") as hfile: | |
save(hfile, "dataframe", df) | |
save(hfile, "recarray", ra) | |
with h5py.File("test.h5", "r") as hfile: | |
print(load(hfile, "dataframe")) | |
print(load(hfile, "recarray")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment