Last active
December 28, 2022 22:49
-
-
Save Zomatree/5380dab7fd2bf9b5e7fab331cdd1f79e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from dataclasses import dataclass | |
from types import NoneType, UnionType | |
from typing import Annotated, Any, Callable, Generic, Iterable, TypeGuard, TypeVar, Union, cast, get_origin as _get_origin, get_args, Self | |
T = TypeVar("T") | |
# isistance doesnt allow the second paramter to be generic so we must get the origin, however typing.get_origin doesnt return the class if its not generic | |
# so we default back to the type ourself | |
# isinstance(..., list[str]) -> Error (cant use isinstance with generics) | |
# isinstance(..., typing.get_origin(list[str])) -> Works (get_origin returns list) | |
# isinstance(..., typing.get_origin(list)) -> Error (get_origin returned None) | |
def get_origin(ty: Any) -> Any: | |
return _get_origin(ty) or ty | |
class Type(Generic[T]): | |
def __init__(self): | |
self._type: type[T] | None = None | |
self._literals: Iterable[T] | None = None | |
self._validator: Callable[[T], bool] = lambda v: True | |
self._type_checker: Callable[[T, type[T]], bool] = lambda obj, ty: isinstance(obj, get_origin(ty)) | |
self._rename: str | None = None | |
self._key: str | None = None | |
self._default: Callable[[], T] | None = None | |
def values(self, literals: Iterable[T]) -> Self: | |
self._literals = literals | |
return self | |
def rename(self, name: str) -> Self: | |
self._rename = name | |
return self | |
def validator(self, f: Callable[[T], bool]) -> Self: | |
self._validator = f | |
return self | |
def default(self, f: Callable[[], T]) -> Self: | |
self._default = f | |
return self | |
def type_checker(self, f: Callable[[T, type], bool]) -> Self: | |
self._type_checker = f | |
return self | |
def _to_internal(self) -> InternalType[T]: | |
return InternalType(cast(type[T], self._type), self._literals, self._validator, self._rename, cast(str, self._key), self._default, self._type_checker, ()) | |
def evaluate_annotation(cls: type, annot: Any) -> Any: | |
if not isinstance(annot, str): | |
return annot | |
return eval(annot) | |
@dataclass | |
class InternalType(Generic[T]): | |
type: type[T] | |
literals: Iterable[T] | None | |
validator: Callable[[T], bool] | |
rename: str | None | |
key: str | |
default: Callable[[], T] | None | |
type_checker: Callable[[T, type[T]], bool] | |
generics: tuple[InternalType[Any], ...] | |
def rename(key: str) -> Type[Any]: | |
return Type[Any]().rename(key) | |
def values(values: Iterable[T]) -> Type[T]: | |
return Type[T]().values(values) | |
def default(f: Callable[[], T]) -> Type[T]: | |
return Type[T]().default(f) | |
def type_checker(f: Callable[[T, type], bool]) -> Type[T]: | |
return Type[T]().type_checker(f) | |
class ModelError(Exception): | |
pass | |
class MissingRequiredKey(ModelError): | |
pass | |
class InvalidType(ModelError): | |
pass | |
class ValidatorFailed(ModelError): | |
pass | |
def is_model_subclass(ty: Any) -> TypeGuard[type[Model]]: | |
return isinstance(ty, type) and issubclass(ty, Model) | |
def check_types(cls: type, key: str, ty: InternalType[Any], value: Any) -> Any: | |
if is_model_subclass(ty.type): | |
return ty.type(value) | |
if not ty.type_checker(value, ty.type): | |
raise InvalidType(f"Key {key} expected {ty.type} but found {type(value)}") | |
if ty.type in [list, set, tuple]: | |
inner_ty = ty.generics[0] | |
new_value: list[Any] = [] | |
for inner_value in value: | |
if is_model_subclass(inner_ty): | |
new_value.append(inner_ty(inner_value)) | |
else: | |
assert isinstance(inner_ty, InternalType) | |
check_types(cls, key, inner_ty, inner_value) | |
new_value.append(inner_value) | |
return ty.type(new_value) # cast back to the proper type because we are using list to store the values | |
elif ty is dict: | |
key_ty, value_ty = get_args(ty) | |
new_dict: dict[Any, Any] = {} | |
for inner_key, inner_value in value.items(): | |
new_key = check_types(cls, f"{key}.{inner_key}", key_ty, inner_key) | |
new_value = check_types(cls, f"{key}.{inner_key}", value_ty, value_ty) | |
new_dict[new_key] = new_value | |
return new_dict | |
return value | |
def convert_to_type(cls: type, key: Any, value: Any) -> InternalType[Any]: | |
if isinstance(value, InternalType): | |
return value | |
value = evaluate_annotation(cls, value) | |
origin = get_origin(value) | |
args = get_args(value) | |
if origin is Annotated: | |
original_ty: Type[Any] = args[1] | |
ty = original_ty._to_internal() | |
ty.type = args[0] | |
elif origin in (Union, UnionType): | |
if len(args) == 2 and args[1] is NoneType: | |
ty = Type[Any]()._to_internal() | |
ty.default = lambda: None | |
ty.type = (args[0], NoneType) | |
else: | |
raise ModelError("Union is not supported") | |
elif origin in [list, set, tuple]: | |
original_internal_ty = args[0] | |
internal_ty = convert_to_type(cls, key, original_internal_ty) | |
ty = Type[Any]()._to_internal() | |
ty.type = value | |
ty.type_checker = lambda obj, type: isinstance(obj, get_origin(type)) and all(internal_ty.type_checker(v, internal_ty.type) for v in obj) | |
ty.generics = (internal_ty,) | |
elif origin is dict: | |
original_key_internal_ty, original_value_internal_ty = args | |
internal_key_ty = convert_to_type(cls, key, original_key_internal_ty) | |
internal_value_ty = convert_to_type(cls, key, original_value_internal_ty) | |
ty = Type[Any]()._to_internal() | |
ty.type = value | |
ty.type_checker = lambda obj, type: isinstance(obj, dict) and all(internal_key_ty.type_checker(k, internal_key_ty.type) and internal_value_ty.type_checker(v, internal_value_ty.type) for k, v in obj.items()) | |
ty.generics = (internal_key_ty, internal_value_ty) | |
else: | |
ty = Type[Any]()._to_internal() | |
ty.type = value | |
ty.key = key | |
return ty | |
class Model: | |
_items: dict[str, InternalType[Any]] | |
def __init_subclass__(cls) -> None: | |
items: dict[str, InternalType[Any]] = {} | |
for key, value in cls.__annotations__.items(): | |
ty = convert_to_type(cls, key, value) | |
items[ty.rename or ty.key] = ty | |
cls._items = items | |
def __init__(self, data: dict[str, Any]): | |
rebuilt: dict[str, Any] = {} | |
for key, ty in self._items.items(): | |
if key not in data: | |
if default := ty.default: | |
value = default() | |
else: | |
raise MissingRequiredKey(f"Missing required key `{ty.key}`") | |
else: | |
value = data[key] | |
value = check_types(self.__class__, ty.key, ty, value) | |
if (ty.literals and value not in ty.literals) or not ty.validator(value): | |
raise ValidatorFailed(f"Key {ty.key}'s validator failed") | |
rebuilt[ty.key] = value | |
for key, value in rebuilt.items(): | |
setattr(self, key, value) | |
def to_dict(self) -> dict[str, Any]: | |
output: dict[str, Any] = {} | |
for renamed, item in self._items.items(): | |
value = getattr(self, item.key) | |
if isinstance(value, Model): | |
value = value.to_dict() | |
output[renamed] = value | |
return output | |
def __repr__(self) -> str: | |
items: list[str] = [] | |
for item in self._items.values(): | |
value = getattr(self, item.key) | |
items.append(f"{item.key}={value!r}") | |
return f"<{self.__class__.__name__} {' '.join(items)}>" | |
class Foo(Model): | |
x: Annotated[int, values(range(0, 10))] | |
y: Annotated[str, rename("z")] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment