Created
January 8, 2025 03:52
-
-
Save swarn/f7874cc9ddfaaadeaf99622954df5b38 to your computer and use it in GitHub Desktop.
Dataclass serialization
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from dataclasses import MISSING, fields, is_dataclass | |
from inspect import isabstract, isclass | |
from types import GenericAlias, NoneType | |
from typing import Any, TypeAliasType, get_origin | |
import numpy as np | |
from numpy.typing import NDArray | |
type JSON = ( | |
dict[str, "JSON"] | |
| list["JSON"] | |
| tuple["JSON"] | |
| str | |
| int | |
| float | |
| bool | |
| NoneType | |
) | |
class ValidationError(ValueError): | |
pass | |
def serialize(x: JSON) -> str: | |
"""Serialization into json for GPSM types.""" | |
return json.dumps(x, default=encode_value) | |
def deserialize[T](into: type[T], data: str) -> T: | |
"""Typed json deserialization for GPSM types.""" | |
return decode_value(into, json.loads(data)) | |
def encode_value(x: Any) -> JSON: | |
"""Encoder for GPSM types. | |
Converts dataclasses into types for JSON serialization. | |
""" | |
if is_dataclass(x): | |
return encode_dataclass(x) | |
if isinstance(x, np.ndarray): | |
return encode_array(x) | |
json_types = [dict, list, tuple, str, int, float, bool, NoneType] | |
if type(x) not in json_types: | |
raise ValidationError(f"Can't serialize type {type(x)}") | |
return x | |
def encode_array(x: NDArray[np.generic]) -> list[JSON]: | |
if not np.issubdtype(x.dtype, np.number): | |
raise ValidationError | |
return x.tolist() # pyright: ignore[reportReturnType] | |
def encode_dataclass(x: Any) -> dict[str, JSON]: | |
if isclass(x): | |
raise ValidationError("Can't serialize class definitions") | |
# Convert to a dict. Don't use dataclasses.asdict, because it recurses into | |
# dataclasses, and we want to specialize that. | |
d = {field.name: getattr(x, field.name) for field in fields(x)} | |
if hasattr(type(x), "tag"): | |
d["tag"] = getattr(type(x), "tag") | |
return {k: encode_value(v) for k, v in d.items()} | |
def decode_value[T](into: type[T], value: Any) -> T: | |
"""A typed decoder for GPSM types. | |
Given a type and decoded json, convert to the specified type. | |
""" | |
concrete = get_actual_type(into) | |
if isabstract(concrete): | |
return decode_abc(concrete, value) | |
if is_dataclass(concrete) and isinstance(concrete, type): | |
return decode_dataclass(concrete, value) # pyright: ignore[reportReturnType] | |
if concrete is np.ndarray: | |
return decode_array(value) # pyright: ignore[reportReturnType] | |
if type(value) is not concrete: | |
raise ValidationError(f"expected {concrete}, got {value}") | |
return value | |
def decode_array(value: Any) -> NDArray[np.number]: | |
decoded = np.array(value) | |
if not np.issubdtype(decoded.dtype, np.number): | |
raise ValidationError(f"Expected array, got {value}") | |
return decoded | |
def decode_dataclass[T](into: type[T], value: Any) -> T: | |
"""Do typed decoding of a datacass using its type annotations. | |
Attempts to convert nested json objects into the types described in the | |
dataclass. This includes nested dataclasses. | |
This treats class variables named "tag" as the discriminator in tagged | |
unions, see `decode_abc`. | |
""" | |
if not isinstance(value, dict): | |
raise ValidationError(f"Expected {into}, got {value}") | |
if not is_dataclass(into): | |
raise ValueError | |
field_info = {field.name: field for field in fields(into)} | |
# If present, the 'tag' field doesn't get deserialized or checked. | |
json_keys = set(value) - set(["tag"]) | |
required_fields = set( | |
field.name | |
for field in fields(into) | |
if field.default == MISSING and field.default_factory == MISSING | |
) | |
missing_fields = required_fields - json_keys | |
if missing_fields: | |
missing_names = ", ".join(f for f in missing_fields) | |
raise ValidationError(f"For {into}, missing the fields {missing_names}") | |
extra_fields = json_keys - set(field_info) | |
if extra_fields: | |
extra_names = ", ".join(f for f in extra_fields) | |
raise ValidationError(f"For {into}, extra fields {extra_names}") | |
args = {k: decode_value(field_info[k].type, value[k]) for k in json_keys} | |
return into(**args) | |
def decode_abc[T](into: type[T], value: Any) -> T: | |
"""Decode an abstract base class. | |
In GPSM, the convention is that concrete classes have a class variable | |
named "tag" with a string unique among the concrete subclasses of the | |
abastract base class. This tag is used when deserializing an object that | |
can be any of the subclasses of the ABC, like a tagged union. | |
""" | |
if not isinstance(value, dict): | |
raise ValidationError(f"Expected {into}, got {value}") | |
lookup = {cls.tag: cls for cls in subclasses_of(into) if hasattr(cls, "tag")} | |
if "tag" not in value: | |
msg = f"Expected {into.__name__}, which is abstract,\n" | |
msg += "But the given value does not have a 'tag' key\n" | |
msg += f"Possible tags: {', '.join(k for k in lookup)}\n" | |
msg += f"value: {value}" | |
raise ValidationError(msg) | |
return decode_dataclass(lookup[value["tag"]], value) | |
def subclasses_of(cls: type) -> set[type]: | |
children = set(cls.__subclasses__()) | |
return children.union(s for c in children for s in subclasses_of(c)) | |
def get_actual_type(obj: Any) -> Any: | |
# Strip off all type aliases. | |
while isinstance(obj, TypeAliasType): | |
obj = obj.__value__ | |
# Go from a generic type to the underlying type. | |
if isinstance(obj, GenericAlias): | |
return get_origin(obj) | |
return obj |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment