Last active
April 29, 2026 18:53
-
-
Save kurtbrose/aa21bbee470c4c1255346874b5b8f557 to your computer and use it in GitHub Desktop.
alternative to raw __dict__ manipulation for translating between types
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import inspect | |
| from dataclasses import dataclass, field | |
| from types import SimpleNamespace | |
| from typing import Any | |
| _MISSING = object() | |
| def init_from[T](cls: type[T], /, *src_objs: object, **overrides: Any) -> T: | |
| """ | |
| Construct `cls` by filling in __init__ parameters from the given source objects and kwargs. | |
| Each src_obj logically overrides the previous ones, and **overides has highest precedence. | |
| getattr() is used on the minimum number of source objects, to avoid expensive properties. | |
| """ | |
| # All keyword-passable parameters | |
| sig = inspect.signature(cls) | |
| kw_param_kinds = (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) | |
| kw_names = [p.name for p in sig.parameters.values() if p.kind in kw_param_kinds] | |
| init_kwargs = overrides | |
| for name in kw_names: | |
| if name in init_kwargs: | |
| continue | |
| for src in reversed(src_objs): | |
| val = getattr(src, name, _MISSING) | |
| if val is not _MISSING: | |
| init_kwargs[name] = val | |
| break | |
| return cls(**init_kwargs) | |
| def test_init_from_resolution_order() -> None: | |
| @dataclass | |
| class A: | |
| a: int | |
| b: int | |
| c: int | |
| d: int = 0 | |
| s1 = SimpleNamespace(a=1, b=1, c=1) | |
| s2 = SimpleNamespace(b=2, c=2) | |
| # kwargs > s2 > s1 > default | |
| assert init_from(A, s1, s2, c=3) == A(1, 2, 3, 0) | |
| def test_init_from_none_okay() -> None: | |
| # None is just another value | |
| @dataclass | |
| class A: | |
| a: int | None | |
| assert init_from(A, SimpleNamespace(a=None)) == A(None) | |
| def test_init_from_missing_required() -> None: | |
| @dataclass | |
| class A: | |
| a: int | |
| with pytest.raises(TypeError): | |
| init_from(A, SimpleNamespace(x=1)) | |
| with pytest.raises(TypeError): | |
| init_from(A) | |
| def test_init_from_extra() -> None: | |
| @dataclass | |
| class A: | |
| a: int | |
| # extra fields are ignored | |
| assert init_from(A, SimpleNamespace(a=1, b=2)) == A(1) | |
| # extra kwargs are an error | |
| with pytest.raises(TypeError): | |
| init_from(A, a=1, b=2) | |
| def test_init_from_no_posargs() -> None: | |
| class A: | |
| def __init__(self, a, /): ... | |
| with pytest.raises(TypeError): | |
| init_from(A, a=1) | |
| def test_init_from_non_init_field() -> None: | |
| @dataclass | |
| class A: | |
| a: int = field(init=False, default=0) | |
| # doesn't try to pass a where it isn't allowed | |
| assert init_from(A, SimpleNamespace(a=1)).a == 0 | |
| if __name__ == "__main__": | |
| test_init_from_resolution_order() | |
| test_init_from_none_okay() | |
| test_init_from_missing_required() | |
| test_init_from_extra() | |
| test_init_from_no_posargs() | |
| test_init_from_non_init_field() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment