Last active
July 28, 2023 06:22
-
-
Save rnag/db6bf83d9ca19dfe897d6ccabd4e2570 to your computer and use it in GitHub Desktop.
Dataclass Type Validation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# dataclass type validation for fields on class instantiation -and- on assignment. | |
from dataclasses import dataclass | |
from validators import TypeValidator | |
@dataclass | |
class FooDC: | |
# alternatively, like: | |
# number: int = TypeValidator(default_factory=int) | |
number: int = TypeValidator() | |
word: str = TypeValidator() | |
foo = FooDC(number=3, word='1') | |
print(foo) | |
try: | |
_ = FooDC(number='test') | |
except TypeError as e: | |
print(e) | |
else: | |
raise ValueError('expected a TypeError to be raised!') | |
print() | |
bar = FooDC() | |
bar.word = 2 | |
bar.word |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# dataclass type validation for fields on class instantiation -and- on assignment. | |
# | |
# NOTE: this example should work in Python 3.7+ | |
from dataclasses import dataclass | |
from typing import Dict, Optional, Union | |
from validators import TypeValidator | |
@dataclass | |
class FooTest: | |
map: Dict[str, float] = TypeValidator() | |
num_or_str: Union[int, float, str] = TypeValidator() | |
opt_word: Optional[str] = TypeValidator() | |
foo = FooTest() | |
print(foo) | |
try: | |
_ = FooTest(map={'key': 1.23}, num_or_str=b'byte string') | |
except TypeError as e: | |
print(e) | |
else: | |
raise ValueError('expected a TypeError to be raised!') | |
bar = FooTest() | |
bar.map = 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# dataclass type validation for fields on class instantiation -and- on assignment. | |
# | |
# NOTE: this example should work in Python 3.10+ | |
from dataclasses import dataclass | |
from validators import TypeValidator | |
@dataclass | |
class FooTest: | |
map: dict[str, float] = TypeValidator() | |
num_or_str: int | float | str = TypeValidator() | |
opt_word: str | None = TypeValidator() | |
foo = FooTest() | |
print(foo) | |
try: | |
_ = FooTest(map={'key': 1.23}, num_or_str=b'byte string') | |
except TypeError as e: | |
print(e) | |
else: | |
raise ValueError('expected a TypeError to be raised!') | |
bar = FooTest() | |
bar.map = 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import MISSING | |
from typing import Callable, Generic, NewType, TypeVar, Union | |
_NoneType = type(None) | |
_T = TypeVar('_T') | |
_UNSET = NewType('_UNSET', None) | |
class TypeValidator(Generic[_T]): | |
__slots__ = ('add_default', | |
'private_name', | |
'default', | |
'default_factory', | |
'type', | |
) | |
def __init__(self, *, | |
add_default=True, | |
default: _T = MISSING, | |
default_factory: Callable[[], _T] = MISSING): | |
self.add_default = add_default | |
self.default = default | |
self.default_factory = default_factory | |
def __set_name__(self, owner, name): | |
self.private_name = '_' + name | |
tp = owner.__annotations__[name] | |
try: | |
# check for types like `typing.Dict` and `typing.List`. | |
tps = tp.__origin__ | |
# fix for Python 3.7, where `typing.Union` has an `__origin__` attribute | |
if tps is Union: | |
tps = tp.__args__ | |
except AttributeError: | |
# check for types like `typing.Union[int, str]` | |
# and `typing.Optional[float]`. | |
tps = getattr(tp, '__args__', tp) | |
self.type = tps | |
if self.add_default and self.default is self.default_factory is MISSING: | |
if isinstance(tps, tuple): # a tuple of types | |
if _NoneType in tps: | |
# if we see a `None` for a type, then `None` is a reasonable | |
# default to use. | |
self.default = None | |
return | |
else: | |
# fix, since tuples aren't callable in any case. | |
tps = tps[0] | |
# check if the type can be used as a "default factory" | |
try: | |
_ = tps() | |
except TypeError: | |
pass | |
else: | |
self.default_factory = tps | |
def __get__(self, obj, objtype=None): | |
if obj is None: | |
# we are called from the `@dataclass` decorator, which processes | |
# the class (objtype here). | |
return self.default if self.default_factory is MISSING else _UNSET | |
# we are called from the class instance, to retrieve the attribute. | |
return getattr(obj, self.private_name) | |
def __set__(self, obj, value): | |
if value is _UNSET: | |
value = self.default_factory() | |
else: | |
self.validate(value) | |
setattr(obj, self.private_name, value) | |
def validate(self, value): | |
if not isinstance(value, self.type): | |
msg = f'Expected {self.private_name.lstrip("_")} to be {self.type!r}, got {value!r}' | |
raise TypeError(msg) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nice.. thank you