Last active
June 18, 2024 02:16
-
-
Save Sachaa-Thanasius/571cafc9087a4de1c5b865079753da29 to your computer and use it in GitHub Desktop.
Proto-draft implementation of a potential Refinement typeform in Python. Not official, just my first pass at it for fun. Reference: https://discuss.python.org/t/pep-746-typedmetadata-for-type-checking-of-pep-593-annotated/53834
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pyright: enableExperimentalFeatures=true | |
# PEP 712 is only active for pyright with the "enableExperimentalFeatures" setting enabled. | |
import operator | |
import re | |
import sys | |
from typing import ( | |
TYPE_CHECKING, | |
Callable, | |
ClassVar, | |
Final, | |
Generic, | |
NoReturn, | |
Pattern, | |
Protocol, | |
Tuple, | |
TypeVar, | |
Union, | |
final, | |
) | |
import attrs | |
if sys.version_info >= (3, 11): | |
from typing import ParamSpec, Self, reveal_type | |
else: | |
from typing_extensions import ParamSpec, Self, reveal_type | |
if sys.version_info >= (3, 10): | |
from typing import ParamSpec | |
else: | |
from typing_extensions import ParamSpec | |
if sys.version_info >= (3, 9): | |
from typing import Annotated | |
else: | |
from typing_extensions import Annotated | |
_MISSING = object() | |
T = TypeVar("T") | |
U = TypeVar("U") | |
P = ParamSpec("P") | |
P2 = ParamSpec("P2") | |
base_model = attrs.define | |
adapter = attrs.field | |
# ===================================================================================================================== | |
# ==== Implementation of Refined. | |
# ===================================================================================================================== | |
if TYPE_CHECKING: | |
Refined = Annotated | |
else: | |
import operator | |
from typing import _GenericAlias, _tp_cache, _type_check, _type_repr | |
if sys.version_info >= (3, 12): | |
from typing import Unpack | |
else: | |
from typing_extensions import Unpack | |
if sys.version_info >= (3, 10): | |
from typing import get_origin | |
else: | |
from typing_extensions import get_origin | |
# Almost an exact reimplementation of Annotated. | |
@final | |
class _RefinedGenericAlias(_GenericAlias, _root=True): | |
if TYPE_CHECKING: | |
__origin__: type | |
__refinements__: Tuple[object, ...] | |
def __init__(self, origin: type, refinements: Tuple[object, ...]): | |
if isinstance(origin, _RefinedGenericAlias): | |
refinements = origin.__refinements__ + refinements | |
origin = origin.__origin__ | |
super().__init__(origin, origin) | |
self.__refinements__ = refinements | |
def copy_with(self, params: Tuple[object, ...]): | |
if len(params) != 1: | |
raise AssertionError | |
new_type = params[0] | |
return _RefinedGenericAlias(new_type, self.__refinements__) | |
def __repr__(self): | |
return f"Refined[{_type_repr(self.__origin__)}, {', '.join(repr(r) for r in self.__refinements__)}]" | |
def __reduce__(self): | |
return operator.getitem, (Refined, (self.__origin__, *self.__refinements__)) | |
def __eq__(self, other: object, /): | |
if isinstance(other, type(self)): | |
if self.__origin__ != other.__origin__: | |
return False | |
return self.__refinements__ == other.__refinements__ | |
return NotImplemented | |
def __hash__(self): | |
return hash((self.__origin__, self.__refinements)) | |
@final | |
class Refined: | |
__slots__ = () | |
def __new__(cls, *args: object, **kwargs: object) -> NoReturn: | |
raise TypeError("Type Refined cannot be instantiated.") | |
def __init_subclass__(cls, *args: object, **kwargs: object) -> NoReturn: | |
raise TypeError(f"Cannot subclass {cls.__module__}.Refined") | |
def __class_getitem__( | |
cls, | |
params: Tuple[type, Unpack[Tuple[Union["TypeRefinement", "ValueRefinement"], ...]]], | |
) -> _RefinedGenericAlias: | |
if not isinstance(params, tuple): | |
params = (params,) | |
return cls._class_getitem_inner(cls, *params) | |
@_tp_cache(typed=True) | |
def _class_getitem_inner( | |
cls, | |
*params: Unpack[Tuple[type, Unpack[Tuple[Union["TypeRefinement", "ValueRefinement"], ...]]]], | |
) -> _RefinedGenericAlias: | |
if len(params) < 2: | |
raise TypeError("Refined[...] should be used with at least two arguments (a type and an annotation).") | |
if (not isinstance(params[0], type)) and getattr(params[0], "__typing_is_unpacked_typevartuple__", False): | |
raise TypeError("Refined[...] should not be used with an unpacked TypeVarTuple.") | |
allowed_special_forms = {ClassVar, Final} | |
if get_origin(params[0]) in allowed_special_forms: | |
origin = params[0] | |
else: | |
msg = "Refined[t, ...]: t must be a type." | |
origin = _type_check(params[0], msg) | |
refinements = tuple(params[1:]) | |
return _RefinedGenericAlias(origin, refinements) | |
class TypeRefinement(Protocol): | |
def __supports_type__(self, t: type) -> bool: ... | |
class ValueRefinement(Protocol): | |
def __supports_value__(self, o: object) -> bool: ... | |
class NumCmp: | |
_op_map: ClassVar = { | |
"eq": operator.eq, | |
"ne": operator.ne, | |
"gt": operator.gt, | |
"ge": operator.ge, | |
"lt": operator.lt, | |
"le": operator.le, | |
} | |
def __init__( | |
self, | |
eq: object = _MISSING, | |
ne: object = _MISSING, | |
gt: object = _MISSING, | |
ge: object = _MISSING, | |
lt: object = _MISSING, | |
le: object = _MISSING, | |
): | |
self.eq = eq | |
self.ne = ne | |
self.gt = gt | |
self.ge = ge | |
self.lt = lt | |
self.le = le | |
def __supports_value__(self, o: object) -> bool: | |
cond = True | |
for cmp_name, cmp_op in self._op_map.items(): | |
if (cmp_val := getattr(self, cmp_name)) is not _MISSING: | |
cond &= cmp_op(o, cmp_val) | |
return cond | |
class RePtrn: | |
def __init__(self, pattern: Union[str, Pattern[str]]): | |
self.pattern = pattern if isinstance(pattern, Pattern) else re.compile(pattern) | |
def __supports_value__(self, o: str) -> bool: | |
return self.pattern.match(o) is not None | |
# ===================================================================================================================== | |
# ==== Implementation of parse to superficially match the semantics of Pydantic's thing/use case. | |
# ===================================================================================================================== | |
class ValidationError(Exception): | |
pass | |
@final | |
class Parser(Generic[P, T]): | |
__slots__ = ("typer",) | |
def __init__(self, typer: Callable[P, T]): | |
self.typer = typer | |
def __init_subclass__(cls, *args: object, **kwargs: object) -> NoReturn: | |
raise TypeError(f"Cannot subclass {cls.__module__}.Parser") | |
def __or__(self, other: "Parser[P2, T]", /) -> "Parser[P2, T]": | |
if not isinstance(other, Parser): # pyright: ignore [reportUnnecessaryIsInstance] | |
return NotImplemented | |
def temp(*args: P2.args, **kwargs: P2.kwargs) -> T: | |
result = object() | |
for typer in (self.typer, other.typer): | |
try: | |
result = typer(*args, **kwargs) | |
except ValidationError: | |
print(f"Failed to parse {(args, kwargs)} with {typer}. Attempting next.") # noqa: T201 | |
else: | |
return result | |
raise ValidationError | |
return Parser(temp) | |
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: | |
try: | |
return self.typer(*args, **kwargs) | |
except Exception as exc: # noqa: BLE001 | |
raise ValidationError from exc | |
def _convert(self, fn: Callable[[T], U]) -> "Parser[P, U]": | |
def temp(*args: P.args, **kwargs: P.kwargs) -> U: | |
return fn(self.typer(*args, **kwargs)) | |
return Parser(temp) | |
def transform(self, fn: Callable[[T], U]) -> "Parser[P, U]": | |
return self._convert(fn) | |
def parse(self, fn: Callable[[T], U]) -> "Parser[P, U]": | |
return self._convert(fn) | |
def ge(self, floor: int) -> Self: | |
# XXX: Nonfunctional placeholder. | |
return self | |
def lt(self, ceil: int) -> Self: | |
# XXX: Nonfunctional placeholder. | |
return self | |
def parse(tp: Callable[P, T]) -> Parser[P, T]: | |
return Parser(tp) | |
# ===================================================================================================================== | |
# ==== Attempt at an example using the above. | |
# ===================================================================================================================== | |
# Pretend these classes are subclasses of pydantic.BaseModel instead of fresh classes being wrapped by class decorators. | |
# This is what Pydantic wants their transformers and validators to look like. | |
@base_model | |
class Before: | |
username: Annotated[str, parse(str).transform(str.lower)] | |
birthday: Annotated[int, (parse(int) | parse(str).transform(str.strip).parse(int)).ge(0).lt(512)] | |
age: Annotated[int, parse(int)] | |
# This is an attrs class with PEP 712 active, and imo looks like a better alternative. | |
@base_model | |
class After: | |
username: str = adapter(converter=parse(str).transform(str.lower)) | |
birthday: Refined[int, NumCmp(ge=0, lt=512)] = adapter(converter=(parse(int) | parse(str).transform(str.strip).parse(int))) | |
age: int = adapter(converter=parse(int)) | |
def test() -> None: | |
reveal_type(After.__init__) | |
# Type of "After.__init__" is "(self: After, username: object, birthday: object, age: str | Buffer | SupportsInt | SupportsIndex | SupportsTrunc) -> None" | |
ex = After(10, "1010", 1.0) | |
reveal_type(ex.username) # Type of "ex.username" is "str" | |
print(ex.username) | |
reveal_type(ex.birthday) # Type of "ex.birthday" is "int" | |
print(ex.birthday) | |
reveal_type(ex.age) # Type of "ex.age" is "int" | |
print(ex.age) | |
if __name__ == "__main__": | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment