Skip to content

Instantly share code, notes, and snippets.

@swarn
Created January 8, 2025 03:52
Show Gist options
  • Save swarn/f7874cc9ddfaaadeaf99622954df5b38 to your computer and use it in GitHub Desktop.
Save swarn/f7874cc9ddfaaadeaf99622954df5b38 to your computer and use it in GitHub Desktop.
Dataclass serialization
import json
from dataclasses import MISSING, fields, is_dataclass
from inspect import isabstract, isclass
from types import GenericAlias, NoneType
from typing import Any, TypeAliasType, get_origin
import numpy as np
from numpy.typing import NDArray
type JSON = (
dict[str, "JSON"]
| list["JSON"]
| tuple["JSON"]
| str
| int
| float
| bool
| NoneType
)
class ValidationError(ValueError):
pass
def serialize(x: JSON) -> str:
"""Serialization into json for GPSM types."""
return json.dumps(x, default=encode_value)
def deserialize[T](into: type[T], data: str) -> T:
"""Typed json deserialization for GPSM types."""
return decode_value(into, json.loads(data))
def encode_value(x: Any) -> JSON:
"""Encoder for GPSM types.
Converts dataclasses into types for JSON serialization.
"""
if is_dataclass(x):
return encode_dataclass(x)
if isinstance(x, np.ndarray):
return encode_array(x)
json_types = [dict, list, tuple, str, int, float, bool, NoneType]
if type(x) not in json_types:
raise ValidationError(f"Can't serialize type {type(x)}")
return x
def encode_array(x: NDArray[np.generic]) -> list[JSON]:
if not np.issubdtype(x.dtype, np.number):
raise ValidationError
return x.tolist() # pyright: ignore[reportReturnType]
def encode_dataclass(x: Any) -> dict[str, JSON]:
if isclass(x):
raise ValidationError("Can't serialize class definitions")
# Convert to a dict. Don't use dataclasses.asdict, because it recurses into
# dataclasses, and we want to specialize that.
d = {field.name: getattr(x, field.name) for field in fields(x)}
if hasattr(type(x), "tag"):
d["tag"] = getattr(type(x), "tag")
return {k: encode_value(v) for k, v in d.items()}
def decode_value[T](into: type[T], value: Any) -> T:
"""A typed decoder for GPSM types.
Given a type and decoded json, convert to the specified type.
"""
concrete = get_actual_type(into)
if isabstract(concrete):
return decode_abc(concrete, value)
if is_dataclass(concrete) and isinstance(concrete, type):
return decode_dataclass(concrete, value) # pyright: ignore[reportReturnType]
if concrete is np.ndarray:
return decode_array(value) # pyright: ignore[reportReturnType]
if type(value) is not concrete:
raise ValidationError(f"expected {concrete}, got {value}")
return value
def decode_array(value: Any) -> NDArray[np.number]:
decoded = np.array(value)
if not np.issubdtype(decoded.dtype, np.number):
raise ValidationError(f"Expected array, got {value}")
return decoded
def decode_dataclass[T](into: type[T], value: Any) -> T:
"""Do typed decoding of a datacass using its type annotations.
Attempts to convert nested json objects into the types described in the
dataclass. This includes nested dataclasses.
This treats class variables named "tag" as the discriminator in tagged
unions, see `decode_abc`.
"""
if not isinstance(value, dict):
raise ValidationError(f"Expected {into}, got {value}")
if not is_dataclass(into):
raise ValueError
field_info = {field.name: field for field in fields(into)}
# If present, the 'tag' field doesn't get deserialized or checked.
json_keys = set(value) - set(["tag"])
required_fields = set(
field.name
for field in fields(into)
if field.default == MISSING and field.default_factory == MISSING
)
missing_fields = required_fields - json_keys
if missing_fields:
missing_names = ", ".join(f for f in missing_fields)
raise ValidationError(f"For {into}, missing the fields {missing_names}")
extra_fields = json_keys - set(field_info)
if extra_fields:
extra_names = ", ".join(f for f in extra_fields)
raise ValidationError(f"For {into}, extra fields {extra_names}")
args = {k: decode_value(field_info[k].type, value[k]) for k in json_keys}
return into(**args)
def decode_abc[T](into: type[T], value: Any) -> T:
"""Decode an abstract base class.
In GPSM, the convention is that concrete classes have a class variable
named "tag" with a string unique among the concrete subclasses of the
abastract base class. This tag is used when deserializing an object that
can be any of the subclasses of the ABC, like a tagged union.
"""
if not isinstance(value, dict):
raise ValidationError(f"Expected {into}, got {value}")
lookup = {cls.tag: cls for cls in subclasses_of(into) if hasattr(cls, "tag")}
if "tag" not in value:
msg = f"Expected {into.__name__}, which is abstract,\n"
msg += "But the given value does not have a 'tag' key\n"
msg += f"Possible tags: {', '.join(k for k in lookup)}\n"
msg += f"value: {value}"
raise ValidationError(msg)
return decode_dataclass(lookup[value["tag"]], value)
def subclasses_of(cls: type) -> set[type]:
children = set(cls.__subclasses__())
return children.union(s for c in children for s in subclasses_of(c))
def get_actual_type(obj: Any) -> Any:
# Strip off all type aliases.
while isinstance(obj, TypeAliasType):
obj = obj.__value__
# Go from a generic type to the underlying type.
if isinstance(obj, GenericAlias):
return get_origin(obj)
return obj
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment