Last active
November 12, 2021 03:14
-
-
Save altescy/ffafeff82c136927951ab1f92fadefe1 to your computer and use it in GitHub Desktop.
Construct dataclass object from dict recursively
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import dataclasses | |
import json | |
import typing | |
from collections.abc import Sized | |
from typing import Any, Dict, Optional, Type, TypeVar, Union | |
T = TypeVar("T") | |
def dataclass_from_dict(dataclass: Type[T], obj: Dict[str, Any]) -> T: | |
def _cat(parent: str, child: Union[int, str]) -> str: | |
return f"{parent}.{child}" if parent else str(child) | |
def _build(annotation: Any, obj: Any, name: str = "") -> Any: | |
origin = typing.get_origin(annotation) | |
args = typing.get_args(annotation) | |
if origin == list and args: | |
return list(_build(args[0], x, _cat(name, i)) for i, x in enumerate(obj)) | |
if origin == set and args: | |
return set(_build(args[0], x, _cat(name, i)) for i, x in enumerate(obj)) | |
if origin == tuple and args: | |
if args[-1] is Ellipsis: | |
return tuple( | |
_build(args[0], x, _cat(name, i)) for i, x in enumerate(obj) | |
) | |
else: | |
if isinstance(obj, Sized) and len(args) != len(obj): | |
raise ValueError( | |
f"Sizes of tuple args and the given obj are mismatched: {name}" | |
) | |
return tuple( | |
_build(arg, x, _cat(name, i)) | |
for i, (arg, x) in enumerate(zip(args, obj)) | |
) | |
if origin == dict and args: | |
return { | |
_build(args[0], key, _cat(name, key)): _build( | |
args[1], value, _cat(name, key) | |
) | |
for key, value in obj.items() | |
} | |
if origin == Union and args: | |
error_chain: Optional[Exception] = None | |
for arg_annotation in args: | |
try: | |
return _build(arg_annotation, copy.deepcopy(obj), name) | |
except (ValueError, TypeError, AttributeError) as e: | |
e.args = ( | |
f"While constructing a field of type {arg_annotation}", | |
) + e.args | |
e.__cause__ = error_chain | |
error_chain = e | |
value_error = ValueError( | |
f"Failed to construct field with type {annotation}: {name}" | |
) | |
value_error.__cause__ = error_chain | |
raise value_error | |
if dataclasses.is_dataclass(annotation): | |
obj = copy.deepcopy(obj) | |
for field in dataclasses.fields(annotation): | |
if field.name not in obj: | |
raise ValueError( | |
f"Field {field.name} of {annotation} is not in the given dict: {name}" | |
) | |
obj[field.name] = _build( | |
field.type, obj[field.name], _cat(name, field.name) | |
) | |
return annotation(**obj) | |
if annotation is Any: | |
return obj | |
if not isinstance(obj, annotation): | |
if isinstance(annotation, type): | |
try: | |
return annotation(obj) | |
except Exception as e: | |
raise ValueError( | |
f"Failed to convert {obj} to {annotation}: {name}" | |
) from e | |
raise ValueError( | |
f"Actual type {type(obj)} differs from annotation {annotation}: {name}" | |
) | |
return obj | |
if not dataclasses.is_dataclass(dataclass): | |
raise ValueError(f"Ginen type must be dataclass: {dataclass}") | |
return _build(dataclass, copy.deepcopy(obj)) # type: ignore | |
def dataclass_from_json(dataclass: Type[T], s: str) -> T: | |
return dataclass_from_dict(dataclass, json.loads(s)) | |
if __name__ == "__main__": | |
from typing import List | |
@dataclasses.dataclass | |
class Foo: | |
x: int | |
@dataclasses.dataclass | |
class Bar: | |
y: str | |
@dataclasses.dataclass | |
class Baz: | |
items: List[Union[Foo, Bar]] | |
print(dataclass_from_dict(Baz, {"items": [{"x": "-123"}, {"y": "abc"}]})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment