Created
April 22, 2024 08:01
-
-
Save mementum/707c570d1eccf7377d6289caa1e2d202 to your computer and use it in GitHub Desktop.
Dataclass with Field Annotations using @
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8; py-indent-offset:4 -*- | |
############################################################################### | |
from __future__ import annotations | |
from collections.abc import Callable | |
import re | |
import inspect | |
from typing import Annotated, overload | |
# dataclasses imports | |
import dataclasses | |
# Imports meant for re-export - ignore non-used and values cannot be determined | |
from dataclasses import * # noqa: F403 F401 | |
# Specific imports for development error-checking | |
from dataclasses import dataclass as _dataclass, field, KW_ONLY, MISSING | |
__all__ = [ | |
"at_dataclass", | |
"NO_INIT", | |
"NO_INIT_FACTORY", | |
"NO_FIELD", | |
] + dataclasses.__all__ | |
class _NO_FIELD_TYPE: | |
pass | |
NO_FIELD = _NO_FIELD_TYPE() | |
class _NO_INIT_TYPE: | |
pass | |
NO_INIT = _NO_INIT_TYPE() | |
class _NO_INIT_FACTORY_TYPE: | |
pass | |
NO_INIT_FACTORY = _NO_INIT_FACTORY_TYPE() | |
ANNOTATOR = "@" | |
ANN_LBRACKET = "[" | |
ANN_RE = re.compile('''\|(?=(?:[^'"]|'[^']*'|"[^"]*")*$)''') | |
@overload | |
def ann_dataclass(cls: type, **kwargs) -> type: | |
... | |
@overload | |
def at_dataclass( | |
cls: type | None = None, | |
reannotate: bool = True, | |
keep_dc_ann: bool = True, | |
annotator: str = ANNOTATOR, | |
) -> Callable[[type], type]: | |
... | |
def at_dataclass( | |
cls: type | None = None, | |
reannotate: bool = True, | |
keep_dc_ann: bool = True, | |
annotator: str = ANNOTATOR, | |
**kwargs, | |
) -> type | Callable[[type], type]: | |
# actual decorator for when cls is not None | |
def _annotifier(cls: type) -> type: | |
no_fields = {} # keep track of no_fields to remove and readd their annotations | |
# go over all annotations | |
for name, annotation in inspect.get_annotations(cls).items(): | |
if not (type(annotation) is str): | |
continue # only parsing annotations which are in str format | |
try: # try a split [type, rest] from "type @ | |
_type, f_ann = annotation.split(annotator, maxsplit=1) | |
except ValueError: | |
continue # splitting was not possible, nothing after the type | |
else: | |
_type = _type.rstrip() # remove trailing whitespace | |
f_ann = f_ann.lstrip() # remove leading whitespace | |
if f_ann.startswith(ANN_LBRACKET): | |
subannotations = eval(f_ann) | |
else: | |
f_tokens = ANN_RE.split(f_ann) | |
subannotations = eval(f"[{','.join(f_tokens)}]") | |
if NO_FIELD in subannotations: # remove from annotations | |
cls.__annotations__.pop(name) | |
if not keep_dc_ann: | |
subannotations.remove(NO_FIELD) | |
no_fields[name] = _type, subannotations # store for later re-adding | |
else: | |
f_kwargs = {} | |
defval = getattr(cls, name, MISSING) | |
if NO_INIT in subannotations: | |
f_kwargs["init"] = False | |
f_kwargs["default"] = defval | |
if not keep_dc_ann: | |
subannotations.remove(NO_INIT) | |
elif NO_INIT_FACTORY in subannotations: | |
f_kwargs["init"] = False | |
f_kwargs["default_factory"] = defval | |
if not keep_dc_ann: | |
subannotations.remove(NO_INIT_FACTORY) | |
elif KW_ONLY in subannotations: | |
f_kwargs["kw_only"] = True | |
f_kwargs["default"] = defval | |
if not keep_dc_ann: | |
subannotations.remove(KW_ONLY) | |
if f_kwargs: | |
setattr(cls, name, field(**f_kwargs)) | |
if not reannotate or not subannotations: | |
cls.__annotations__[name] = _type | |
else: | |
cls.__annotations__[name] = Annotated[_type, *subannotations] | |
dataclassed = _dataclass(cls, **kwargs) # apply std dataclass processing | |
# restore no_field attributes to the annotations | |
for name, (_type, subannotations) in no_fields.items(): | |
if not reannotate or not subannotations: | |
cls.__annotations__[name] = _type | |
else: | |
cls.__annotations__[name] = Annotated[_type, *subannotations] | |
return dataclassed | |
# decorator functionality when kwargs are used, return real deco (with closure) | |
if cls is None: | |
return _annotifier # -> Callable[[type], type] | |
# A cls is there, process it | |
return _annotifier(cls) # -> type | |
# With everything done export ann_dataclass as dataclass | |
dataclass = at_dataclass | |
# Small test | |
if __name__ == "__main__": | |
from dataclasses import fields | |
from typing import ClassVar | |
class Dummy: | |
pass | |
@at_dataclass | |
class A: | |
cv: ClassVar[str] = "classvar" | |
a: int | |
b: int @ KW_ONLY = 25 | |
c: int @ NO_INIT = 5 | |
d: list[str] @ NO_INIT_FACTORY = list | |
e: int @ NO_INIT | Dummy() | Dummy() = 0 | |
f: int @ [NO_INIT, Dummy(), Dummy()] = 1 | |
g: int @ NO_FIELD = 7 | |
h: int @ NO_FIELD | |
# ############ | |
a = A(3) | |
print("=" * 80) | |
print(f"{a.__annotations__ = }") | |
print("=" * 80) | |
print(f"{a.a = }") | |
print(f"{a.b = }") | |
for f in fields(A): | |
print("-- " + "-" * 70) | |
print(f"{f = }") | |
print("-" * 70) | |
try: | |
b = A(1, b=2) | |
except Exception as e: | |
print(f"Exception: {e = }") | |
else: | |
print("b is a keyword argument. Ok") | |
try: | |
b = A(1, 2) | |
except Exception as e: | |
print(f"Exception: {e = }") | |
try: | |
b = A(1, c=2) | |
except Exception as e: | |
print(f"Exception: {e = }") | |
try: | |
b = A(1, d=2) | |
except Exception as e: | |
print(f"Exception: {e = }") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment