Last active
August 16, 2024 21:30
-
-
Save antoine-tran/d641e3590d3d9c6adc2a50525ab73ff7 to your computer and use it in GitHub Desktop.
Update a dataclass inplace
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, ClassVar, Final, Mapping, Protocol, TypeAlias, TypeGuard, TypeVar | |
import typing_extensions | |
class DataClass(Protocol): | |
"""Represents a data class object.""" | |
__dataclass_fields__: ClassVar[dict[str, Field[Any]]] | |
def is_dataclass_instance(obj: Any) -> TypeGuard[DataClass]: | |
"""Return ``True`` if ``obj`` is of type :class:`DataClass`.""" | |
return is_dataclass(obj) and not isinstance(obj, type) | |
def update_dataclass( | |
obj: DataClass, | |
overrides: Mapping[str, Any], | |
) -> List[str]: | |
"""Update ``obj`` with the data contained in ``overrides`` Return the unknown fields. | |
:param obj: | |
The data class instance to update. | |
:param overrides: | |
The dictionary containing the data to set in ``obj``. | |
""" | |
unknown_fields: List[str] = [] | |
field_path: List[str] = [] | |
def update(obj_: DataClass, overrides_: Mapping[str, Any]) -> None: | |
overrides_copy = {**overrides_} | |
for field in fields(obj_): | |
value = getattr(obj_, field.name) | |
try: | |
override = overrides_copy.pop(field.name) | |
except KeyError: | |
continue | |
# Recursively traverse child dataclasses. | |
if override is not None and is_dataclass_instance(value): | |
if not isinstance(override, Mapping): | |
pathname = ".".join(field_path + [field.name]) | |
raise RuntimeError( | |
pathname, f"The field '{pathname}' is expected to be of type `{type(value)}`, but is of type `{type(override)}` instead." # fmt: skip | |
) | |
field_path.append(field.name) | |
update(value, override) | |
field_path.pop() | |
else: | |
setattr(obj_, field.name, override) | |
if overrides_copy: | |
unknown_fields.extend( | |
".".join(field_path + [name]) for name in overrides_copy | |
) | |
update(obj, overrides) | |
unknown_fields.sort() | |
return unknown_fields |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment